diff options
Diffstat (limited to 'gfx/angle/checkout/src/compiler/translator/tree_ops/RemoveDynamicIndexing.cpp')
-rw-r--r-- | gfx/angle/checkout/src/compiler/translator/tree_ops/RemoveDynamicIndexing.cpp | 597 |
1 files changed, 597 insertions, 0 deletions
diff --git a/gfx/angle/checkout/src/compiler/translator/tree_ops/RemoveDynamicIndexing.cpp b/gfx/angle/checkout/src/compiler/translator/tree_ops/RemoveDynamicIndexing.cpp new file mode 100644 index 0000000000..fda6de6f48 --- /dev/null +++ b/gfx/angle/checkout/src/compiler/translator/tree_ops/RemoveDynamicIndexing.cpp @@ -0,0 +1,597 @@ +// +// Copyright 2002 The ANGLE Project Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of non-SSBO vectors and +// matrices, replacing them with calls to functions that choose which component to return or write. +// We don't need to consider dynamic indexing in SSBO since it can be directly as part of the offset +// of RWByteAddressBuffer. +// + +#include "compiler/translator/tree_ops/RemoveDynamicIndexing.h" + +#include "compiler/translator/Compiler.h" +#include "compiler/translator/Diagnostics.h" +#include "compiler/translator/InfoSink.h" +#include "compiler/translator/StaticType.h" +#include "compiler/translator/SymbolTable.h" +#include "compiler/translator/tree_util/IntermNodePatternMatcher.h" +#include "compiler/translator/tree_util/IntermNode_util.h" +#include "compiler/translator/tree_util/IntermTraverse.h" + +namespace sh +{ + +namespace +{ + +using DynamicIndexingNodeMatcher = std::function<bool(TIntermBinary *)>; + +const TType *kIndexType = StaticType::Get<EbtInt, EbpHigh, EvqParamIn, 1, 1>(); + +constexpr const ImmutableString kBaseName("base"); +constexpr const ImmutableString kIndexName("index"); +constexpr const ImmutableString kValueName("value"); + +std::string GetIndexFunctionName(const TType &type, bool write) +{ + TInfoSinkBase nameSink; + nameSink << "dyn_index_"; + if (write) + { + nameSink << "write_"; + } + if (type.isMatrix()) + { + nameSink << "mat" << static_cast<uint32_t>(type.getCols()) << "x" + << static_cast<uint32_t>(type.getRows()); + } + else + { + switch (type.getBasicType()) + { + case EbtInt: + nameSink << "ivec"; + break; + case EbtBool: + nameSink << "bvec"; + break; + case EbtUInt: + nameSink << "uvec"; + break; + case EbtFloat: + nameSink << "vec"; + break; + default: + UNREACHABLE(); + } + nameSink << static_cast<uint32_t>(type.getNominalSize()); + } + return nameSink.str(); +} + +TIntermConstantUnion *CreateIntConstantNode(int i) +{ + TConstantUnion *constant = new TConstantUnion(); + constant->setIConst(i); + return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh)); +} + +TIntermTyped *EnsureSignedInt(TIntermTyped *node) +{ + if (node->getBasicType() == EbtInt) + return node; + + TIntermSequence arguments; + arguments.push_back(node); + return TIntermAggregate::CreateConstructor(TType(EbtInt), &arguments); +} + +TType *GetFieldType(const TType &indexedType) +{ + TType *fieldType = new TType(indexedType); + if (indexedType.isMatrix()) + { + fieldType->toMatrixColumnType(); + } + else + { + ASSERT(indexedType.isVector()); + fieldType->toComponentType(); + } + // Default precision to highp if not specified. For example in |vec3(0)[i], i < 0|, there is no + // precision assigned to vec3(0). + if (fieldType->getPrecision() == EbpUndefined) + { + fieldType->setPrecision(EbpHigh); + } + return fieldType; +} + +const TType *GetBaseType(const TType &type, bool write) +{ + TType *baseType = new TType(type); + // Conservatively use highp here, even if the indexed type is not highp. That way the code can't + // end up using mediump version of an indexing function for a highp value, if both mediump and + // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in + // principle this code could be used with multiple backends. + baseType->setPrecision(EbpHigh); + baseType->setQualifier(EvqParamInOut); + if (!write) + baseType->setQualifier(EvqParamIn); + return baseType; +} + +// Generate a read or write function for one field in a vector/matrix. +// Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range +// indices in other places. +// Note that indices can be either int or uint. We create only int versions of the functions, +// and convert uint indices to int at the call site. +// read function example: +// float dyn_index_vec2(in vec2 base, in int index) +// { +// switch(index) +// { +// case (0): +// return base[0]; +// case (1): +// return base[1]; +// default: +// break; +// } +// if (index < 0) +// return base[0]; +// return base[1]; +// } +// write function example: +// void dyn_index_write_vec2(inout vec2 base, in int index, in float value) +// { +// switch(index) +// { +// case (0): +// base[0] = value; +// return; +// case (1): +// base[1] = value; +// return; +// default: +// break; +// } +// if (index < 0) +// { +// base[0] = value; +// return; +// } +// base[1] = value; +// } +// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation. +TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type, + bool write, + const TFunction &func, + TSymbolTable *symbolTable) +{ + ASSERT(!type.isArray()); + + uint8_t numCases = 0; + if (type.isMatrix()) + { + numCases = type.getCols(); + } + else + { + numCases = type.getNominalSize(); + } + + std::string functionName = GetIndexFunctionName(type, write); + TIntermFunctionPrototype *prototypeNode = CreateInternalFunctionPrototypeNode(func); + + TIntermSymbol *baseParam = new TIntermSymbol(func.getParam(0)); + TIntermSymbol *indexParam = new TIntermSymbol(func.getParam(1)); + TIntermSymbol *valueParam = nullptr; + if (write) + { + valueParam = new TIntermSymbol(func.getParam(2)); + } + + TIntermBlock *statementList = new TIntermBlock(); + for (uint8_t i = 0; i < numCases; ++i) + { + TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i)); + statementList->getSequence()->push_back(caseNode); + + TIntermBinary *indexNode = + new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(i)); + if (write) + { + TIntermBinary *assignNode = + new TIntermBinary(EOpAssign, indexNode, valueParam->deepCopy()); + statementList->getSequence()->push_back(assignNode); + TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr); + statementList->getSequence()->push_back(returnNode); + } + else + { + TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode); + statementList->getSequence()->push_back(returnNode); + } + } + + // Default case + TIntermCase *defaultNode = new TIntermCase(nullptr); + statementList->getSequence()->push_back(defaultNode); + TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr); + statementList->getSequence()->push_back(breakNode); + + TIntermSwitch *switchNode = new TIntermSwitch(indexParam->deepCopy(), statementList); + + TIntermBlock *bodyNode = new TIntermBlock(); + bodyNode->getSequence()->push_back(switchNode); + + TIntermBinary *cond = + new TIntermBinary(EOpLessThan, indexParam->deepCopy(), CreateIntConstantNode(0)); + + // Two blocks: one accesses (either reads or writes) the first element and returns, + // the other accesses the last element. + TIntermBlock *useFirstBlock = new TIntermBlock(); + TIntermBlock *useLastBlock = new TIntermBlock(); + TIntermBinary *indexFirstNode = + new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(0)); + TIntermBinary *indexLastNode = + new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(numCases - 1)); + if (write) + { + TIntermBinary *assignFirstNode = + new TIntermBinary(EOpAssign, indexFirstNode, valueParam->deepCopy()); + useFirstBlock->getSequence()->push_back(assignFirstNode); + TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr); + useFirstBlock->getSequence()->push_back(returnNode); + + TIntermBinary *assignLastNode = + new TIntermBinary(EOpAssign, indexLastNode, valueParam->deepCopy()); + useLastBlock->getSequence()->push_back(assignLastNode); + } + else + { + TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode); + useFirstBlock->getSequence()->push_back(returnFirstNode); + + TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode); + useLastBlock->getSequence()->push_back(returnLastNode); + } + TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr); + bodyNode->getSequence()->push_back(ifNode); + bodyNode->getSequence()->push_back(useLastBlock); + + TIntermFunctionDefinition *indexingFunction = + new TIntermFunctionDefinition(prototypeNode, bodyNode); + return indexingFunction; +} + +class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser +{ + public: + RemoveDynamicIndexingTraverser(DynamicIndexingNodeMatcher &&matcher, + TSymbolTable *symbolTable, + PerformanceDiagnostics *perfDiagnostics); + + bool visitBinary(Visit visit, TIntermBinary *node) override; + + void insertHelperDefinitions(TIntermNode *root); + + void nextIteration(); + + bool usedTreeInsertion() const { return mUsedTreeInsertion; } + + protected: + // Maps of types that are indexed to the indexing function ids used for them. Note that these + // can not store multiple variants of the same type with different precisions - only one + // precision gets stored. + std::map<TType, TFunction *> mIndexedVecAndMatrixTypes; + std::map<TType, TFunction *> mWrittenVecAndMatrixTypes; + + bool mUsedTreeInsertion; + + // When true, the traverser will remove side effects from any indexing expression. + // This is done so that in code like + // V[j++][i]++. + // where V is an array of vectors, j++ will only be evaluated once. + bool mRemoveIndexSideEffectsInSubtree; + + DynamicIndexingNodeMatcher mMatcher; + PerformanceDiagnostics *mPerfDiagnostics; +}; + +RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser( + DynamicIndexingNodeMatcher &&matcher, + TSymbolTable *symbolTable, + PerformanceDiagnostics *perfDiagnostics) + : TLValueTrackingTraverser(true, false, false, symbolTable), + mUsedTreeInsertion(false), + mRemoveIndexSideEffectsInSubtree(false), + mMatcher(matcher), + mPerfDiagnostics(perfDiagnostics) +{} + +void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root) +{ + TIntermBlock *rootBlock = root->getAsBlock(); + ASSERT(rootBlock != nullptr); + TIntermSequence insertions; + for (auto &type : mIndexedVecAndMatrixTypes) + { + insertions.push_back( + GetIndexFunctionDefinition(type.first, false, *type.second, mSymbolTable)); + } + for (auto &type : mWrittenVecAndMatrixTypes) + { + insertions.push_back( + GetIndexFunctionDefinition(type.first, true, *type.second, mSymbolTable)); + } + rootBlock->insertChildNodes(0, insertions); +} + +// Create a call to dyn_index_*() based on an indirect indexing op node +TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node, + TIntermTyped *index, + TFunction *indexingFunction) +{ + ASSERT(node->getOp() == EOpIndexIndirect); + TIntermSequence arguments; + arguments.push_back(node->getLeft()); + arguments.push_back(index); + + TIntermAggregate *indexingCall = + TIntermAggregate::CreateFunctionCall(*indexingFunction, &arguments); + indexingCall->setLine(node->getLine()); + return indexingCall; +} + +TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node, + TVariable *index, + TVariable *writtenValue, + TFunction *indexedWriteFunction) +{ + ASSERT(node->getOp() == EOpIndexIndirect); + TIntermSequence arguments; + // Deep copy the child nodes so that two pointers to the same node don't end up in the tree. + arguments.push_back(node->getLeft()->deepCopy()); + arguments.push_back(CreateTempSymbolNode(index)); + arguments.push_back(CreateTempSymbolNode(writtenValue)); + + TIntermAggregate *indexedWriteCall = + TIntermAggregate::CreateFunctionCall(*indexedWriteFunction, &arguments); + indexedWriteCall->setLine(node->getLine()); + return indexedWriteCall; +} + +bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node) +{ + if (mUsedTreeInsertion) + return false; + + if (node->getOp() == EOpIndexIndirect) + { + if (mRemoveIndexSideEffectsInSubtree) + { + ASSERT(node->getRight()->hasSideEffects()); + // In case we're just removing index side effects, convert + // v_expr[index_expr] + // to this: + // int s0 = index_expr; v_expr[s0]; + // Now v_expr[s0] can be safely executed several times without unintended side effects. + TIntermDeclaration *indexVariableDeclaration = nullptr; + TVariable *indexVariable = DeclareTempVariable(mSymbolTable, node->getRight(), + EvqTemporary, &indexVariableDeclaration); + insertStatementInParentBlock(indexVariableDeclaration); + mUsedTreeInsertion = true; + + // Replace the index with the temp variable + TIntermSymbol *tempIndex = CreateTempSymbolNode(indexVariable); + queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED); + } + else if (mMatcher(node)) + { + if (mPerfDiagnostics) + { + mPerfDiagnostics->warning(node->getLine(), + "Performance: dynamic indexing of vectors and " + "matrices is emulated and can be slow.", + "[]"); + } + bool write = isLValueRequiredHere(); + +#if defined(ANGLE_ENABLE_ASSERTS) + // Make sure that IntermNodePatternMatcher is consistent with the slightly differently + // implemented checks in this traverser. + IntermNodePatternMatcher matcher( + IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue); + ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write); +#endif + + const TType &type = node->getLeft()->getType(); + ImmutableString indexingFunctionName(GetIndexFunctionName(type, false)); + TFunction *indexingFunction = nullptr; + if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end()) + { + indexingFunction = + new TFunction(mSymbolTable, indexingFunctionName, SymbolType::AngleInternal, + GetFieldType(type), true); + indexingFunction->addParameter(new TVariable( + mSymbolTable, kBaseName, GetBaseType(type, false), SymbolType::AngleInternal)); + indexingFunction->addParameter( + new TVariable(mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal)); + mIndexedVecAndMatrixTypes[type] = indexingFunction; + } + else + { + indexingFunction = mIndexedVecAndMatrixTypes[type]; + } + + if (write) + { + // Convert: + // v_expr[index_expr]++; + // to this: + // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++; + // dyn_index_write(v_expr, s0, s1); + // This works even if index_expr has some side effects. + if (node->getLeft()->hasSideEffects()) + { + // If v_expr has side effects, those need to be removed before proceeding. + // Otherwise the side effects of v_expr would be evaluated twice. + // The only case where an l-value can have side effects is when it is + // indexing. For example, it can be V[j++] where V is an array of vectors. + mRemoveIndexSideEffectsInSubtree = true; + return true; + } + + TIntermBinary *leftBinary = node->getLeft()->getAsBinaryNode(); + if (leftBinary != nullptr && mMatcher(leftBinary)) + { + // This is a case like: + // mat2 m; + // m[a][b]++; + // Process the child node m[a] first. + return true; + } + + // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value + // only writes it and doesn't need the previous value. http://anglebug.com/1116 + + TFunction *indexedWriteFunction = nullptr; + if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end()) + { + ImmutableString functionName( + GetIndexFunctionName(node->getLeft()->getType(), true)); + indexedWriteFunction = + new TFunction(mSymbolTable, functionName, SymbolType::AngleInternal, + StaticType::GetBasic<EbtVoid, EbpUndefined>(), false); + indexedWriteFunction->addParameter(new TVariable(mSymbolTable, kBaseName, + GetBaseType(type, true), + SymbolType::AngleInternal)); + indexedWriteFunction->addParameter(new TVariable( + mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal)); + TType *valueType = GetFieldType(type); + valueType->setQualifier(EvqParamIn); + indexedWriteFunction->addParameter(new TVariable( + mSymbolTable, kValueName, static_cast<const TType *>(valueType), + SymbolType::AngleInternal)); + mWrittenVecAndMatrixTypes[type] = indexedWriteFunction; + } + else + { + indexedWriteFunction = mWrittenVecAndMatrixTypes[type]; + } + + TIntermSequence insertionsBefore; + TIntermSequence insertionsAfter; + + // Store the index in a temporary signed int variable. + // s0 = index_expr; + TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight()); + TIntermDeclaration *indexVariableDeclaration = nullptr; + TVariable *indexVariable = DeclareTempVariable( + mSymbolTable, indexInitializer, EvqTemporary, &indexVariableDeclaration); + insertionsBefore.push_back(indexVariableDeclaration); + + // s1 = dyn_index(v_expr, s0); + TIntermAggregate *indexingCall = CreateIndexFunctionCall( + node, CreateTempSymbolNode(indexVariable), indexingFunction); + TIntermDeclaration *fieldVariableDeclaration = nullptr; + TVariable *fieldVariable = DeclareTempVariable( + mSymbolTable, indexingCall, EvqTemporary, &fieldVariableDeclaration); + insertionsBefore.push_back(fieldVariableDeclaration); + + // dyn_index_write(v_expr, s0, s1); + TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall( + node, indexVariable, fieldVariable, indexedWriteFunction); + insertionsAfter.push_back(indexedWriteCall); + insertStatementsInParentBlock(insertionsBefore, insertionsAfter); + + // replace the node with s1 + queueReplacement(CreateTempSymbolNode(fieldVariable), OriginalNode::IS_DROPPED); + mUsedTreeInsertion = true; + } + else + { + // The indexed value is not being written, so we can simply convert + // v_expr[index_expr] + // into + // dyn_index(v_expr, index_expr) + // If the index_expr is unsigned, we'll convert it to signed. + ASSERT(!mRemoveIndexSideEffectsInSubtree); + TIntermAggregate *indexingCall = CreateIndexFunctionCall( + node, EnsureSignedInt(node->getRight()), indexingFunction); + queueReplacement(indexingCall, OriginalNode::IS_DROPPED); + } + } + } + return !mUsedTreeInsertion; +} + +void RemoveDynamicIndexingTraverser::nextIteration() +{ + mUsedTreeInsertion = false; + mRemoveIndexSideEffectsInSubtree = false; +} + +bool RemoveDynamicIndexingIf(DynamicIndexingNodeMatcher &&matcher, + TCompiler *compiler, + TIntermNode *root, + TSymbolTable *symbolTable, + PerformanceDiagnostics *perfDiagnostics) +{ + // This transformation adds function declarations after the fact and so some validation is + // momentarily disabled. + bool enableValidateFunctionCall = compiler->disableValidateFunctionCall(); + + RemoveDynamicIndexingTraverser traverser(std::move(matcher), symbolTable, perfDiagnostics); + do + { + traverser.nextIteration(); + root->traverse(&traverser); + if (!traverser.updateTree(compiler, root)) + { + return false; + } + } while (traverser.usedTreeInsertion()); + // TODO(oetuaho@nvidia.com): It might be nicer to add the helper definitions also in the middle + // of traversal. Now the tree ends up in an inconsistent state in the middle, since there are + // function call nodes with no corresponding definition nodes. This needs special handling in + // TIntermLValueTrackingTraverser, and creates intricacies that are not easily apparent from a + // superficial reading of the code. + traverser.insertHelperDefinitions(root); + + compiler->restoreValidateFunctionCall(enableValidateFunctionCall); + return compiler->validateAST(root); +} + +} // namespace + +[[nodiscard]] bool RemoveDynamicIndexingOfNonSSBOVectorOrMatrix( + TCompiler *compiler, + TIntermNode *root, + TSymbolTable *symbolTable, + PerformanceDiagnostics *perfDiagnostics) +{ + DynamicIndexingNodeMatcher matcher = [](TIntermBinary *node) { + return IntermNodePatternMatcher::IsDynamicIndexingOfNonSSBOVectorOrMatrix(node); + }; + return RemoveDynamicIndexingIf(std::move(matcher), compiler, root, symbolTable, + perfDiagnostics); +} + +[[nodiscard]] bool RemoveDynamicIndexingOfSwizzledVector(TCompiler *compiler, + TIntermNode *root, + TSymbolTable *symbolTable, + PerformanceDiagnostics *perfDiagnostics) +{ + DynamicIndexingNodeMatcher matcher = [](TIntermBinary *node) { + return IntermNodePatternMatcher::IsDynamicIndexingOfSwizzledVector(node); + }; + return RemoveDynamicIndexingIf(std::move(matcher), compiler, root, symbolTable, + perfDiagnostics); +} + +} // namespace sh |