diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /gfx/angle/checkout/src/compiler/translator/tree_util/FindPreciseNodes.cpp | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'gfx/angle/checkout/src/compiler/translator/tree_util/FindPreciseNodes.cpp')
-rw-r--r-- | gfx/angle/checkout/src/compiler/translator/tree_util/FindPreciseNodes.cpp | 703 |
1 files changed, 703 insertions, 0 deletions
diff --git a/gfx/angle/checkout/src/compiler/translator/tree_util/FindPreciseNodes.cpp b/gfx/angle/checkout/src/compiler/translator/tree_util/FindPreciseNodes.cpp new file mode 100644 index 0000000000..2943117314 --- /dev/null +++ b/gfx/angle/checkout/src/compiler/translator/tree_util/FindPreciseNodes.cpp @@ -0,0 +1,703 @@ +// +// Copyright 2021 The ANGLE Project Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// FindPreciseNodes.cpp: Propagates |precise| to AST nodes. +// +// The high level algorithm is as follows. For every node that "assigns" to a precise object, +// subobject (a precise struct whose field is being assigned) or superobject (a struct with a +// precise field), two things happen: +// +// - The operation is marked precise if it's an arithmetic operation +// - The right hand side of the assignment is made precise. If only a subobject is precise, only +// the corresponding subobject of the right hand side is made precise. +// + +#include "compiler/translator/tree_util/FindPreciseNodes.h" + +#include "common/hash_utils.h" +#include "compiler/translator/Compiler.h" +#include "compiler/translator/IntermNode.h" +#include "compiler/translator/Symbol.h" +#include "compiler/translator/tree_util/IntermTraverse.h" + +namespace sh +{ + +namespace +{ + +// An access chain applied to a variable. The |precise|-ness of a node does not change when +// indexing arrays, selecting matrix columns or swizzle vectors. This access chain thus only +// includes block field selections. The access chain is used to identify the part of an object +// that is or should be |precise|. If both a.b.c and a.b are precise, only a.b is every considered. +class AccessChain +{ + public: + AccessChain() = default; + + bool operator==(const AccessChain &other) const { return mChain == other.mChain; } + + const TVariable *build(TIntermTyped *lvalue); + + const TVector<size_t> &getChain() const { return mChain; } + + void reduceChain(size_t newSize) + { + ASSERT(newSize <= mChain.size()); + mChain.resize(newSize); + } + void clear() { reduceChain(0); } + void push_back(size_t index) { mChain.push_back(index); } + void pop_front(size_t n); + void append(const AccessChain &other) + { + mChain.insert(mChain.end(), other.mChain.begin(), other.mChain.end()); + } + bool removePrefix(const AccessChain &other); + + private: + TVector<size_t> mChain; +}; + +bool IsIndexOp(TOperator op) +{ + switch (op) + { + case EOpIndexDirect: + case EOpIndexDirectStruct: + case EOpIndexDirectInterfaceBlock: + case EOpIndexIndirect: + return true; + default: + return false; + } +} + +const TVariable *AccessChain::build(TIntermTyped *lvalue) +{ + if (lvalue->getAsSwizzleNode()) + { + return build(lvalue->getAsSwizzleNode()->getOperand()); + } + if (lvalue->getAsSymbolNode()) + { + const TVariable *var = &lvalue->getAsSymbolNode()->variable(); + + // For fields of nameless interface blocks, add the field index too. + if (var->getType().getInterfaceBlock() != nullptr) + { + mChain.push_back(var->getType().getInterfaceBlockFieldIndex()); + } + + return var; + } + TIntermBinary *binary = lvalue->getAsBinaryNode(); + ASSERT(binary); + + TOperator op = binary->getOp(); + ASSERT(IsIndexOp(op)); + + const TVariable *var = build(binary->getLeft()); + + if (op == EOpIndexDirectStruct || op == EOpIndexDirectInterfaceBlock) + { + int fieldIndex = binary->getRight()->getAsConstantUnion()->getIConst(0); + mChain.push_back(fieldIndex); + } + + return var; +} + +void AccessChain::pop_front(size_t n) +{ + std::rotate(mChain.begin(), mChain.begin() + n, mChain.end()); + reduceChain(mChain.size() - n); +} + +bool AccessChain::removePrefix(const AccessChain &other) +{ + // First, make sure the common part of the two access chains match. + size_t commonSize = std::min(mChain.size(), other.mChain.size()); + + for (size_t index = 0; index < commonSize; ++index) + { + if (mChain[index] != other.mChain[index]) + { + return false; + } + } + + // Remove the common part from the access chain. If other is a deeper access chain, this access + // chain will become empty. + pop_front(commonSize); + + return true; +} + +AccessChain GetAssignmentAccessChain(TIntermOperator *node) +{ + // The assignment is either a unary or a binary node, and the lvalue is always the first child. + AccessChain lvalueAccessChain; + lvalueAccessChain.build(node->getChildNode(0)->getAsTyped()); + return lvalueAccessChain; +} + +template <typename Traverser> +void TraverseIndexNodesOnly(TIntermNode *node, Traverser *traverser) +{ + if (node->getAsSwizzleNode()) + { + node = node->getAsSwizzleNode()->getOperand(); + } + + if (node->getAsSymbolNode()) + { + return; + } + + TIntermBinary *binary = node->getAsBinaryNode(); + ASSERT(binary); + + TOperator op = binary->getOp(); + ASSERT(IsIndexOp(op)); + + if (op == EOpIndexIndirect) + { + binary->getRight()->traverse(traverser); + } + + TraverseIndexNodesOnly(binary->getLeft(), traverser); +} + +// An object, which could be a sub-object of a variable. +struct ObjectAndAccessChain +{ + const TVariable *variable; + AccessChain accessChain; +}; + +bool operator==(const ObjectAndAccessChain &a, const ObjectAndAccessChain &b) +{ + return a.variable == b.variable && a.accessChain == b.accessChain; +} + +struct ObjectAndAccessChainHash +{ + size_t operator()(const ObjectAndAccessChain &object) const + { + size_t result = angle::ComputeGenericHash(&object.variable, sizeof(object.variable)); + if (!object.accessChain.getChain().empty()) + { + result = + result ^ angle::ComputeGenericHash(object.accessChain.getChain().data(), + object.accessChain.getChain().size() * + sizeof(object.accessChain.getChain()[0])); + } + return result; + } +}; + +// A map from variables to AST nodes that modify them (i.e. nodes where IsAssignment(op)). +using VariableToAssignmentNodeMap = angle::HashMap<const TVariable *, TVector<TIntermOperator *>>; +// A set of |return| nodes from functions with a |precise| return value. +using PreciseReturnNodes = angle::HashSet<TIntermBranch *>; +// A set of precise objects that need processing, or have been processed. +using PreciseObjectSet = angle::HashSet<ObjectAndAccessChain, ObjectAndAccessChainHash>; + +struct ASTInfo +{ + // Generic information about the tree: + VariableToAssignmentNodeMap variableAssignmentNodeMap; + // Information pertaining to |precise| expressions: + PreciseReturnNodes preciseReturnNodes; + PreciseObjectSet preciseObjectsToProcess; + PreciseObjectSet preciseObjectsVisited; +}; + +int GetObjectPreciseSubChainLength(const ObjectAndAccessChain &object) +{ + const TType &type = object.variable->getType(); + + if (type.isPrecise()) + { + return 0; + } + + const TFieldListCollection *block = type.getInterfaceBlock(); + if (block == nullptr) + { + block = type.getStruct(); + } + const TVector<size_t> &accessChain = object.accessChain.getChain(); + + for (size_t length = 0; length < accessChain.size(); ++length) + { + ASSERT(block != nullptr); + + const TField *field = block->fields()[accessChain[length]]; + if (field->type()->isPrecise()) + { + return static_cast<int>(length + 1); + } + + block = field->type()->getStruct(); + } + + return -1; +} + +void AddPreciseObject(ASTInfo *info, const ObjectAndAccessChain &object) +{ + if (info->preciseObjectsVisited.count(object) > 0) + { + return; + } + + info->preciseObjectsToProcess.insert(object); + info->preciseObjectsVisited.insert(object); +} + +void AddPreciseSubObjects(ASTInfo *info, const ObjectAndAccessChain &object); + +void AddObjectIfPrecise(ASTInfo *info, const ObjectAndAccessChain &object) +{ + // See if the access chain is already precise, and if so add the minimum access chain that is + // precise. + int preciseSubChainLength = GetObjectPreciseSubChainLength(object); + if (preciseSubChainLength == -1) + { + // If the access chain is not precise, see if there are any fields of it that are precise, + // and add those individually. + AddPreciseSubObjects(info, object); + return; + } + + ObjectAndAccessChain preciseObject = object; + preciseObject.accessChain.reduceChain(preciseSubChainLength); + + AddPreciseObject(info, preciseObject); +} + +void AddPreciseSubObjects(ASTInfo *info, const ObjectAndAccessChain &object) +{ + const TFieldListCollection *block = object.variable->getType().getInterfaceBlock(); + if (block == nullptr) + { + block = object.variable->getType().getStruct(); + } + const TVector<size_t> &accessChain = object.accessChain.getChain(); + + for (size_t length = 0; length < accessChain.size(); ++length) + { + block = block->fields()[accessChain[length]]->type()->getStruct(); + } + + if (block == nullptr) + { + return; + } + + for (size_t fieldIndex = 0; fieldIndex < block->fields().size(); ++fieldIndex) + { + ObjectAndAccessChain subObject = object; + subObject.accessChain.push_back(fieldIndex); + + // If the field is precise, add it as a precise subobject. Otherwise recurse. + if (block->fields()[fieldIndex]->type()->isPrecise()) + { + AddPreciseObject(info, subObject); + } + else + { + AddPreciseSubObjects(info, subObject); + } + } +} + +bool IsArithmeticOp(TOperator op) +{ + switch (op) + { + case EOpNegative: + + case EOpPostIncrement: + case EOpPostDecrement: + case EOpPreIncrement: + case EOpPreDecrement: + + case EOpAdd: + case EOpSub: + case EOpMul: + case EOpDiv: + case EOpIMod: + + case EOpVectorTimesScalar: + case EOpVectorTimesMatrix: + case EOpMatrixTimesVector: + case EOpMatrixTimesScalar: + case EOpMatrixTimesMatrix: + + case EOpAddAssign: + case EOpSubAssign: + + case EOpMulAssign: + case EOpVectorTimesMatrixAssign: + case EOpVectorTimesScalarAssign: + case EOpMatrixTimesScalarAssign: + case EOpMatrixTimesMatrixAssign: + + case EOpDivAssign: + case EOpIModAssign: + + case EOpDot: + return true; + default: + return false; + } +} + +// A traverser that gathers the following information, used to kick off processing: +// +// - For each variable, the AST nodes that modify it. +// - The set of |precise| return AST node. +// - The set of |precise| access chains assigned to. +// +class InfoGatherTraverser : public TIntermTraverser +{ + public: + InfoGatherTraverser(ASTInfo *info) : TIntermTraverser(true, false, false), mInfo(info) {} + + bool visitUnary(Visit visit, TIntermUnary *node) override + { + // If the node is an assignment (i.e. ++ and --), store the relevant information. + if (!IsAssignment(node->getOp())) + { + return true; + } + + visitLvalue(node, node->getOperand()); + return false; + } + + bool visitBinary(Visit visit, TIntermBinary *node) override + { + if (IsAssignment(node->getOp())) + { + visitLvalue(node, node->getLeft()); + + node->getRight()->traverse(this); + + return false; + } + + return true; + } + + bool visitDeclaration(Visit visit, TIntermDeclaration *node) override + { + const TIntermSequence &sequence = *(node->getSequence()); + TIntermSymbol *symbol = sequence.front()->getAsSymbolNode(); + TIntermBinary *initNode = sequence.front()->getAsBinaryNode(); + TIntermTyped *initExpression = nullptr; + + if (symbol == nullptr) + { + ASSERT(initNode->getOp() == EOpInitialize); + + symbol = initNode->getLeft()->getAsSymbolNode(); + initExpression = initNode->getRight(); + } + + ASSERT(symbol); + ObjectAndAccessChain object = {&symbol->variable(), {}}; + AddObjectIfPrecise(mInfo, object); + + if (initExpression) + { + mInfo->variableAssignmentNodeMap[object.variable].push_back(initNode); + + // Visit the init expression, which may itself have assignments. + initExpression->traverse(this); + } + + return false; + } + + bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override + { + mCurrentFunction = node->getFunction(); + + for (size_t paramIndex = 0; paramIndex < mCurrentFunction->getParamCount(); ++paramIndex) + { + ObjectAndAccessChain param = {mCurrentFunction->getParam(paramIndex), {}}; + AddObjectIfPrecise(mInfo, param); + } + + return true; + } + + bool visitBranch(Visit visit, TIntermBranch *node) override + { + if (node->getFlowOp() == EOpReturn && node->getChildCount() == 1 && + mCurrentFunction->getReturnType().isPrecise()) + { + mInfo->preciseReturnNodes.insert(node); + } + + return true; + } + + bool visitGlobalQualifierDeclaration(Visit visit, + TIntermGlobalQualifierDeclaration *node) override + { + if (node->isPrecise()) + { + ObjectAndAccessChain preciseObject = {&node->getSymbol()->variable(), {}}; + AddPreciseObject(mInfo, preciseObject); + } + + return false; + } + + private: + void visitLvalue(TIntermOperator *assignmentNode, TIntermTyped *lvalueNode) + { + AccessChain lvalueChain; + const TVariable *lvalueBase = lvalueChain.build(lvalueNode); + mInfo->variableAssignmentNodeMap[lvalueBase].push_back(assignmentNode); + + ObjectAndAccessChain lvalue = {lvalueBase, lvalueChain}; + AddObjectIfPrecise(mInfo, lvalue); + + TraverseIndexNodesOnly(lvalueNode, this); + } + + ASTInfo *mInfo = nullptr; + const TFunction *mCurrentFunction = nullptr; +}; + +// A traverser that, given an access chain, traverses an expression and marks parts of it |precise|. +// For example, in the expression |Struct1(a, Struct2(b, c), d)|: +// +// - Given access chain [1], both |b| and |c| are marked precise. +// - Given access chain [1, 0], only |b| is marked precise. +// +// When access chain is empty, arithmetic nodes are marked |precise| and any access chains found in +// their children is recursively added for processing. +// +// The access chain given to the traverser is derived from the left hand side of an assignment, +// while the traverser is run on the right hand side. +class PropagatePreciseTraverser : public TIntermTraverser +{ + public: + PropagatePreciseTraverser(ASTInfo *info) : TIntermTraverser(true, false, false), mInfo(info) {} + + void propagatePrecise(TIntermNode *expression, const AccessChain &accessChain) + { + mCurrentAccessChain = accessChain; + expression->traverse(this); + } + + bool visitUnary(Visit visit, TIntermUnary *node) override + { + // Unary operations cannot be applied to structures. + ASSERT(mCurrentAccessChain.getChain().empty()); + + // Mark arithmetic nodes as |precise|. + if (IsArithmeticOp(node->getOp())) + { + node->setIsPrecise(); + } + + // Mark the operand itself |precise| too. + return true; + } + + bool visitBinary(Visit visit, TIntermBinary *node) override + { + if (IsIndexOp(node->getOp())) + { + // Append the remaining access chain with that of the node, and mark that as |precise|. + // For example, if we are evaluating an expression and expecting to mark the access + // chain [1, 3] as |precise|, and the node itself has access chain [0, 2] applied to + // variable V, then what ends up being |precise| is V with access chain [0, 2, 1, 3]. + AccessChain nodeAccessChain; + const TVariable *baseVariable = nodeAccessChain.build(node); + nodeAccessChain.append(mCurrentAccessChain); + + ObjectAndAccessChain preciseObject = {baseVariable, nodeAccessChain}; + AddPreciseObject(mInfo, preciseObject); + + // Visit index nodes, each of which should be considered |precise| in its entirety. + mCurrentAccessChain.clear(); + TraverseIndexNodesOnly(node, this); + + return false; + } + + if (node->getOp() == EOpComma) + { + // For expr1,expr2, consider only expr2 as that's the one whose calculation is relevant. + node->getRight()->traverse(this); + return false; + } + + // Mark arithmetic nodes as |precise|. + if (IsArithmeticOp(node->getOp())) + { + node->setIsPrecise(); + } + + if (IsAssignment(node->getOp()) || node->getOp() == EOpInitialize) + { + // If the node itself is a[...] op= expr, consider only expr as |precise|, as that's the + // one whose calculation is significant. + node->getRight()->traverse(this); + + // The indices used on the left hand side are also significant in their entirety. + mCurrentAccessChain.clear(); + TraverseIndexNodesOnly(node->getLeft(), this); + + return false; + } + + // Binary operations cannot be applied to structures. + ASSERT(mCurrentAccessChain.getChain().empty()); + + // Mark the operands themselves |precise| too. + return true; + } + + void visitSymbol(TIntermSymbol *symbol) override + { + // Mark the symbol together with the current access chain as |precise|. + ObjectAndAccessChain preciseObject = {&symbol->variable(), mCurrentAccessChain}; + AddPreciseObject(mInfo, preciseObject); + } + + bool visitAggregate(Visit visit, TIntermAggregate *node) override + { + // If this is a struct constructor and the access chain is not empty, only apply |precise| + // to the field selected by the access chain. + const TType &type = node->getType(); + const bool isStructConstructor = + node->getOp() == EOpConstruct && type.getStruct() != nullptr && !type.isArray(); + + if (!mCurrentAccessChain.getChain().empty() && isStructConstructor) + { + size_t selectedFieldIndex = mCurrentAccessChain.getChain().front(); + mCurrentAccessChain.pop_front(1); + + ASSERT(selectedFieldIndex < node->getChildCount()); + + // Visit only said field. + node->getChildNode(selectedFieldIndex)->traverse(this); + return false; + } + + // If this is an array constructor, each element is equally |precise| with the same access + // chain. Otherwise there cannot be any access chain for constructors. + if (node->getOp() == EOpConstruct) + { + ASSERT(type.isArray() || mCurrentAccessChain.getChain().empty()); + return true; + } + + // Otherwise this is a function call. The access chain is irrelevant and every (non-out) + // parameter of the function call should be considered |precise|. + mCurrentAccessChain.clear(); + + const TFunction *function = node->getFunction(); + ASSERT(function); + + for (size_t paramIndex = 0; paramIndex < function->getParamCount(); ++paramIndex) + { + if (function->getParam(paramIndex)->getType().getQualifier() != EvqParamOut) + { + node->getChildNode(paramIndex)->traverse(this); + } + } + + // Mark arithmetic nodes as |precise|. + if (IsArithmeticOp(node->getOp())) + { + node->setIsPrecise(); + } + + return false; + } + + private: + ASTInfo *mInfo = nullptr; + AccessChain mCurrentAccessChain; +}; +} // anonymous namespace + +void FindPreciseNodes(TCompiler *compiler, TIntermBlock *root) +{ + ASTInfo info; + + InfoGatherTraverser infoGather(&info); + root->traverse(&infoGather); + + PropagatePreciseTraverser propagator(&info); + + // First, get return expressions out of the way by propagating |precise|. + for (TIntermBranch *returnNode : info.preciseReturnNodes) + { + ASSERT(returnNode->getChildCount() == 1); + propagator.propagatePrecise(returnNode->getChildNode(0), {}); + } + + // Now take |precise| access chains one by one, and propagate their |precise|-ness to the right + // hand side of all assignments in which they are on the left hand side, as well as the + // arithmetic expression that assigns to them. + + while (!info.preciseObjectsToProcess.empty()) + { + // Get one |precise| object to process. + auto first = info.preciseObjectsToProcess.begin(); + const ObjectAndAccessChain toProcess = *first; + info.preciseObjectsToProcess.erase(first); + + // Propagate |precise| to every node where it's assigned to. + const TVector<TIntermOperator *> &assignmentNodes = + info.variableAssignmentNodeMap[toProcess.variable]; + for (TIntermOperator *assignmentNode : assignmentNodes) + { + AccessChain assignmentAccessChain = GetAssignmentAccessChain(assignmentNode); + + // There are two possibilities: + // + // - The assignment is to a bigger access chain than that which is being processed, in + // which case the entire right hand side is marked |precise|, + // - The assignment is to a smaller access chain, in which case only the subobject of + // the right hand side that corresponds to the remaining part of the access chain must + // be marked |precise|. + // + // For example, if processing |a.b.c| as a |precise| access chain: + // + // - If the assignment is to |a.b.c.d|, then the entire right hand side must be + // |precise|. + // - If the assignment is to |a.b|, only the |.c| part of the right hand side expression + // must be |precise|. + // - If the assignment is to |a.e|, there is nothing to do. + // + AccessChain remainingAccessChain = toProcess.accessChain; + if (!remainingAccessChain.removePrefix(assignmentAccessChain)) + { + continue; + } + + propagator.propagatePrecise(assignmentNode, remainingAccessChain); + } + } + + // The AST nodes now contain information gathered by this post-processing step, and so the tree + // must no longer be transformed. + compiler->enableValidateNoMoreTransformations(); +} + +} // namespace sh |