summaryrefslogtreecommitdiffstats
path: root/gfx/angle/checkout/src/compiler/translator/tree_ops/RemoveDynamicIndexing.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'gfx/angle/checkout/src/compiler/translator/tree_ops/RemoveDynamicIndexing.cpp')
-rw-r--r--gfx/angle/checkout/src/compiler/translator/tree_ops/RemoveDynamicIndexing.cpp597
1 files changed, 597 insertions, 0 deletions
diff --git a/gfx/angle/checkout/src/compiler/translator/tree_ops/RemoveDynamicIndexing.cpp b/gfx/angle/checkout/src/compiler/translator/tree_ops/RemoveDynamicIndexing.cpp
new file mode 100644
index 0000000000..fda6de6f48
--- /dev/null
+++ b/gfx/angle/checkout/src/compiler/translator/tree_ops/RemoveDynamicIndexing.cpp
@@ -0,0 +1,597 @@
+//
+// Copyright 2002 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of non-SSBO vectors and
+// matrices, replacing them with calls to functions that choose which component to return or write.
+// We don't need to consider dynamic indexing in SSBO since it can be directly as part of the offset
+// of RWByteAddressBuffer.
+//
+
+#include "compiler/translator/tree_ops/RemoveDynamicIndexing.h"
+
+#include "compiler/translator/Compiler.h"
+#include "compiler/translator/Diagnostics.h"
+#include "compiler/translator/InfoSink.h"
+#include "compiler/translator/StaticType.h"
+#include "compiler/translator/SymbolTable.h"
+#include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
+#include "compiler/translator/tree_util/IntermNode_util.h"
+#include "compiler/translator/tree_util/IntermTraverse.h"
+
+namespace sh
+{
+
+namespace
+{
+
+using DynamicIndexingNodeMatcher = std::function<bool(TIntermBinary *)>;
+
+const TType *kIndexType = StaticType::Get<EbtInt, EbpHigh, EvqParamIn, 1, 1>();
+
+constexpr const ImmutableString kBaseName("base");
+constexpr const ImmutableString kIndexName("index");
+constexpr const ImmutableString kValueName("value");
+
+std::string GetIndexFunctionName(const TType &type, bool write)
+{
+ TInfoSinkBase nameSink;
+ nameSink << "dyn_index_";
+ if (write)
+ {
+ nameSink << "write_";
+ }
+ if (type.isMatrix())
+ {
+ nameSink << "mat" << static_cast<uint32_t>(type.getCols()) << "x"
+ << static_cast<uint32_t>(type.getRows());
+ }
+ else
+ {
+ switch (type.getBasicType())
+ {
+ case EbtInt:
+ nameSink << "ivec";
+ break;
+ case EbtBool:
+ nameSink << "bvec";
+ break;
+ case EbtUInt:
+ nameSink << "uvec";
+ break;
+ case EbtFloat:
+ nameSink << "vec";
+ break;
+ default:
+ UNREACHABLE();
+ }
+ nameSink << static_cast<uint32_t>(type.getNominalSize());
+ }
+ return nameSink.str();
+}
+
+TIntermConstantUnion *CreateIntConstantNode(int i)
+{
+ TConstantUnion *constant = new TConstantUnion();
+ constant->setIConst(i);
+ return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
+}
+
+TIntermTyped *EnsureSignedInt(TIntermTyped *node)
+{
+ if (node->getBasicType() == EbtInt)
+ return node;
+
+ TIntermSequence arguments;
+ arguments.push_back(node);
+ return TIntermAggregate::CreateConstructor(TType(EbtInt), &arguments);
+}
+
+TType *GetFieldType(const TType &indexedType)
+{
+ TType *fieldType = new TType(indexedType);
+ if (indexedType.isMatrix())
+ {
+ fieldType->toMatrixColumnType();
+ }
+ else
+ {
+ ASSERT(indexedType.isVector());
+ fieldType->toComponentType();
+ }
+ // Default precision to highp if not specified. For example in |vec3(0)[i], i < 0|, there is no
+ // precision assigned to vec3(0).
+ if (fieldType->getPrecision() == EbpUndefined)
+ {
+ fieldType->setPrecision(EbpHigh);
+ }
+ return fieldType;
+}
+
+const TType *GetBaseType(const TType &type, bool write)
+{
+ TType *baseType = new TType(type);
+ // Conservatively use highp here, even if the indexed type is not highp. That way the code can't
+ // end up using mediump version of an indexing function for a highp value, if both mediump and
+ // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
+ // principle this code could be used with multiple backends.
+ baseType->setPrecision(EbpHigh);
+ baseType->setQualifier(EvqParamInOut);
+ if (!write)
+ baseType->setQualifier(EvqParamIn);
+ return baseType;
+}
+
+// Generate a read or write function for one field in a vector/matrix.
+// Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
+// indices in other places.
+// Note that indices can be either int or uint. We create only int versions of the functions,
+// and convert uint indices to int at the call site.
+// read function example:
+// float dyn_index_vec2(in vec2 base, in int index)
+// {
+// switch(index)
+// {
+// case (0):
+// return base[0];
+// case (1):
+// return base[1];
+// default:
+// break;
+// }
+// if (index < 0)
+// return base[0];
+// return base[1];
+// }
+// write function example:
+// void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
+// {
+// switch(index)
+// {
+// case (0):
+// base[0] = value;
+// return;
+// case (1):
+// base[1] = value;
+// return;
+// default:
+// break;
+// }
+// if (index < 0)
+// {
+// base[0] = value;
+// return;
+// }
+// base[1] = value;
+// }
+// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
+TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type,
+ bool write,
+ const TFunction &func,
+ TSymbolTable *symbolTable)
+{
+ ASSERT(!type.isArray());
+
+ uint8_t numCases = 0;
+ if (type.isMatrix())
+ {
+ numCases = type.getCols();
+ }
+ else
+ {
+ numCases = type.getNominalSize();
+ }
+
+ std::string functionName = GetIndexFunctionName(type, write);
+ TIntermFunctionPrototype *prototypeNode = CreateInternalFunctionPrototypeNode(func);
+
+ TIntermSymbol *baseParam = new TIntermSymbol(func.getParam(0));
+ TIntermSymbol *indexParam = new TIntermSymbol(func.getParam(1));
+ TIntermSymbol *valueParam = nullptr;
+ if (write)
+ {
+ valueParam = new TIntermSymbol(func.getParam(2));
+ }
+
+ TIntermBlock *statementList = new TIntermBlock();
+ for (uint8_t i = 0; i < numCases; ++i)
+ {
+ TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
+ statementList->getSequence()->push_back(caseNode);
+
+ TIntermBinary *indexNode =
+ new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(i));
+ if (write)
+ {
+ TIntermBinary *assignNode =
+ new TIntermBinary(EOpAssign, indexNode, valueParam->deepCopy());
+ statementList->getSequence()->push_back(assignNode);
+ TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
+ statementList->getSequence()->push_back(returnNode);
+ }
+ else
+ {
+ TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
+ statementList->getSequence()->push_back(returnNode);
+ }
+ }
+
+ // Default case
+ TIntermCase *defaultNode = new TIntermCase(nullptr);
+ statementList->getSequence()->push_back(defaultNode);
+ TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
+ statementList->getSequence()->push_back(breakNode);
+
+ TIntermSwitch *switchNode = new TIntermSwitch(indexParam->deepCopy(), statementList);
+
+ TIntermBlock *bodyNode = new TIntermBlock();
+ bodyNode->getSequence()->push_back(switchNode);
+
+ TIntermBinary *cond =
+ new TIntermBinary(EOpLessThan, indexParam->deepCopy(), CreateIntConstantNode(0));
+
+ // Two blocks: one accesses (either reads or writes) the first element and returns,
+ // the other accesses the last element.
+ TIntermBlock *useFirstBlock = new TIntermBlock();
+ TIntermBlock *useLastBlock = new TIntermBlock();
+ TIntermBinary *indexFirstNode =
+ new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(0));
+ TIntermBinary *indexLastNode =
+ new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(numCases - 1));
+ if (write)
+ {
+ TIntermBinary *assignFirstNode =
+ new TIntermBinary(EOpAssign, indexFirstNode, valueParam->deepCopy());
+ useFirstBlock->getSequence()->push_back(assignFirstNode);
+ TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
+ useFirstBlock->getSequence()->push_back(returnNode);
+
+ TIntermBinary *assignLastNode =
+ new TIntermBinary(EOpAssign, indexLastNode, valueParam->deepCopy());
+ useLastBlock->getSequence()->push_back(assignLastNode);
+ }
+ else
+ {
+ TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
+ useFirstBlock->getSequence()->push_back(returnFirstNode);
+
+ TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
+ useLastBlock->getSequence()->push_back(returnLastNode);
+ }
+ TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr);
+ bodyNode->getSequence()->push_back(ifNode);
+ bodyNode->getSequence()->push_back(useLastBlock);
+
+ TIntermFunctionDefinition *indexingFunction =
+ new TIntermFunctionDefinition(prototypeNode, bodyNode);
+ return indexingFunction;
+}
+
+class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
+{
+ public:
+ RemoveDynamicIndexingTraverser(DynamicIndexingNodeMatcher &&matcher,
+ TSymbolTable *symbolTable,
+ PerformanceDiagnostics *perfDiagnostics);
+
+ bool visitBinary(Visit visit, TIntermBinary *node) override;
+
+ void insertHelperDefinitions(TIntermNode *root);
+
+ void nextIteration();
+
+ bool usedTreeInsertion() const { return mUsedTreeInsertion; }
+
+ protected:
+ // Maps of types that are indexed to the indexing function ids used for them. Note that these
+ // can not store multiple variants of the same type with different precisions - only one
+ // precision gets stored.
+ std::map<TType, TFunction *> mIndexedVecAndMatrixTypes;
+ std::map<TType, TFunction *> mWrittenVecAndMatrixTypes;
+
+ bool mUsedTreeInsertion;
+
+ // When true, the traverser will remove side effects from any indexing expression.
+ // This is done so that in code like
+ // V[j++][i]++.
+ // where V is an array of vectors, j++ will only be evaluated once.
+ bool mRemoveIndexSideEffectsInSubtree;
+
+ DynamicIndexingNodeMatcher mMatcher;
+ PerformanceDiagnostics *mPerfDiagnostics;
+};
+
+RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(
+ DynamicIndexingNodeMatcher &&matcher,
+ TSymbolTable *symbolTable,
+ PerformanceDiagnostics *perfDiagnostics)
+ : TLValueTrackingTraverser(true, false, false, symbolTable),
+ mUsedTreeInsertion(false),
+ mRemoveIndexSideEffectsInSubtree(false),
+ mMatcher(matcher),
+ mPerfDiagnostics(perfDiagnostics)
+{}
+
+void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
+{
+ TIntermBlock *rootBlock = root->getAsBlock();
+ ASSERT(rootBlock != nullptr);
+ TIntermSequence insertions;
+ for (auto &type : mIndexedVecAndMatrixTypes)
+ {
+ insertions.push_back(
+ GetIndexFunctionDefinition(type.first, false, *type.second, mSymbolTable));
+ }
+ for (auto &type : mWrittenVecAndMatrixTypes)
+ {
+ insertions.push_back(
+ GetIndexFunctionDefinition(type.first, true, *type.second, mSymbolTable));
+ }
+ rootBlock->insertChildNodes(0, insertions);
+}
+
+// Create a call to dyn_index_*() based on an indirect indexing op node
+TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
+ TIntermTyped *index,
+ TFunction *indexingFunction)
+{
+ ASSERT(node->getOp() == EOpIndexIndirect);
+ TIntermSequence arguments;
+ arguments.push_back(node->getLeft());
+ arguments.push_back(index);
+
+ TIntermAggregate *indexingCall =
+ TIntermAggregate::CreateFunctionCall(*indexingFunction, &arguments);
+ indexingCall->setLine(node->getLine());
+ return indexingCall;
+}
+
+TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
+ TVariable *index,
+ TVariable *writtenValue,
+ TFunction *indexedWriteFunction)
+{
+ ASSERT(node->getOp() == EOpIndexIndirect);
+ TIntermSequence arguments;
+ // Deep copy the child nodes so that two pointers to the same node don't end up in the tree.
+ arguments.push_back(node->getLeft()->deepCopy());
+ arguments.push_back(CreateTempSymbolNode(index));
+ arguments.push_back(CreateTempSymbolNode(writtenValue));
+
+ TIntermAggregate *indexedWriteCall =
+ TIntermAggregate::CreateFunctionCall(*indexedWriteFunction, &arguments);
+ indexedWriteCall->setLine(node->getLine());
+ return indexedWriteCall;
+}
+
+bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
+{
+ if (mUsedTreeInsertion)
+ return false;
+
+ if (node->getOp() == EOpIndexIndirect)
+ {
+ if (mRemoveIndexSideEffectsInSubtree)
+ {
+ ASSERT(node->getRight()->hasSideEffects());
+ // In case we're just removing index side effects, convert
+ // v_expr[index_expr]
+ // to this:
+ // int s0 = index_expr; v_expr[s0];
+ // Now v_expr[s0] can be safely executed several times without unintended side effects.
+ TIntermDeclaration *indexVariableDeclaration = nullptr;
+ TVariable *indexVariable = DeclareTempVariable(mSymbolTable, node->getRight(),
+ EvqTemporary, &indexVariableDeclaration);
+ insertStatementInParentBlock(indexVariableDeclaration);
+ mUsedTreeInsertion = true;
+
+ // Replace the index with the temp variable
+ TIntermSymbol *tempIndex = CreateTempSymbolNode(indexVariable);
+ queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED);
+ }
+ else if (mMatcher(node))
+ {
+ if (mPerfDiagnostics)
+ {
+ mPerfDiagnostics->warning(node->getLine(),
+ "Performance: dynamic indexing of vectors and "
+ "matrices is emulated and can be slow.",
+ "[]");
+ }
+ bool write = isLValueRequiredHere();
+
+#if defined(ANGLE_ENABLE_ASSERTS)
+ // Make sure that IntermNodePatternMatcher is consistent with the slightly differently
+ // implemented checks in this traverser.
+ IntermNodePatternMatcher matcher(
+ IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue);
+ ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
+#endif
+
+ const TType &type = node->getLeft()->getType();
+ ImmutableString indexingFunctionName(GetIndexFunctionName(type, false));
+ TFunction *indexingFunction = nullptr;
+ if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end())
+ {
+ indexingFunction =
+ new TFunction(mSymbolTable, indexingFunctionName, SymbolType::AngleInternal,
+ GetFieldType(type), true);
+ indexingFunction->addParameter(new TVariable(
+ mSymbolTable, kBaseName, GetBaseType(type, false), SymbolType::AngleInternal));
+ indexingFunction->addParameter(
+ new TVariable(mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal));
+ mIndexedVecAndMatrixTypes[type] = indexingFunction;
+ }
+ else
+ {
+ indexingFunction = mIndexedVecAndMatrixTypes[type];
+ }
+
+ if (write)
+ {
+ // Convert:
+ // v_expr[index_expr]++;
+ // to this:
+ // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
+ // dyn_index_write(v_expr, s0, s1);
+ // This works even if index_expr has some side effects.
+ if (node->getLeft()->hasSideEffects())
+ {
+ // If v_expr has side effects, those need to be removed before proceeding.
+ // Otherwise the side effects of v_expr would be evaluated twice.
+ // The only case where an l-value can have side effects is when it is
+ // indexing. For example, it can be V[j++] where V is an array of vectors.
+ mRemoveIndexSideEffectsInSubtree = true;
+ return true;
+ }
+
+ TIntermBinary *leftBinary = node->getLeft()->getAsBinaryNode();
+ if (leftBinary != nullptr && mMatcher(leftBinary))
+ {
+ // This is a case like:
+ // mat2 m;
+ // m[a][b]++;
+ // Process the child node m[a] first.
+ return true;
+ }
+
+ // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
+ // only writes it and doesn't need the previous value. http://anglebug.com/1116
+
+ TFunction *indexedWriteFunction = nullptr;
+ if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end())
+ {
+ ImmutableString functionName(
+ GetIndexFunctionName(node->getLeft()->getType(), true));
+ indexedWriteFunction =
+ new TFunction(mSymbolTable, functionName, SymbolType::AngleInternal,
+ StaticType::GetBasic<EbtVoid, EbpUndefined>(), false);
+ indexedWriteFunction->addParameter(new TVariable(mSymbolTable, kBaseName,
+ GetBaseType(type, true),
+ SymbolType::AngleInternal));
+ indexedWriteFunction->addParameter(new TVariable(
+ mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal));
+ TType *valueType = GetFieldType(type);
+ valueType->setQualifier(EvqParamIn);
+ indexedWriteFunction->addParameter(new TVariable(
+ mSymbolTable, kValueName, static_cast<const TType *>(valueType),
+ SymbolType::AngleInternal));
+ mWrittenVecAndMatrixTypes[type] = indexedWriteFunction;
+ }
+ else
+ {
+ indexedWriteFunction = mWrittenVecAndMatrixTypes[type];
+ }
+
+ TIntermSequence insertionsBefore;
+ TIntermSequence insertionsAfter;
+
+ // Store the index in a temporary signed int variable.
+ // s0 = index_expr;
+ TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
+ TIntermDeclaration *indexVariableDeclaration = nullptr;
+ TVariable *indexVariable = DeclareTempVariable(
+ mSymbolTable, indexInitializer, EvqTemporary, &indexVariableDeclaration);
+ insertionsBefore.push_back(indexVariableDeclaration);
+
+ // s1 = dyn_index(v_expr, s0);
+ TIntermAggregate *indexingCall = CreateIndexFunctionCall(
+ node, CreateTempSymbolNode(indexVariable), indexingFunction);
+ TIntermDeclaration *fieldVariableDeclaration = nullptr;
+ TVariable *fieldVariable = DeclareTempVariable(
+ mSymbolTable, indexingCall, EvqTemporary, &fieldVariableDeclaration);
+ insertionsBefore.push_back(fieldVariableDeclaration);
+
+ // dyn_index_write(v_expr, s0, s1);
+ TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall(
+ node, indexVariable, fieldVariable, indexedWriteFunction);
+ insertionsAfter.push_back(indexedWriteCall);
+ insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
+
+ // replace the node with s1
+ queueReplacement(CreateTempSymbolNode(fieldVariable), OriginalNode::IS_DROPPED);
+ mUsedTreeInsertion = true;
+ }
+ else
+ {
+ // The indexed value is not being written, so we can simply convert
+ // v_expr[index_expr]
+ // into
+ // dyn_index(v_expr, index_expr)
+ // If the index_expr is unsigned, we'll convert it to signed.
+ ASSERT(!mRemoveIndexSideEffectsInSubtree);
+ TIntermAggregate *indexingCall = CreateIndexFunctionCall(
+ node, EnsureSignedInt(node->getRight()), indexingFunction);
+ queueReplacement(indexingCall, OriginalNode::IS_DROPPED);
+ }
+ }
+ }
+ return !mUsedTreeInsertion;
+}
+
+void RemoveDynamicIndexingTraverser::nextIteration()
+{
+ mUsedTreeInsertion = false;
+ mRemoveIndexSideEffectsInSubtree = false;
+}
+
+bool RemoveDynamicIndexingIf(DynamicIndexingNodeMatcher &&matcher,
+ TCompiler *compiler,
+ TIntermNode *root,
+ TSymbolTable *symbolTable,
+ PerformanceDiagnostics *perfDiagnostics)
+{
+ // This transformation adds function declarations after the fact and so some validation is
+ // momentarily disabled.
+ bool enableValidateFunctionCall = compiler->disableValidateFunctionCall();
+
+ RemoveDynamicIndexingTraverser traverser(std::move(matcher), symbolTable, perfDiagnostics);
+ do
+ {
+ traverser.nextIteration();
+ root->traverse(&traverser);
+ if (!traverser.updateTree(compiler, root))
+ {
+ return false;
+ }
+ } while (traverser.usedTreeInsertion());
+ // TODO(oetuaho@nvidia.com): It might be nicer to add the helper definitions also in the middle
+ // of traversal. Now the tree ends up in an inconsistent state in the middle, since there are
+ // function call nodes with no corresponding definition nodes. This needs special handling in
+ // TIntermLValueTrackingTraverser, and creates intricacies that are not easily apparent from a
+ // superficial reading of the code.
+ traverser.insertHelperDefinitions(root);
+
+ compiler->restoreValidateFunctionCall(enableValidateFunctionCall);
+ return compiler->validateAST(root);
+}
+
+} // namespace
+
+[[nodiscard]] bool RemoveDynamicIndexingOfNonSSBOVectorOrMatrix(
+ TCompiler *compiler,
+ TIntermNode *root,
+ TSymbolTable *symbolTable,
+ PerformanceDiagnostics *perfDiagnostics)
+{
+ DynamicIndexingNodeMatcher matcher = [](TIntermBinary *node) {
+ return IntermNodePatternMatcher::IsDynamicIndexingOfNonSSBOVectorOrMatrix(node);
+ };
+ return RemoveDynamicIndexingIf(std::move(matcher), compiler, root, symbolTable,
+ perfDiagnostics);
+}
+
+[[nodiscard]] bool RemoveDynamicIndexingOfSwizzledVector(TCompiler *compiler,
+ TIntermNode *root,
+ TSymbolTable *symbolTable,
+ PerformanceDiagnostics *perfDiagnostics)
+{
+ DynamicIndexingNodeMatcher matcher = [](TIntermBinary *node) {
+ return IntermNodePatternMatcher::IsDynamicIndexingOfSwizzledVector(node);
+ };
+ return RemoveDynamicIndexingIf(std::move(matcher), compiler, root, symbolTable,
+ perfDiagnostics);
+}
+
+} // namespace sh