// // Copyright 2021 The ANGLE Project Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // // FindPreciseNodes.cpp: Propagates |precise| to AST nodes. // // The high level algorithm is as follows. For every node that "assigns" to a precise object, // subobject (a precise struct whose field is being assigned) or superobject (a struct with a // precise field), two things happen: // // - The operation is marked precise if it's an arithmetic operation // - The right hand side of the assignment is made precise. If only a subobject is precise, only // the corresponding subobject of the right hand side is made precise. // #include "compiler/translator/tree_util/FindPreciseNodes.h" #include "common/hash_utils.h" #include "compiler/translator/Compiler.h" #include "compiler/translator/IntermNode.h" #include "compiler/translator/Symbol.h" #include "compiler/translator/tree_util/IntermTraverse.h" namespace sh { namespace { // An access chain applied to a variable. The |precise|-ness of a node does not change when // indexing arrays, selecting matrix columns or swizzle vectors. This access chain thus only // includes block field selections. The access chain is used to identify the part of an object // that is or should be |precise|. If both a.b.c and a.b are precise, only a.b is every considered. class AccessChain { public: AccessChain() = default; bool operator==(const AccessChain &other) const { return mChain == other.mChain; } const TVariable *build(TIntermTyped *lvalue); const TVector &getChain() const { return mChain; } void reduceChain(size_t newSize) { ASSERT(newSize <= mChain.size()); mChain.resize(newSize); } void clear() { reduceChain(0); } void push_back(size_t index) { mChain.push_back(index); } void pop_front(size_t n); void append(const AccessChain &other) { mChain.insert(mChain.end(), other.mChain.begin(), other.mChain.end()); } bool removePrefix(const AccessChain &other); private: TVector mChain; }; bool IsIndexOp(TOperator op) { switch (op) { case EOpIndexDirect: case EOpIndexDirectStruct: case EOpIndexDirectInterfaceBlock: case EOpIndexIndirect: return true; default: return false; } } const TVariable *AccessChain::build(TIntermTyped *lvalue) { if (lvalue->getAsSwizzleNode()) { return build(lvalue->getAsSwizzleNode()->getOperand()); } if (lvalue->getAsSymbolNode()) { const TVariable *var = &lvalue->getAsSymbolNode()->variable(); // For fields of nameless interface blocks, add the field index too. if (var->getType().getInterfaceBlock() != nullptr) { mChain.push_back(var->getType().getInterfaceBlockFieldIndex()); } return var; } TIntermBinary *binary = lvalue->getAsBinaryNode(); ASSERT(binary); TOperator op = binary->getOp(); ASSERT(IsIndexOp(op)); const TVariable *var = build(binary->getLeft()); if (op == EOpIndexDirectStruct || op == EOpIndexDirectInterfaceBlock) { int fieldIndex = binary->getRight()->getAsConstantUnion()->getIConst(0); mChain.push_back(fieldIndex); } return var; } void AccessChain::pop_front(size_t n) { std::rotate(mChain.begin(), mChain.begin() + n, mChain.end()); reduceChain(mChain.size() - n); } bool AccessChain::removePrefix(const AccessChain &other) { // First, make sure the common part of the two access chains match. size_t commonSize = std::min(mChain.size(), other.mChain.size()); for (size_t index = 0; index < commonSize; ++index) { if (mChain[index] != other.mChain[index]) { return false; } } // Remove the common part from the access chain. If other is a deeper access chain, this access // chain will become empty. pop_front(commonSize); return true; } AccessChain GetAssignmentAccessChain(TIntermOperator *node) { // The assignment is either a unary or a binary node, and the lvalue is always the first child. AccessChain lvalueAccessChain; lvalueAccessChain.build(node->getChildNode(0)->getAsTyped()); return lvalueAccessChain; } template void TraverseIndexNodesOnly(TIntermNode *node, Traverser *traverser) { if (node->getAsSwizzleNode()) { node = node->getAsSwizzleNode()->getOperand(); } if (node->getAsSymbolNode()) { return; } TIntermBinary *binary = node->getAsBinaryNode(); ASSERT(binary); TOperator op = binary->getOp(); ASSERT(IsIndexOp(op)); if (op == EOpIndexIndirect) { binary->getRight()->traverse(traverser); } TraverseIndexNodesOnly(binary->getLeft(), traverser); } // An object, which could be a sub-object of a variable. struct ObjectAndAccessChain { const TVariable *variable; AccessChain accessChain; }; bool operator==(const ObjectAndAccessChain &a, const ObjectAndAccessChain &b) { return a.variable == b.variable && a.accessChain == b.accessChain; } struct ObjectAndAccessChainHash { size_t operator()(const ObjectAndAccessChain &object) const { size_t result = angle::ComputeGenericHash(&object.variable, sizeof(object.variable)); if (!object.accessChain.getChain().empty()) { result = result ^ angle::ComputeGenericHash(object.accessChain.getChain().data(), object.accessChain.getChain().size() * sizeof(object.accessChain.getChain()[0])); } return result; } }; // A map from variables to AST nodes that modify them (i.e. nodes where IsAssignment(op)). using VariableToAssignmentNodeMap = angle::HashMap>; // A set of |return| nodes from functions with a |precise| return value. using PreciseReturnNodes = angle::HashSet; // A set of precise objects that need processing, or have been processed. using PreciseObjectSet = angle::HashSet; struct ASTInfo { // Generic information about the tree: VariableToAssignmentNodeMap variableAssignmentNodeMap; // Information pertaining to |precise| expressions: PreciseReturnNodes preciseReturnNodes; PreciseObjectSet preciseObjectsToProcess; PreciseObjectSet preciseObjectsVisited; }; int GetObjectPreciseSubChainLength(const ObjectAndAccessChain &object) { const TType &type = object.variable->getType(); if (type.isPrecise()) { return 0; } const TFieldListCollection *block = type.getInterfaceBlock(); if (block == nullptr) { block = type.getStruct(); } const TVector &accessChain = object.accessChain.getChain(); for (size_t length = 0; length < accessChain.size(); ++length) { ASSERT(block != nullptr); const TField *field = block->fields()[accessChain[length]]; if (field->type()->isPrecise()) { return static_cast(length + 1); } block = field->type()->getStruct(); } return -1; } void AddPreciseObject(ASTInfo *info, const ObjectAndAccessChain &object) { if (info->preciseObjectsVisited.count(object) > 0) { return; } info->preciseObjectsToProcess.insert(object); info->preciseObjectsVisited.insert(object); } void AddPreciseSubObjects(ASTInfo *info, const ObjectAndAccessChain &object); void AddObjectIfPrecise(ASTInfo *info, const ObjectAndAccessChain &object) { // See if the access chain is already precise, and if so add the minimum access chain that is // precise. int preciseSubChainLength = GetObjectPreciseSubChainLength(object); if (preciseSubChainLength == -1) { // If the access chain is not precise, see if there are any fields of it that are precise, // and add those individually. AddPreciseSubObjects(info, object); return; } ObjectAndAccessChain preciseObject = object; preciseObject.accessChain.reduceChain(preciseSubChainLength); AddPreciseObject(info, preciseObject); } void AddPreciseSubObjects(ASTInfo *info, const ObjectAndAccessChain &object) { const TFieldListCollection *block = object.variable->getType().getInterfaceBlock(); if (block == nullptr) { block = object.variable->getType().getStruct(); } const TVector &accessChain = object.accessChain.getChain(); for (size_t length = 0; length < accessChain.size(); ++length) { block = block->fields()[accessChain[length]]->type()->getStruct(); } if (block == nullptr) { return; } for (size_t fieldIndex = 0; fieldIndex < block->fields().size(); ++fieldIndex) { ObjectAndAccessChain subObject = object; subObject.accessChain.push_back(fieldIndex); // If the field is precise, add it as a precise subobject. Otherwise recurse. if (block->fields()[fieldIndex]->type()->isPrecise()) { AddPreciseObject(info, subObject); } else { AddPreciseSubObjects(info, subObject); } } } bool IsArithmeticOp(TOperator op) { switch (op) { case EOpNegative: case EOpPostIncrement: case EOpPostDecrement: case EOpPreIncrement: case EOpPreDecrement: case EOpAdd: case EOpSub: case EOpMul: case EOpDiv: case EOpIMod: case EOpVectorTimesScalar: case EOpVectorTimesMatrix: case EOpMatrixTimesVector: case EOpMatrixTimesScalar: case EOpMatrixTimesMatrix: case EOpAddAssign: case EOpSubAssign: case EOpMulAssign: case EOpVectorTimesMatrixAssign: case EOpVectorTimesScalarAssign: case EOpMatrixTimesScalarAssign: case EOpMatrixTimesMatrixAssign: case EOpDivAssign: case EOpIModAssign: case EOpDot: return true; default: return false; } } // A traverser that gathers the following information, used to kick off processing: // // - For each variable, the AST nodes that modify it. // - The set of |precise| return AST node. // - The set of |precise| access chains assigned to. // class InfoGatherTraverser : public TIntermTraverser { public: InfoGatherTraverser(ASTInfo *info) : TIntermTraverser(true, false, false), mInfo(info) {} bool visitUnary(Visit visit, TIntermUnary *node) override { // If the node is an assignment (i.e. ++ and --), store the relevant information. if (!IsAssignment(node->getOp())) { return true; } visitLvalue(node, node->getOperand()); return false; } bool visitBinary(Visit visit, TIntermBinary *node) override { if (IsAssignment(node->getOp())) { visitLvalue(node, node->getLeft()); node->getRight()->traverse(this); return false; } return true; } bool visitDeclaration(Visit visit, TIntermDeclaration *node) override { const TIntermSequence &sequence = *(node->getSequence()); TIntermSymbol *symbol = sequence.front()->getAsSymbolNode(); TIntermBinary *initNode = sequence.front()->getAsBinaryNode(); TIntermTyped *initExpression = nullptr; if (symbol == nullptr) { ASSERT(initNode->getOp() == EOpInitialize); symbol = initNode->getLeft()->getAsSymbolNode(); initExpression = initNode->getRight(); } ASSERT(symbol); ObjectAndAccessChain object = {&symbol->variable(), {}}; AddObjectIfPrecise(mInfo, object); if (initExpression) { mInfo->variableAssignmentNodeMap[object.variable].push_back(initNode); // Visit the init expression, which may itself have assignments. initExpression->traverse(this); } return false; } bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override { mCurrentFunction = node->getFunction(); for (size_t paramIndex = 0; paramIndex < mCurrentFunction->getParamCount(); ++paramIndex) { ObjectAndAccessChain param = {mCurrentFunction->getParam(paramIndex), {}}; AddObjectIfPrecise(mInfo, param); } return true; } bool visitBranch(Visit visit, TIntermBranch *node) override { if (node->getFlowOp() == EOpReturn && node->getChildCount() == 1 && mCurrentFunction->getReturnType().isPrecise()) { mInfo->preciseReturnNodes.insert(node); } return true; } bool visitGlobalQualifierDeclaration(Visit visit, TIntermGlobalQualifierDeclaration *node) override { if (node->isPrecise()) { ObjectAndAccessChain preciseObject = {&node->getSymbol()->variable(), {}}; AddPreciseObject(mInfo, preciseObject); } return false; } private: void visitLvalue(TIntermOperator *assignmentNode, TIntermTyped *lvalueNode) { AccessChain lvalueChain; const TVariable *lvalueBase = lvalueChain.build(lvalueNode); mInfo->variableAssignmentNodeMap[lvalueBase].push_back(assignmentNode); ObjectAndAccessChain lvalue = {lvalueBase, lvalueChain}; AddObjectIfPrecise(mInfo, lvalue); TraverseIndexNodesOnly(lvalueNode, this); } ASTInfo *mInfo = nullptr; const TFunction *mCurrentFunction = nullptr; }; // A traverser that, given an access chain, traverses an expression and marks parts of it |precise|. // For example, in the expression |Struct1(a, Struct2(b, c), d)|: // // - Given access chain [1], both |b| and |c| are marked precise. // - Given access chain [1, 0], only |b| is marked precise. // // When access chain is empty, arithmetic nodes are marked |precise| and any access chains found in // their children is recursively added for processing. // // The access chain given to the traverser is derived from the left hand side of an assignment, // while the traverser is run on the right hand side. class PropagatePreciseTraverser : public TIntermTraverser { public: PropagatePreciseTraverser(ASTInfo *info) : TIntermTraverser(true, false, false), mInfo(info) {} void propagatePrecise(TIntermNode *expression, const AccessChain &accessChain) { mCurrentAccessChain = accessChain; expression->traverse(this); } bool visitUnary(Visit visit, TIntermUnary *node) override { // Unary operations cannot be applied to structures. ASSERT(mCurrentAccessChain.getChain().empty()); // Mark arithmetic nodes as |precise|. if (IsArithmeticOp(node->getOp())) { node->setIsPrecise(); } // Mark the operand itself |precise| too. return true; } bool visitBinary(Visit visit, TIntermBinary *node) override { if (IsIndexOp(node->getOp())) { // Append the remaining access chain with that of the node, and mark that as |precise|. // For example, if we are evaluating an expression and expecting to mark the access // chain [1, 3] as |precise|, and the node itself has access chain [0, 2] applied to // variable V, then what ends up being |precise| is V with access chain [0, 2, 1, 3]. AccessChain nodeAccessChain; const TVariable *baseVariable = nodeAccessChain.build(node); nodeAccessChain.append(mCurrentAccessChain); ObjectAndAccessChain preciseObject = {baseVariable, nodeAccessChain}; AddPreciseObject(mInfo, preciseObject); // Visit index nodes, each of which should be considered |precise| in its entirety. mCurrentAccessChain.clear(); TraverseIndexNodesOnly(node, this); return false; } if (node->getOp() == EOpComma) { // For expr1,expr2, consider only expr2 as that's the one whose calculation is relevant. node->getRight()->traverse(this); return false; } // Mark arithmetic nodes as |precise|. if (IsArithmeticOp(node->getOp())) { node->setIsPrecise(); } if (IsAssignment(node->getOp()) || node->getOp() == EOpInitialize) { // If the node itself is a[...] op= expr, consider only expr as |precise|, as that's the // one whose calculation is significant. node->getRight()->traverse(this); // The indices used on the left hand side are also significant in their entirety. mCurrentAccessChain.clear(); TraverseIndexNodesOnly(node->getLeft(), this); return false; } // Binary operations cannot be applied to structures. ASSERT(mCurrentAccessChain.getChain().empty()); // Mark the operands themselves |precise| too. return true; } void visitSymbol(TIntermSymbol *symbol) override { // Mark the symbol together with the current access chain as |precise|. ObjectAndAccessChain preciseObject = {&symbol->variable(), mCurrentAccessChain}; AddPreciseObject(mInfo, preciseObject); } bool visitAggregate(Visit visit, TIntermAggregate *node) override { // If this is a struct constructor and the access chain is not empty, only apply |precise| // to the field selected by the access chain. const TType &type = node->getType(); const bool isStructConstructor = node->getOp() == EOpConstruct && type.getStruct() != nullptr && !type.isArray(); if (!mCurrentAccessChain.getChain().empty() && isStructConstructor) { size_t selectedFieldIndex = mCurrentAccessChain.getChain().front(); mCurrentAccessChain.pop_front(1); ASSERT(selectedFieldIndex < node->getChildCount()); // Visit only said field. node->getChildNode(selectedFieldIndex)->traverse(this); return false; } // If this is an array constructor, each element is equally |precise| with the same access // chain. Otherwise there cannot be any access chain for constructors. if (node->getOp() == EOpConstruct) { ASSERT(type.isArray() || mCurrentAccessChain.getChain().empty()); return true; } // Otherwise this is a function call. The access chain is irrelevant and every (non-out) // parameter of the function call should be considered |precise|. mCurrentAccessChain.clear(); const TFunction *function = node->getFunction(); ASSERT(function); for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex) { if (function->getParam(paramIndex)->getType().getQualifier() != EvqParamOut) { node->getChildNode(paramIndex)->traverse(this); } } // Mark arithmetic nodes as |precise|. if (IsArithmeticOp(node->getOp())) { node->setIsPrecise(); } return false; } private: ASTInfo *mInfo = nullptr; AccessChain mCurrentAccessChain; }; } // anonymous namespace void FindPreciseNodes(TCompiler *compiler, TIntermBlock *root) { ASTInfo info; InfoGatherTraverser infoGather(&info); root->traverse(&infoGather); PropagatePreciseTraverser propagator(&info); // First, get return expressions out of the way by propagating |precise|. for (TIntermBranch *returnNode : info.preciseReturnNodes) { ASSERT(returnNode->getChildCount() == 1); propagator.propagatePrecise(returnNode->getChildNode(0), {}); } // Now take |precise| access chains one by one, and propagate their |precise|-ness to the right // hand side of all assignments in which they are on the left hand side, as well as the // arithmetic expression that assigns to them. while (!info.preciseObjectsToProcess.empty()) { // Get one |precise| object to process. auto first = info.preciseObjectsToProcess.begin(); const ObjectAndAccessChain toProcess = *first; info.preciseObjectsToProcess.erase(first); // Propagate |precise| to every node where it's assigned to. const TVector &assignmentNodes = info.variableAssignmentNodeMap[toProcess.variable]; for (TIntermOperator *assignmentNode : assignmentNodes) { AccessChain assignmentAccessChain = GetAssignmentAccessChain(assignmentNode); // There are two possibilities: // // - The assignment is to a bigger access chain than that which is being processed, in // which case the entire right hand side is marked |precise|, // - The assignment is to a smaller access chain, in which case only the subobject of // the right hand side that corresponds to the remaining part of the access chain must // be marked |precise|. // // For example, if processing |a.b.c| as a |precise| access chain: // // - If the assignment is to |a.b.c.d|, then the entire right hand side must be // |precise|. // - If the assignment is to |a.b|, only the |.c| part of the right hand side expression // must be |precise|. // - If the assignment is to |a.e|, there is nothing to do. // AccessChain remainingAccessChain = toProcess.accessChain; if (!remainingAccessChain.removePrefix(assignmentAccessChain)) { continue; } propagator.propagatePrecise(assignmentNode, remainingAccessChain); } } // The AST nodes now contain information gathered by this post-processing step, and so the tree // must no longer be transformed. compiler->enableValidateNoMoreTransformations(); } } // namespace sh