// // Copyright 2002 The ANGLE Project Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // // RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of non-SSBO vectors and // matrices, replacing them with calls to functions that choose which component to return or write. // We don't need to consider dynamic indexing in SSBO since it can be directly as part of the offset // of RWByteAddressBuffer. // #include "compiler/translator/tree_ops/RemoveDynamicIndexing.h" #include "compiler/translator/Compiler.h" #include "compiler/translator/Diagnostics.h" #include "compiler/translator/InfoSink.h" #include "compiler/translator/StaticType.h" #include "compiler/translator/SymbolTable.h" #include "compiler/translator/tree_util/IntermNodePatternMatcher.h" #include "compiler/translator/tree_util/IntermNode_util.h" #include "compiler/translator/tree_util/IntermTraverse.h" namespace sh { namespace { using DynamicIndexingNodeMatcher = std::function; const TType *kIndexType = StaticType::Get(); constexpr const ImmutableString kBaseName("base"); constexpr const ImmutableString kIndexName("index"); constexpr const ImmutableString kValueName("value"); std::string GetIndexFunctionName(const TType &type, bool write) { TInfoSinkBase nameSink; nameSink << "dyn_index_"; if (write) { nameSink << "write_"; } if (type.isMatrix()) { nameSink << "mat" << static_cast(type.getCols()) << "x" << static_cast(type.getRows()); } else { switch (type.getBasicType()) { case EbtInt: nameSink << "ivec"; break; case EbtBool: nameSink << "bvec"; break; case EbtUInt: nameSink << "uvec"; break; case EbtFloat: nameSink << "vec"; break; default: UNREACHABLE(); } nameSink << static_cast(type.getNominalSize()); } return nameSink.str(); } TIntermConstantUnion *CreateIntConstantNode(int i) { TConstantUnion *constant = new TConstantUnion(); constant->setIConst(i); return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh)); } TIntermTyped *EnsureSignedInt(TIntermTyped *node) { if (node->getBasicType() == EbtInt) return node; TIntermSequence arguments; arguments.push_back(node); return TIntermAggregate::CreateConstructor(TType(EbtInt), &arguments); } TType *GetFieldType(const TType &indexedType) { TType *fieldType = new TType(indexedType); if (indexedType.isMatrix()) { fieldType->toMatrixColumnType(); } else { ASSERT(indexedType.isVector()); fieldType->toComponentType(); } // Default precision to highp if not specified. For example in |vec3(0)[i], i < 0|, there is no // precision assigned to vec3(0). if (fieldType->getPrecision() == EbpUndefined) { fieldType->setPrecision(EbpHigh); } return fieldType; } const TType *GetBaseType(const TType &type, bool write) { TType *baseType = new TType(type); // Conservatively use highp here, even if the indexed type is not highp. That way the code can't // end up using mediump version of an indexing function for a highp value, if both mediump and // highp values are being indexed in the shader. For HLSL precision doesn't matter, but in // principle this code could be used with multiple backends. baseType->setPrecision(EbpHigh); baseType->setQualifier(EvqParamInOut); if (!write) baseType->setQualifier(EvqParamIn); return baseType; } // Generate a read or write function for one field in a vector/matrix. // Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range // indices in other places. // Note that indices can be either int or uint. We create only int versions of the functions, // and convert uint indices to int at the call site. // read function example: // float dyn_index_vec2(in vec2 base, in int index) // { // switch(index) // { // case (0): // return base[0]; // case (1): // return base[1]; // default: // break; // } // if (index < 0) // return base[0]; // return base[1]; // } // write function example: // void dyn_index_write_vec2(inout vec2 base, in int index, in float value) // { // switch(index) // { // case (0): // base[0] = value; // return; // case (1): // base[1] = value; // return; // default: // break; // } // if (index < 0) // { // base[0] = value; // return; // } // base[1] = value; // } // Note that else is not used in above functions to avoid the RewriteElseBlocks transformation. TIntermFunctionDefinition *GetIndexFunctionDefinition(const TType &type, bool write, const TFunction &func, TSymbolTable *symbolTable) { ASSERT(!type.isArray()); uint8_t numCases = 0; if (type.isMatrix()) { numCases = type.getCols(); } else { numCases = type.getNominalSize(); } std::string functionName = GetIndexFunctionName(type, write); TIntermFunctionPrototype *prototypeNode = CreateInternalFunctionPrototypeNode(func); TIntermSymbol *baseParam = new TIntermSymbol(func.getParam(0)); TIntermSymbol *indexParam = new TIntermSymbol(func.getParam(1)); TIntermSymbol *valueParam = nullptr; if (write) { valueParam = new TIntermSymbol(func.getParam(2)); } TIntermBlock *statementList = new TIntermBlock(); for (uint8_t i = 0; i < numCases; ++i) { TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i)); statementList->getSequence()->push_back(caseNode); TIntermBinary *indexNode = new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(i)); if (write) { TIntermBinary *assignNode = new TIntermBinary(EOpAssign, indexNode, valueParam->deepCopy()); statementList->getSequence()->push_back(assignNode); TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr); statementList->getSequence()->push_back(returnNode); } else { TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode); statementList->getSequence()->push_back(returnNode); } } // Default case TIntermCase *defaultNode = new TIntermCase(nullptr); statementList->getSequence()->push_back(defaultNode); TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr); statementList->getSequence()->push_back(breakNode); TIntermSwitch *switchNode = new TIntermSwitch(indexParam->deepCopy(), statementList); TIntermBlock *bodyNode = new TIntermBlock(); bodyNode->getSequence()->push_back(switchNode); TIntermBinary *cond = new TIntermBinary(EOpLessThan, indexParam->deepCopy(), CreateIntConstantNode(0)); // Two blocks: one accesses (either reads or writes) the first element and returns, // the other accesses the last element. TIntermBlock *useFirstBlock = new TIntermBlock(); TIntermBlock *useLastBlock = new TIntermBlock(); TIntermBinary *indexFirstNode = new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(0)); TIntermBinary *indexLastNode = new TIntermBinary(EOpIndexDirect, baseParam->deepCopy(), CreateIndexNode(numCases - 1)); if (write) { TIntermBinary *assignFirstNode = new TIntermBinary(EOpAssign, indexFirstNode, valueParam->deepCopy()); useFirstBlock->getSequence()->push_back(assignFirstNode); TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr); useFirstBlock->getSequence()->push_back(returnNode); TIntermBinary *assignLastNode = new TIntermBinary(EOpAssign, indexLastNode, valueParam->deepCopy()); useLastBlock->getSequence()->push_back(assignLastNode); } else { TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode); useFirstBlock->getSequence()->push_back(returnFirstNode); TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode); useLastBlock->getSequence()->push_back(returnLastNode); } TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr); bodyNode->getSequence()->push_back(ifNode); bodyNode->getSequence()->push_back(useLastBlock); TIntermFunctionDefinition *indexingFunction = new TIntermFunctionDefinition(prototypeNode, bodyNode); return indexingFunction; } class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser { public: RemoveDynamicIndexingTraverser(DynamicIndexingNodeMatcher &&matcher, TSymbolTable *symbolTable, PerformanceDiagnostics *perfDiagnostics); bool visitBinary(Visit visit, TIntermBinary *node) override; void insertHelperDefinitions(TIntermNode *root); void nextIteration(); bool usedTreeInsertion() const { return mUsedTreeInsertion; } protected: // Maps of types that are indexed to the indexing function ids used for them. Note that these // can not store multiple variants of the same type with different precisions - only one // precision gets stored. std::map mIndexedVecAndMatrixTypes; std::map mWrittenVecAndMatrixTypes; bool mUsedTreeInsertion; // When true, the traverser will remove side effects from any indexing expression. // This is done so that in code like // V[j++][i]++. // where V is an array of vectors, j++ will only be evaluated once. bool mRemoveIndexSideEffectsInSubtree; DynamicIndexingNodeMatcher mMatcher; PerformanceDiagnostics *mPerfDiagnostics; }; RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser( DynamicIndexingNodeMatcher &&matcher, TSymbolTable *symbolTable, PerformanceDiagnostics *perfDiagnostics) : TLValueTrackingTraverser(true, false, false, symbolTable), mUsedTreeInsertion(false), mRemoveIndexSideEffectsInSubtree(false), mMatcher(matcher), mPerfDiagnostics(perfDiagnostics) {} void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root) { TIntermBlock *rootBlock = root->getAsBlock(); ASSERT(rootBlock != nullptr); TIntermSequence insertions; for (auto &type : mIndexedVecAndMatrixTypes) { insertions.push_back( GetIndexFunctionDefinition(type.first, false, *type.second, mSymbolTable)); } for (auto &type : mWrittenVecAndMatrixTypes) { insertions.push_back( GetIndexFunctionDefinition(type.first, true, *type.second, mSymbolTable)); } rootBlock->insertChildNodes(0, insertions); } // Create a call to dyn_index_*() based on an indirect indexing op node TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node, TIntermTyped *index, TFunction *indexingFunction) { ASSERT(node->getOp() == EOpIndexIndirect); TIntermSequence arguments; arguments.push_back(node->getLeft()); arguments.push_back(index); TIntermAggregate *indexingCall = TIntermAggregate::CreateFunctionCall(*indexingFunction, &arguments); indexingCall->setLine(node->getLine()); return indexingCall; } TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node, TVariable *index, TVariable *writtenValue, TFunction *indexedWriteFunction) { ASSERT(node->getOp() == EOpIndexIndirect); TIntermSequence arguments; // Deep copy the child nodes so that two pointers to the same node don't end up in the tree. arguments.push_back(node->getLeft()->deepCopy()); arguments.push_back(CreateTempSymbolNode(index)); arguments.push_back(CreateTempSymbolNode(writtenValue)); TIntermAggregate *indexedWriteCall = TIntermAggregate::CreateFunctionCall(*indexedWriteFunction, &arguments); indexedWriteCall->setLine(node->getLine()); return indexedWriteCall; } bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node) { if (mUsedTreeInsertion) return false; if (node->getOp() == EOpIndexIndirect) { if (mRemoveIndexSideEffectsInSubtree) { ASSERT(node->getRight()->hasSideEffects()); // In case we're just removing index side effects, convert // v_expr[index_expr] // to this: // int s0 = index_expr; v_expr[s0]; // Now v_expr[s0] can be safely executed several times without unintended side effects. TIntermDeclaration *indexVariableDeclaration = nullptr; TVariable *indexVariable = DeclareTempVariable(mSymbolTable, node->getRight(), EvqTemporary, &indexVariableDeclaration); insertStatementInParentBlock(indexVariableDeclaration); mUsedTreeInsertion = true; // Replace the index with the temp variable TIntermSymbol *tempIndex = CreateTempSymbolNode(indexVariable); queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED); } else if (mMatcher(node)) { if (mPerfDiagnostics) { mPerfDiagnostics->warning(node->getLine(), "Performance: dynamic indexing of vectors and " "matrices is emulated and can be slow.", "[]"); } bool write = isLValueRequiredHere(); #if defined(ANGLE_ENABLE_ASSERTS) // Make sure that IntermNodePatternMatcher is consistent with the slightly differently // implemented checks in this traverser. IntermNodePatternMatcher matcher( IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue); ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write); #endif const TType &type = node->getLeft()->getType(); ImmutableString indexingFunctionName(GetIndexFunctionName(type, false)); TFunction *indexingFunction = nullptr; if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end()) { indexingFunction = new TFunction(mSymbolTable, indexingFunctionName, SymbolType::AngleInternal, GetFieldType(type), true); indexingFunction->addParameter(new TVariable( mSymbolTable, kBaseName, GetBaseType(type, false), SymbolType::AngleInternal)); indexingFunction->addParameter( new TVariable(mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal)); mIndexedVecAndMatrixTypes[type] = indexingFunction; } else { indexingFunction = mIndexedVecAndMatrixTypes[type]; } if (write) { // Convert: // v_expr[index_expr]++; // to this: // int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++; // dyn_index_write(v_expr, s0, s1); // This works even if index_expr has some side effects. if (node->getLeft()->hasSideEffects()) { // If v_expr has side effects, those need to be removed before proceeding. // Otherwise the side effects of v_expr would be evaluated twice. // The only case where an l-value can have side effects is when it is // indexing. For example, it can be V[j++] where V is an array of vectors. mRemoveIndexSideEffectsInSubtree = true; return true; } TIntermBinary *leftBinary = node->getLeft()->getAsBinaryNode(); if (leftBinary != nullptr && mMatcher(leftBinary)) { // This is a case like: // mat2 m; // m[a][b]++; // Process the child node m[a] first. return true; } // TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value // only writes it and doesn't need the previous value. http://anglebug.com/1116 TFunction *indexedWriteFunction = nullptr; if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end()) { ImmutableString functionName( GetIndexFunctionName(node->getLeft()->getType(), true)); indexedWriteFunction = new TFunction(mSymbolTable, functionName, SymbolType::AngleInternal, StaticType::GetBasic(), false); indexedWriteFunction->addParameter(new TVariable(mSymbolTable, kBaseName, GetBaseType(type, true), SymbolType::AngleInternal)); indexedWriteFunction->addParameter(new TVariable( mSymbolTable, kIndexName, kIndexType, SymbolType::AngleInternal)); TType *valueType = GetFieldType(type); valueType->setQualifier(EvqParamIn); indexedWriteFunction->addParameter(new TVariable( mSymbolTable, kValueName, static_cast(valueType), SymbolType::AngleInternal)); mWrittenVecAndMatrixTypes[type] = indexedWriteFunction; } else { indexedWriteFunction = mWrittenVecAndMatrixTypes[type]; } TIntermSequence insertionsBefore; TIntermSequence insertionsAfter; // Store the index in a temporary signed int variable. // s0 = index_expr; TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight()); TIntermDeclaration *indexVariableDeclaration = nullptr; TVariable *indexVariable = DeclareTempVariable( mSymbolTable, indexInitializer, EvqTemporary, &indexVariableDeclaration); insertionsBefore.push_back(indexVariableDeclaration); // s1 = dyn_index(v_expr, s0); TIntermAggregate *indexingCall = CreateIndexFunctionCall( node, CreateTempSymbolNode(indexVariable), indexingFunction); TIntermDeclaration *fieldVariableDeclaration = nullptr; TVariable *fieldVariable = DeclareTempVariable( mSymbolTable, indexingCall, EvqTemporary, &fieldVariableDeclaration); insertionsBefore.push_back(fieldVariableDeclaration); // dyn_index_write(v_expr, s0, s1); TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall( node, indexVariable, fieldVariable, indexedWriteFunction); insertionsAfter.push_back(indexedWriteCall); insertStatementsInParentBlock(insertionsBefore, insertionsAfter); // replace the node with s1 queueReplacement(CreateTempSymbolNode(fieldVariable), OriginalNode::IS_DROPPED); mUsedTreeInsertion = true; } else { // The indexed value is not being written, so we can simply convert // v_expr[index_expr] // into // dyn_index(v_expr, index_expr) // If the index_expr is unsigned, we'll convert it to signed. ASSERT(!mRemoveIndexSideEffectsInSubtree); TIntermAggregate *indexingCall = CreateIndexFunctionCall( node, EnsureSignedInt(node->getRight()), indexingFunction); queueReplacement(indexingCall, OriginalNode::IS_DROPPED); } } } return !mUsedTreeInsertion; } void RemoveDynamicIndexingTraverser::nextIteration() { mUsedTreeInsertion = false; mRemoveIndexSideEffectsInSubtree = false; } bool RemoveDynamicIndexingIf(DynamicIndexingNodeMatcher &&matcher, TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable, PerformanceDiagnostics *perfDiagnostics) { // This transformation adds function declarations after the fact and so some validation is // momentarily disabled. bool enableValidateFunctionCall = compiler->disableValidateFunctionCall(); RemoveDynamicIndexingTraverser traverser(std::move(matcher), symbolTable, perfDiagnostics); do { traverser.nextIteration(); root->traverse(&traverser); if (!traverser.updateTree(compiler, root)) { return false; } } while (traverser.usedTreeInsertion()); // TODO(oetuaho@nvidia.com): It might be nicer to add the helper definitions also in the middle // of traversal. Now the tree ends up in an inconsistent state in the middle, since there are // function call nodes with no corresponding definition nodes. This needs special handling in // TIntermLValueTrackingTraverser, and creates intricacies that are not easily apparent from a // superficial reading of the code. traverser.insertHelperDefinitions(root); compiler->restoreValidateFunctionCall(enableValidateFunctionCall); return compiler->validateAST(root); } } // namespace [[nodiscard]] bool RemoveDynamicIndexingOfNonSSBOVectorOrMatrix( TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable, PerformanceDiagnostics *perfDiagnostics) { DynamicIndexingNodeMatcher matcher = [](TIntermBinary *node) { return IntermNodePatternMatcher::IsDynamicIndexingOfNonSSBOVectorOrMatrix(node); }; return RemoveDynamicIndexingIf(std::move(matcher), compiler, root, symbolTable, perfDiagnostics); } [[nodiscard]] bool RemoveDynamicIndexingOfSwizzledVector(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable, PerformanceDiagnostics *perfDiagnostics) { DynamicIndexingNodeMatcher matcher = [](TIntermBinary *node) { return IntermNodePatternMatcher::IsDynamicIndexingOfSwizzledVector(node); }; return RemoveDynamicIndexingIf(std::move(matcher), compiler, root, symbolTable, perfDiagnostics); } } // namespace sh