summaryrefslogtreecommitdiffstats
path: root/gfx/angle/checkout/src/compiler/translator/tree_util/FindPreciseNodes.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'gfx/angle/checkout/src/compiler/translator/tree_util/FindPreciseNodes.cpp')
-rw-r--r--gfx/angle/checkout/src/compiler/translator/tree_util/FindPreciseNodes.cpp703
1 files changed, 703 insertions, 0 deletions
diff --git a/gfx/angle/checkout/src/compiler/translator/tree_util/FindPreciseNodes.cpp b/gfx/angle/checkout/src/compiler/translator/tree_util/FindPreciseNodes.cpp
new file mode 100644
index 0000000000..2943117314
--- /dev/null
+++ b/gfx/angle/checkout/src/compiler/translator/tree_util/FindPreciseNodes.cpp
@@ -0,0 +1,703 @@
+//
+// Copyright 2021 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// FindPreciseNodes.cpp: Propagates |precise| to AST nodes.
+//
+// The high level algorithm is as follows. For every node that "assigns" to a precise object,
+// subobject (a precise struct whose field is being assigned) or superobject (a struct with a
+// precise field), two things happen:
+//
+// - The operation is marked precise if it's an arithmetic operation
+// - The right hand side of the assignment is made precise. If only a subobject is precise, only
+// the corresponding subobject of the right hand side is made precise.
+//
+
+#include "compiler/translator/tree_util/FindPreciseNodes.h"
+
+#include "common/hash_utils.h"
+#include "compiler/translator/Compiler.h"
+#include "compiler/translator/IntermNode.h"
+#include "compiler/translator/Symbol.h"
+#include "compiler/translator/tree_util/IntermTraverse.h"
+
+namespace sh
+{
+
+namespace
+{
+
+// An access chain applied to a variable. The |precise|-ness of a node does not change when
+// indexing arrays, selecting matrix columns or swizzle vectors. This access chain thus only
+// includes block field selections. The access chain is used to identify the part of an object
+// that is or should be |precise|. If both a.b.c and a.b are precise, only a.b is every considered.
+class AccessChain
+{
+ public:
+ AccessChain() = default;
+
+ bool operator==(const AccessChain &other) const { return mChain == other.mChain; }
+
+ const TVariable *build(TIntermTyped *lvalue);
+
+ const TVector<size_t> &getChain() const { return mChain; }
+
+ void reduceChain(size_t newSize)
+ {
+ ASSERT(newSize <= mChain.size());
+ mChain.resize(newSize);
+ }
+ void clear() { reduceChain(0); }
+ void push_back(size_t index) { mChain.push_back(index); }
+ void pop_front(size_t n);
+ void append(const AccessChain &other)
+ {
+ mChain.insert(mChain.end(), other.mChain.begin(), other.mChain.end());
+ }
+ bool removePrefix(const AccessChain &other);
+
+ private:
+ TVector<size_t> mChain;
+};
+
+bool IsIndexOp(TOperator op)
+{
+ switch (op)
+ {
+ case EOpIndexDirect:
+ case EOpIndexDirectStruct:
+ case EOpIndexDirectInterfaceBlock:
+ case EOpIndexIndirect:
+ return true;
+ default:
+ return false;
+ }
+}
+
+const TVariable *AccessChain::build(TIntermTyped *lvalue)
+{
+ if (lvalue->getAsSwizzleNode())
+ {
+ return build(lvalue->getAsSwizzleNode()->getOperand());
+ }
+ if (lvalue->getAsSymbolNode())
+ {
+ const TVariable *var = &lvalue->getAsSymbolNode()->variable();
+
+ // For fields of nameless interface blocks, add the field index too.
+ if (var->getType().getInterfaceBlock() != nullptr)
+ {
+ mChain.push_back(var->getType().getInterfaceBlockFieldIndex());
+ }
+
+ return var;
+ }
+ TIntermBinary *binary = lvalue->getAsBinaryNode();
+ ASSERT(binary);
+
+ TOperator op = binary->getOp();
+ ASSERT(IsIndexOp(op));
+
+ const TVariable *var = build(binary->getLeft());
+
+ if (op == EOpIndexDirectStruct || op == EOpIndexDirectInterfaceBlock)
+ {
+ int fieldIndex = binary->getRight()->getAsConstantUnion()->getIConst(0);
+ mChain.push_back(fieldIndex);
+ }
+
+ return var;
+}
+
+void AccessChain::pop_front(size_t n)
+{
+ std::rotate(mChain.begin(), mChain.begin() + n, mChain.end());
+ reduceChain(mChain.size() - n);
+}
+
+bool AccessChain::removePrefix(const AccessChain &other)
+{
+ // First, make sure the common part of the two access chains match.
+ size_t commonSize = std::min(mChain.size(), other.mChain.size());
+
+ for (size_t index = 0; index < commonSize; ++index)
+ {
+ if (mChain[index] != other.mChain[index])
+ {
+ return false;
+ }
+ }
+
+ // Remove the common part from the access chain. If other is a deeper access chain, this access
+ // chain will become empty.
+ pop_front(commonSize);
+
+ return true;
+}
+
+AccessChain GetAssignmentAccessChain(TIntermOperator *node)
+{
+ // The assignment is either a unary or a binary node, and the lvalue is always the first child.
+ AccessChain lvalueAccessChain;
+ lvalueAccessChain.build(node->getChildNode(0)->getAsTyped());
+ return lvalueAccessChain;
+}
+
+template <typename Traverser>
+void TraverseIndexNodesOnly(TIntermNode *node, Traverser *traverser)
+{
+ if (node->getAsSwizzleNode())
+ {
+ node = node->getAsSwizzleNode()->getOperand();
+ }
+
+ if (node->getAsSymbolNode())
+ {
+ return;
+ }
+
+ TIntermBinary *binary = node->getAsBinaryNode();
+ ASSERT(binary);
+
+ TOperator op = binary->getOp();
+ ASSERT(IsIndexOp(op));
+
+ if (op == EOpIndexIndirect)
+ {
+ binary->getRight()->traverse(traverser);
+ }
+
+ TraverseIndexNodesOnly(binary->getLeft(), traverser);
+}
+
+// An object, which could be a sub-object of a variable.
+struct ObjectAndAccessChain
+{
+ const TVariable *variable;
+ AccessChain accessChain;
+};
+
+bool operator==(const ObjectAndAccessChain &a, const ObjectAndAccessChain &b)
+{
+ return a.variable == b.variable && a.accessChain == b.accessChain;
+}
+
+struct ObjectAndAccessChainHash
+{
+ size_t operator()(const ObjectAndAccessChain &object) const
+ {
+ size_t result = angle::ComputeGenericHash(&object.variable, sizeof(object.variable));
+ if (!object.accessChain.getChain().empty())
+ {
+ result =
+ result ^ angle::ComputeGenericHash(object.accessChain.getChain().data(),
+ object.accessChain.getChain().size() *
+ sizeof(object.accessChain.getChain()[0]));
+ }
+ return result;
+ }
+};
+
+// A map from variables to AST nodes that modify them (i.e. nodes where IsAssignment(op)).
+using VariableToAssignmentNodeMap = angle::HashMap<const TVariable *, TVector<TIntermOperator *>>;
+// A set of |return| nodes from functions with a |precise| return value.
+using PreciseReturnNodes = angle::HashSet<TIntermBranch *>;
+// A set of precise objects that need processing, or have been processed.
+using PreciseObjectSet = angle::HashSet<ObjectAndAccessChain, ObjectAndAccessChainHash>;
+
+struct ASTInfo
+{
+ // Generic information about the tree:
+ VariableToAssignmentNodeMap variableAssignmentNodeMap;
+ // Information pertaining to |precise| expressions:
+ PreciseReturnNodes preciseReturnNodes;
+ PreciseObjectSet preciseObjectsToProcess;
+ PreciseObjectSet preciseObjectsVisited;
+};
+
+int GetObjectPreciseSubChainLength(const ObjectAndAccessChain &object)
+{
+ const TType &type = object.variable->getType();
+
+ if (type.isPrecise())
+ {
+ return 0;
+ }
+
+ const TFieldListCollection *block = type.getInterfaceBlock();
+ if (block == nullptr)
+ {
+ block = type.getStruct();
+ }
+ const TVector<size_t> &accessChain = object.accessChain.getChain();
+
+ for (size_t length = 0; length < accessChain.size(); ++length)
+ {
+ ASSERT(block != nullptr);
+
+ const TField *field = block->fields()[accessChain[length]];
+ if (field->type()->isPrecise())
+ {
+ return static_cast<int>(length + 1);
+ }
+
+ block = field->type()->getStruct();
+ }
+
+ return -1;
+}
+
+void AddPreciseObject(ASTInfo *info, const ObjectAndAccessChain &object)
+{
+ if (info->preciseObjectsVisited.count(object) > 0)
+ {
+ return;
+ }
+
+ info->preciseObjectsToProcess.insert(object);
+ info->preciseObjectsVisited.insert(object);
+}
+
+void AddPreciseSubObjects(ASTInfo *info, const ObjectAndAccessChain &object);
+
+void AddObjectIfPrecise(ASTInfo *info, const ObjectAndAccessChain &object)
+{
+ // See if the access chain is already precise, and if so add the minimum access chain that is
+ // precise.
+ int preciseSubChainLength = GetObjectPreciseSubChainLength(object);
+ if (preciseSubChainLength == -1)
+ {
+ // If the access chain is not precise, see if there are any fields of it that are precise,
+ // and add those individually.
+ AddPreciseSubObjects(info, object);
+ return;
+ }
+
+ ObjectAndAccessChain preciseObject = object;
+ preciseObject.accessChain.reduceChain(preciseSubChainLength);
+
+ AddPreciseObject(info, preciseObject);
+}
+
+void AddPreciseSubObjects(ASTInfo *info, const ObjectAndAccessChain &object)
+{
+ const TFieldListCollection *block = object.variable->getType().getInterfaceBlock();
+ if (block == nullptr)
+ {
+ block = object.variable->getType().getStruct();
+ }
+ const TVector<size_t> &accessChain = object.accessChain.getChain();
+
+ for (size_t length = 0; length < accessChain.size(); ++length)
+ {
+ block = block->fields()[accessChain[length]]->type()->getStruct();
+ }
+
+ if (block == nullptr)
+ {
+ return;
+ }
+
+ for (size_t fieldIndex = 0; fieldIndex < block->fields().size(); ++fieldIndex)
+ {
+ ObjectAndAccessChain subObject = object;
+ subObject.accessChain.push_back(fieldIndex);
+
+ // If the field is precise, add it as a precise subobject. Otherwise recurse.
+ if (block->fields()[fieldIndex]->type()->isPrecise())
+ {
+ AddPreciseObject(info, subObject);
+ }
+ else
+ {
+ AddPreciseSubObjects(info, subObject);
+ }
+ }
+}
+
+bool IsArithmeticOp(TOperator op)
+{
+ switch (op)
+ {
+ case EOpNegative:
+
+ case EOpPostIncrement:
+ case EOpPostDecrement:
+ case EOpPreIncrement:
+ case EOpPreDecrement:
+
+ case EOpAdd:
+ case EOpSub:
+ case EOpMul:
+ case EOpDiv:
+ case EOpIMod:
+
+ case EOpVectorTimesScalar:
+ case EOpVectorTimesMatrix:
+ case EOpMatrixTimesVector:
+ case EOpMatrixTimesScalar:
+ case EOpMatrixTimesMatrix:
+
+ case EOpAddAssign:
+ case EOpSubAssign:
+
+ case EOpMulAssign:
+ case EOpVectorTimesMatrixAssign:
+ case EOpVectorTimesScalarAssign:
+ case EOpMatrixTimesScalarAssign:
+ case EOpMatrixTimesMatrixAssign:
+
+ case EOpDivAssign:
+ case EOpIModAssign:
+
+ case EOpDot:
+ return true;
+ default:
+ return false;
+ }
+}
+
+// A traverser that gathers the following information, used to kick off processing:
+//
+// - For each variable, the AST nodes that modify it.
+// - The set of |precise| return AST node.
+// - The set of |precise| access chains assigned to.
+//
+class InfoGatherTraverser : public TIntermTraverser
+{
+ public:
+ InfoGatherTraverser(ASTInfo *info) : TIntermTraverser(true, false, false), mInfo(info) {}
+
+ bool visitUnary(Visit visit, TIntermUnary *node) override
+ {
+ // If the node is an assignment (i.e. ++ and --), store the relevant information.
+ if (!IsAssignment(node->getOp()))
+ {
+ return true;
+ }
+
+ visitLvalue(node, node->getOperand());
+ return false;
+ }
+
+ bool visitBinary(Visit visit, TIntermBinary *node) override
+ {
+ if (IsAssignment(node->getOp()))
+ {
+ visitLvalue(node, node->getLeft());
+
+ node->getRight()->traverse(this);
+
+ return false;
+ }
+
+ return true;
+ }
+
+ bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
+ {
+ const TIntermSequence &sequence = *(node->getSequence());
+ TIntermSymbol *symbol = sequence.front()->getAsSymbolNode();
+ TIntermBinary *initNode = sequence.front()->getAsBinaryNode();
+ TIntermTyped *initExpression = nullptr;
+
+ if (symbol == nullptr)
+ {
+ ASSERT(initNode->getOp() == EOpInitialize);
+
+ symbol = initNode->getLeft()->getAsSymbolNode();
+ initExpression = initNode->getRight();
+ }
+
+ ASSERT(symbol);
+ ObjectAndAccessChain object = {&symbol->variable(), {}};
+ AddObjectIfPrecise(mInfo, object);
+
+ if (initExpression)
+ {
+ mInfo->variableAssignmentNodeMap[object.variable].push_back(initNode);
+
+ // Visit the init expression, which may itself have assignments.
+ initExpression->traverse(this);
+ }
+
+ return false;
+ }
+
+ bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
+ {
+ mCurrentFunction = node->getFunction();
+
+ for (size_t paramIndex = 0; paramIndex < mCurrentFunction->getParamCount(); ++paramIndex)
+ {
+ ObjectAndAccessChain param = {mCurrentFunction->getParam(paramIndex), {}};
+ AddObjectIfPrecise(mInfo, param);
+ }
+
+ return true;
+ }
+
+ bool visitBranch(Visit visit, TIntermBranch *node) override
+ {
+ if (node->getFlowOp() == EOpReturn && node->getChildCount() == 1 &&
+ mCurrentFunction->getReturnType().isPrecise())
+ {
+ mInfo->preciseReturnNodes.insert(node);
+ }
+
+ return true;
+ }
+
+ bool visitGlobalQualifierDeclaration(Visit visit,
+ TIntermGlobalQualifierDeclaration *node) override
+ {
+ if (node->isPrecise())
+ {
+ ObjectAndAccessChain preciseObject = {&node->getSymbol()->variable(), {}};
+ AddPreciseObject(mInfo, preciseObject);
+ }
+
+ return false;
+ }
+
+ private:
+ void visitLvalue(TIntermOperator *assignmentNode, TIntermTyped *lvalueNode)
+ {
+ AccessChain lvalueChain;
+ const TVariable *lvalueBase = lvalueChain.build(lvalueNode);
+ mInfo->variableAssignmentNodeMap[lvalueBase].push_back(assignmentNode);
+
+ ObjectAndAccessChain lvalue = {lvalueBase, lvalueChain};
+ AddObjectIfPrecise(mInfo, lvalue);
+
+ TraverseIndexNodesOnly(lvalueNode, this);
+ }
+
+ ASTInfo *mInfo = nullptr;
+ const TFunction *mCurrentFunction = nullptr;
+};
+
+// A traverser that, given an access chain, traverses an expression and marks parts of it |precise|.
+// For example, in the expression |Struct1(a, Struct2(b, c), d)|:
+//
+// - Given access chain [1], both |b| and |c| are marked precise.
+// - Given access chain [1, 0], only |b| is marked precise.
+//
+// When access chain is empty, arithmetic nodes are marked |precise| and any access chains found in
+// their children is recursively added for processing.
+//
+// The access chain given to the traverser is derived from the left hand side of an assignment,
+// while the traverser is run on the right hand side.
+class PropagatePreciseTraverser : public TIntermTraverser
+{
+ public:
+ PropagatePreciseTraverser(ASTInfo *info) : TIntermTraverser(true, false, false), mInfo(info) {}
+
+ void propagatePrecise(TIntermNode *expression, const AccessChain &accessChain)
+ {
+ mCurrentAccessChain = accessChain;
+ expression->traverse(this);
+ }
+
+ bool visitUnary(Visit visit, TIntermUnary *node) override
+ {
+ // Unary operations cannot be applied to structures.
+ ASSERT(mCurrentAccessChain.getChain().empty());
+
+ // Mark arithmetic nodes as |precise|.
+ if (IsArithmeticOp(node->getOp()))
+ {
+ node->setIsPrecise();
+ }
+
+ // Mark the operand itself |precise| too.
+ return true;
+ }
+
+ bool visitBinary(Visit visit, TIntermBinary *node) override
+ {
+ if (IsIndexOp(node->getOp()))
+ {
+ // Append the remaining access chain with that of the node, and mark that as |precise|.
+ // For example, if we are evaluating an expression and expecting to mark the access
+ // chain [1, 3] as |precise|, and the node itself has access chain [0, 2] applied to
+ // variable V, then what ends up being |precise| is V with access chain [0, 2, 1, 3].
+ AccessChain nodeAccessChain;
+ const TVariable *baseVariable = nodeAccessChain.build(node);
+ nodeAccessChain.append(mCurrentAccessChain);
+
+ ObjectAndAccessChain preciseObject = {baseVariable, nodeAccessChain};
+ AddPreciseObject(mInfo, preciseObject);
+
+ // Visit index nodes, each of which should be considered |precise| in its entirety.
+ mCurrentAccessChain.clear();
+ TraverseIndexNodesOnly(node, this);
+
+ return false;
+ }
+
+ if (node->getOp() == EOpComma)
+ {
+ // For expr1,expr2, consider only expr2 as that's the one whose calculation is relevant.
+ node->getRight()->traverse(this);
+ return false;
+ }
+
+ // Mark arithmetic nodes as |precise|.
+ if (IsArithmeticOp(node->getOp()))
+ {
+ node->setIsPrecise();
+ }
+
+ if (IsAssignment(node->getOp()) || node->getOp() == EOpInitialize)
+ {
+ // If the node itself is a[...] op= expr, consider only expr as |precise|, as that's the
+ // one whose calculation is significant.
+ node->getRight()->traverse(this);
+
+ // The indices used on the left hand side are also significant in their entirety.
+ mCurrentAccessChain.clear();
+ TraverseIndexNodesOnly(node->getLeft(), this);
+
+ return false;
+ }
+
+ // Binary operations cannot be applied to structures.
+ ASSERT(mCurrentAccessChain.getChain().empty());
+
+ // Mark the operands themselves |precise| too.
+ return true;
+ }
+
+ void visitSymbol(TIntermSymbol *symbol) override
+ {
+ // Mark the symbol together with the current access chain as |precise|.
+ ObjectAndAccessChain preciseObject = {&symbol->variable(), mCurrentAccessChain};
+ AddPreciseObject(mInfo, preciseObject);
+ }
+
+ bool visitAggregate(Visit visit, TIntermAggregate *node) override
+ {
+ // If this is a struct constructor and the access chain is not empty, only apply |precise|
+ // to the field selected by the access chain.
+ const TType &type = node->getType();
+ const bool isStructConstructor =
+ node->getOp() == EOpConstruct && type.getStruct() != nullptr && !type.isArray();
+
+ if (!mCurrentAccessChain.getChain().empty() && isStructConstructor)
+ {
+ size_t selectedFieldIndex = mCurrentAccessChain.getChain().front();
+ mCurrentAccessChain.pop_front(1);
+
+ ASSERT(selectedFieldIndex < node->getChildCount());
+
+ // Visit only said field.
+ node->getChildNode(selectedFieldIndex)->traverse(this);
+ return false;
+ }
+
+ // If this is an array constructor, each element is equally |precise| with the same access
+ // chain. Otherwise there cannot be any access chain for constructors.
+ if (node->getOp() == EOpConstruct)
+ {
+ ASSERT(type.isArray() || mCurrentAccessChain.getChain().empty());
+ return true;
+ }
+
+ // Otherwise this is a function call. The access chain is irrelevant and every (non-out)
+ // parameter of the function call should be considered |precise|.
+ mCurrentAccessChain.clear();
+
+ const TFunction *function = node->getFunction();
+ ASSERT(function);
+
+ for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex)
+ {
+ if (function->getParam(paramIndex)->getType().getQualifier() != EvqParamOut)
+ {
+ node->getChildNode(paramIndex)->traverse(this);
+ }
+ }
+
+ // Mark arithmetic nodes as |precise|.
+ if (IsArithmeticOp(node->getOp()))
+ {
+ node->setIsPrecise();
+ }
+
+ return false;
+ }
+
+ private:
+ ASTInfo *mInfo = nullptr;
+ AccessChain mCurrentAccessChain;
+};
+} // anonymous namespace
+
+void FindPreciseNodes(TCompiler *compiler, TIntermBlock *root)
+{
+ ASTInfo info;
+
+ InfoGatherTraverser infoGather(&info);
+ root->traverse(&infoGather);
+
+ PropagatePreciseTraverser propagator(&info);
+
+ // First, get return expressions out of the way by propagating |precise|.
+ for (TIntermBranch *returnNode : info.preciseReturnNodes)
+ {
+ ASSERT(returnNode->getChildCount() == 1);
+ propagator.propagatePrecise(returnNode->getChildNode(0), {});
+ }
+
+ // Now take |precise| access chains one by one, and propagate their |precise|-ness to the right
+ // hand side of all assignments in which they are on the left hand side, as well as the
+ // arithmetic expression that assigns to them.
+
+ while (!info.preciseObjectsToProcess.empty())
+ {
+ // Get one |precise| object to process.
+ auto first = info.preciseObjectsToProcess.begin();
+ const ObjectAndAccessChain toProcess = *first;
+ info.preciseObjectsToProcess.erase(first);
+
+ // Propagate |precise| to every node where it's assigned to.
+ const TVector<TIntermOperator *> &assignmentNodes =
+ info.variableAssignmentNodeMap[toProcess.variable];
+ for (TIntermOperator *assignmentNode : assignmentNodes)
+ {
+ AccessChain assignmentAccessChain = GetAssignmentAccessChain(assignmentNode);
+
+ // There are two possibilities:
+ //
+ // - The assignment is to a bigger access chain than that which is being processed, in
+ // which case the entire right hand side is marked |precise|,
+ // - The assignment is to a smaller access chain, in which case only the subobject of
+ // the right hand side that corresponds to the remaining part of the access chain must
+ // be marked |precise|.
+ //
+ // For example, if processing |a.b.c| as a |precise| access chain:
+ //
+ // - If the assignment is to |a.b.c.d|, then the entire right hand side must be
+ // |precise|.
+ // - If the assignment is to |a.b|, only the |.c| part of the right hand side expression
+ // must be |precise|.
+ // - If the assignment is to |a.e|, there is nothing to do.
+ //
+ AccessChain remainingAccessChain = toProcess.accessChain;
+ if (!remainingAccessChain.removePrefix(assignmentAccessChain))
+ {
+ continue;
+ }
+
+ propagator.propagatePrecise(assignmentNode, remainingAccessChain);
+ }
+ }
+
+ // The AST nodes now contain information gathered by this post-processing step, and so the tree
+ // must no longer be transformed.
+ compiler->enableValidateNoMoreTransformations();
+}
+
+} // namespace sh