diff options
Diffstat (limited to 'gfx/angle/checkout/src/compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.cpp')
-rw-r--r-- | gfx/angle/checkout/src/compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.cpp | 613 |
1 files changed, 613 insertions, 0 deletions
diff --git a/gfx/angle/checkout/src/compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.cpp b/gfx/angle/checkout/src/compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.cpp new file mode 100644 index 0000000000..11c8b72002 --- /dev/null +++ b/gfx/angle/checkout/src/compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.cpp @@ -0,0 +1,613 @@ +// +// Copyright 2021 The ANGLE Project Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// MonomorphizeUnsupportedFunctions: Monomorphize functions that are called with +// parameters that are incompatible with both Vulkan GLSL and Metal. +// + +#include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h" + +#include "compiler/translator/ImmutableStringBuilder.h" +#include "compiler/translator/SymbolTable.h" +#include "compiler/translator/tree_util/IntermNode_util.h" +#include "compiler/translator/tree_util/IntermTraverse.h" +#include "compiler/translator/tree_util/ReplaceVariable.h" + +namespace sh +{ +namespace +{ +struct Argument +{ + size_t argumentIndex; + TIntermTyped *argument; +}; + +struct FunctionData +{ + // Whether the original function is used. If this is false, the function can be removed because + // all callers have been modified. + bool isOriginalUsed; + // The original definition of the function, used to create the monomorphized version. + TIntermFunctionDefinition *originalDefinition; + // List of monomorphized versions of this function. They will be added next to the original + // version (or replace it). + TVector<TIntermFunctionDefinition *> monomorphizedDefinitions; +}; + +using FunctionMap = angle::HashMap<const TFunction *, FunctionData>; + +// Traverse the function definitions and initialize the map. Allows visitAggregate to have access +// to TIntermFunctionDefinition even when the function is only forward declared at that point. +void InitializeFunctionMap(TIntermBlock *root, FunctionMap *functionMapOut) +{ + TIntermSequence &sequence = *root->getSequence(); + + for (TIntermNode *node : sequence) + { + TIntermFunctionDefinition *asFuncDef = node->getAsFunctionDefinition(); + if (asFuncDef != nullptr) + { + const TFunction *function = asFuncDef->getFunction(); + ASSERT(function && functionMapOut->find(function) == functionMapOut->end()); + (*functionMapOut)[function] = FunctionData{false, asFuncDef, {}}; + } + } +} + +const TVariable *GetBaseUniform(TIntermTyped *node, bool *isSamplerInStructOut) +{ + *isSamplerInStructOut = false; + + while (node->getAsBinaryNode()) + { + TIntermBinary *asBinary = node->getAsBinaryNode(); + + TOperator op = asBinary->getOp(); + + // No opaque uniform can be inside an interface block. + if (op == EOpIndexDirectInterfaceBlock) + { + return nullptr; + } + + if (op == EOpIndexDirectStruct) + { + *isSamplerInStructOut = true; + } + + node = asBinary->getLeft(); + } + + // Only interested in uniform opaque types. If a function call within another function uses + // opaque uniforms in an unsupported way, it will be replaced in a follow up pass after the + // calling function is monomorphized. + if (node->getType().getQualifier() != EvqUniform) + { + return nullptr; + } + + ASSERT(IsOpaqueType(node->getType().getBasicType()) || + node->getType().isStructureContainingSamplers()); + + TIntermSymbol *asSymbol = node->getAsSymbolNode(); + ASSERT(asSymbol); + + return &asSymbol->variable(); +} + +TIntermTyped *ExtractSideEffects(TSymbolTable *symbolTable, + TIntermTyped *node, + TIntermSequence *replacementIndices) +{ + TIntermTyped *withoutSideEffects = node->deepCopy(); + + for (TIntermBinary *asBinary = withoutSideEffects->getAsBinaryNode(); asBinary; + asBinary = asBinary->getLeft()->getAsBinaryNode()) + { + TOperator op = asBinary->getOp(); + TIntermTyped *index = asBinary->getRight(); + + if (op == EOpIndexDirectStruct) + { + break; + } + + // No side effects with constant expressions. + if (op == EOpIndexDirect) + { + ASSERT(index->getAsConstantUnion()); + continue; + } + + ASSERT(op == EOpIndexIndirect); + + // If the index is a symbol, there's no side effect, so leave it as-is. + if (index->getAsSymbolNode()) + { + continue; + } + + // Otherwise create a temp variable initialized with the index and use that temp variable as + // the index. + TIntermDeclaration *tempDecl = nullptr; + TVariable *tempVar = DeclareTempVariable(symbolTable, index, EvqTemporary, &tempDecl); + + replacementIndices->push_back(tempDecl); + asBinary->replaceChildNode(index, new TIntermSymbol(tempVar)); + } + + return withoutSideEffects; +} + +void CreateMonomorphizedFunctionCallArgs(const TIntermSequence &originalCallArguments, + const TVector<Argument> &replacedArguments, + TIntermSequence *substituteArgsOut) +{ + size_t nextReplacedArg = 0; + for (size_t argIndex = 0; argIndex < originalCallArguments.size(); ++argIndex) + { + if (nextReplacedArg >= replacedArguments.size() || + argIndex != replacedArguments[nextReplacedArg].argumentIndex) + { + // Not replaced, keep argument as is. + substituteArgsOut->push_back(originalCallArguments[argIndex]); + } + else + { + TIntermTyped *argument = replacedArguments[nextReplacedArg].argument; + + // Iterate over indices of the argument and create a new arg for every non-const + // index. Note that the index itself may be an expression, and it may require further + // substitution in the next pass. + while (argument->getAsBinaryNode()) + { + TIntermBinary *asBinary = argument->getAsBinaryNode(); + if (asBinary->getOp() == EOpIndexIndirect) + { + TIntermTyped *index = asBinary->getRight(); + substituteArgsOut->push_back(index->deepCopy()); + } + argument = asBinary->getLeft(); + } + + ++nextReplacedArg; + } + } +} + +const TFunction *MonomorphizeFunction(TSymbolTable *symbolTable, + const TFunction *original, + TVector<Argument> *replacedArguments, + VariableReplacementMap *argumentMapOut) +{ + TFunction *substituteFunction = + new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal, + &original->getReturnType(), original->isKnownToNotHaveSideEffects()); + + size_t nextReplacedArg = 0; + for (size_t paramIndex = 0; paramIndex < original->getParamCount(); ++paramIndex) + { + const TVariable *originalParam = original->getParam(paramIndex); + + if (nextReplacedArg >= replacedArguments->size() || + paramIndex != (*replacedArguments)[nextReplacedArg].argumentIndex) + { + TVariable *substituteArgument = + new TVariable(symbolTable, originalParam->name(), &originalParam->getType(), + originalParam->symbolType()); + // Not replaced, add an identical parameter. + substituteFunction->addParameter(substituteArgument); + (*argumentMapOut)[originalParam] = new TIntermSymbol(substituteArgument); + } + else + { + TIntermTyped *substituteArgument = (*replacedArguments)[nextReplacedArg].argument; + (*argumentMapOut)[originalParam] = substituteArgument; + + // Iterate over indices of the argument and create a new parameter for every non-const + // index (which may be an expression). Replace the symbol in the argument with a + // variable of the index type. This is later used to replace the parameter in the + // function body. + while (substituteArgument->getAsBinaryNode()) + { + TIntermBinary *asBinary = substituteArgument->getAsBinaryNode(); + if (asBinary->getOp() == EOpIndexIndirect) + { + TIntermTyped *index = asBinary->getRight(); + TType *indexType = new TType(index->getType()); + indexType->setQualifier(EvqParamIn); + + TVariable *param = new TVariable(symbolTable, kEmptyImmutableString, indexType, + SymbolType::AngleInternal); + substituteFunction->addParameter(param); + + // The argument now uses the function parameters as indices. + asBinary->replaceChildNode(asBinary->getRight(), new TIntermSymbol(param)); + } + substituteArgument = asBinary->getLeft(); + } + + ++nextReplacedArg; + } + } + + return substituteFunction; +} + +class MonomorphizeTraverser final : public TIntermTraverser +{ + public: + explicit MonomorphizeTraverser(TCompiler *compiler, + TSymbolTable *symbolTable, + const ShCompileOptions &compileOptions, + UnsupportedFunctionArgsBitSet unsupportedFunctionArgs, + FunctionMap *functionMap) + : TIntermTraverser(true, false, false, symbolTable), + mCompiler(compiler), + mCompileOptions(compileOptions), + mUnsupportedFunctionArgs(unsupportedFunctionArgs), + mFunctionMap(functionMap) + {} + + bool visitAggregate(Visit visit, TIntermAggregate *node) override + { + if (node->getOp() != EOpCallFunctionInAST) + { + return true; + } + + const TFunction *function = node->getFunction(); + ASSERT(function && mFunctionMap->find(function) != mFunctionMap->end()); + + FunctionData &data = (*mFunctionMap)[function]; + + TIntermFunctionDefinition *monomorphized = + processFunctionCall(node, data.originalDefinition, &data.isOriginalUsed); + if (monomorphized) + { + data.monomorphizedDefinitions.push_back(monomorphized); + } + + return true; + } + + bool getAnyMonomorphized() const { return mAnyMonomorphized; } + + private: + bool isUnsupportedArgument(TIntermTyped *callArgument, const TVariable *funcArgument) const + { + // Only interested in opaque uniforms and structs that contain samplers. + const bool isOpaqueType = IsOpaqueType(funcArgument->getType().getBasicType()); + const bool isStructContainingSamplers = + funcArgument->getType().isStructureContainingSamplers(); + if (!isOpaqueType && !isStructContainingSamplers) + { + return false; + } + + // If not uniform (the variable was itself a function parameter), don't process it in + // this pass, as we don't know which actual uniform it corresponds to. + bool isSamplerInStruct = false; + const TVariable *uniform = GetBaseUniform(callArgument, &isSamplerInStruct); + if (uniform == nullptr) + { + return false; + } + + const TType &type = uniform->getType(); + + if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::StructContainingSamplers]) + { + // Monomorphize if the parameter is a structure that contains samplers (so in + // RewriteStructSamplers we don't need to rewrite the functions to accept multiple + // parameters split from the struct). + if (isStructContainingSamplers) + { + return true; + } + } + + if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::ArrayOfArrayOfSamplerOrImage]) + { + // Monomorphize if: + // + // - The opaque uniform is a sampler in a struct (which can create an array-of-array + // situation), and the function expects an array of samplers, or + // + // - The opaque uniform is an array of array of sampler or image, and it's partially + // subscripted (i.e. the function itself expects an array) + // + const bool isParameterArrayOfOpaqueType = funcArgument->getType().isArray(); + const bool isArrayOfArrayOfSamplerOrImage = + (type.isSampler() || type.isImage()) && type.isArrayOfArrays(); + if (isSamplerInStruct && isParameterArrayOfOpaqueType) + { + return true; + } + if (isArrayOfArrayOfSamplerOrImage && isParameterArrayOfOpaqueType) + { + return true; + } + } + + if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::AtomicCounter]) + { + if (type.isAtomicCounter()) + { + return true; + } + } + + if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::SamplerCubeEmulation]) + { + // Monomorphize if the opaque uniform is a samplerCube and ES2's cube sampling emulation + // is requested. + if (type.isSamplerCube() && mCompileOptions.emulateSeamfulCubeMapSampling) + { + return true; + } + } + + if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::Image]) + { + if (type.isImage()) + { + return true; + } + } + + if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::PixelLocalStorage]) + { + if (type.isPixelLocal()) + { + return true; + } + } + + return false; + } + + TIntermFunctionDefinition *processFunctionCall(TIntermAggregate *functionCall, + TIntermFunctionDefinition *originalDefinition, + bool *isOriginalUsedOut) + { + const TFunction *function = functionCall->getFunction(); + const TIntermSequence &callArguments = *functionCall->getSequence(); + + TVector<Argument> replacedArguments; + TIntermSequence replacementIndices; + + // Go through function call arguments, and see if any is used in an unsupported way. + for (size_t argIndex = 0; argIndex < callArguments.size(); ++argIndex) + { + TIntermTyped *callArgument = callArguments[argIndex]->getAsTyped(); + const TVariable *funcArgument = function->getParam(argIndex); + if (isUnsupportedArgument(callArgument, funcArgument)) + { + // Copy the argument and extract the side effects. + TIntermTyped *argument = + ExtractSideEffects(mSymbolTable, callArgument, &replacementIndices); + + replacedArguments.push_back({argIndex, argument}); + } + } + + if (replacedArguments.empty()) + { + *isOriginalUsedOut = true; + return nullptr; + } + + mAnyMonomorphized = true; + + insertStatementsInParentBlock(replacementIndices); + + // Create the arguments for the substitute function call. Done before monomorphizing the + // function, which transforms the arguments to what needs to be replaced in the function + // body. + TIntermSequence newCallArgs; + CreateMonomorphizedFunctionCallArgs(callArguments, replacedArguments, &newCallArgs); + + // Duplicate the function and substitute the replaced arguments with only the non-const + // indices. Additionally, substitute the non-const indices of arguments with the new + // function parameters. + VariableReplacementMap argumentMap; + const TFunction *monomorphized = + MonomorphizeFunction(mSymbolTable, function, &replacedArguments, &argumentMap); + + // Replace this function call with a call to the new one. + queueReplacement(TIntermAggregate::CreateFunctionCall(*monomorphized, &newCallArgs), + OriginalNode::IS_DROPPED); + + // Create a new function definition, with the body of the old function but with the replaced + // parameters substituted with the calling expressions. + TIntermFunctionPrototype *substitutePrototype = new TIntermFunctionPrototype(monomorphized); + TIntermBlock *substituteBlock = originalDefinition->getBody()->deepCopy(); + GetDeclaratorReplacements(mSymbolTable, substituteBlock, &argumentMap); + bool valid = ReplaceVariables(mCompiler, substituteBlock, argumentMap); + ASSERT(valid); + + return new TIntermFunctionDefinition(substitutePrototype, substituteBlock); + } + + TCompiler *mCompiler; + const ShCompileOptions &mCompileOptions; + UnsupportedFunctionArgsBitSet mUnsupportedFunctionArgs; + bool mAnyMonomorphized = false; + + // Map of original to monomorphized functions. + FunctionMap *mFunctionMap; +}; + +class UpdateFunctionsDefinitionsTraverser final : public TIntermTraverser +{ + public: + explicit UpdateFunctionsDefinitionsTraverser(TSymbolTable *symbolTable, + const FunctionMap &functionMap) + : TIntermTraverser(true, false, false, symbolTable), mFunctionMap(functionMap) + {} + + void visitFunctionPrototype(TIntermFunctionPrototype *node) override + { + const bool isInFunctionDefinition = getParentNode()->getAsFunctionDefinition() != nullptr; + if (isInFunctionDefinition) + { + return; + } + + // Add to and possibly replace the function prototype with replacement prototypes. + const TFunction *function = node->getFunction(); + ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end()); + + const FunctionData &data = mFunctionMap.at(function); + + // If nothing to do, leave it be. + if (data.monomorphizedDefinitions.empty()) + { + ASSERT(data.isOriginalUsed); + return; + } + + // Replace the prototype with itself (if function is still used) as well as any + // monomorphized versions. + TIntermSequence replacement; + if (data.isOriginalUsed) + { + replacement.push_back(node); + } + for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions) + { + replacement.push_back(new TIntermFunctionPrototype( + monomorphizedDefinition->getFunctionPrototype()->getFunction())); + } + mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, + std::move(replacement)); + } + + bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override + { + // Add to and possibly replace the function definition with replacement definitions. + const TFunction *function = node->getFunction(); + ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end()); + + const FunctionData &data = mFunctionMap.at(function); + + // If nothing to do, leave it be. + if (data.monomorphizedDefinitions.empty()) + { + ASSERT(data.isOriginalUsed || function->name() == "main"); + return false; + } + + // Replace the definition with itself (if function is still used) as well as any + // monomorphized versions. + TIntermSequence replacement; + if (data.isOriginalUsed) + { + replacement.push_back(node); + } + for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions) + { + replacement.push_back(monomorphizedDefinition); + } + mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, + std::move(replacement)); + + return false; + } + + private: + const FunctionMap &mFunctionMap; +}; + +void SortDeclarations(TIntermBlock *root) +{ + TIntermSequence *original = root->getSequence(); + + TIntermSequence replacement; + TIntermSequence functionDefs; + + // Accumulate non-function-definition declarations in |replacement| and function definitions in + // |functionDefs|. + for (TIntermNode *node : *original) + { + if (node->getAsFunctionDefinition() || node->getAsFunctionPrototypeNode()) + { + functionDefs.push_back(node); + } + else + { + replacement.push_back(node); + } + } + + // Append function definitions to |replacement|. + replacement.insert(replacement.end(), functionDefs.begin(), functionDefs.end()); + + // Replace root's sequence with |replacement|. + root->replaceAllChildren(replacement); +} + +bool MonomorphizeUnsupportedFunctionsImpl(TCompiler *compiler, + TIntermBlock *root, + TSymbolTable *symbolTable, + const ShCompileOptions &compileOptions, + UnsupportedFunctionArgsBitSet unsupportedFunctionArgs) +{ + // First, sort out the declarations such that all non-function declarations are placed before + // function definitions. This way when the function is replaced with one that references said + // declarations (i.e. uniforms), the uniform declaration is already present above it. + SortDeclarations(root); + + while (true) + { + FunctionMap functionMap; + InitializeFunctionMap(root, &functionMap); + + MonomorphizeTraverser monomorphizer(compiler, symbolTable, compileOptions, + unsupportedFunctionArgs, &functionMap); + root->traverse(&monomorphizer); + + if (!monomorphizer.getAnyMonomorphized()) + { + break; + } + + if (!monomorphizer.updateTree(compiler, root)) + { + return false; + } + + UpdateFunctionsDefinitionsTraverser functionUpdater(symbolTable, functionMap); + root->traverse(&functionUpdater); + + if (!functionUpdater.updateTree(compiler, root)) + { + return false; + } + } + + return true; +} +} // anonymous namespace + +bool MonomorphizeUnsupportedFunctions(TCompiler *compiler, + TIntermBlock *root, + TSymbolTable *symbolTable, + const ShCompileOptions &compileOptions, + UnsupportedFunctionArgsBitSet unsupportedFunctionArgs) +{ + // This function actually applies multiple transformation, and the AST may not be valid until + // the transformations are entirely done. Some validation is momentarily disabled. + bool enableValidateFunctionCall = compiler->disableValidateFunctionCall(); + + bool result = MonomorphizeUnsupportedFunctionsImpl(compiler, root, symbolTable, compileOptions, + unsupportedFunctionArgs); + + compiler->restoreValidateFunctionCall(enableValidateFunctionCall); + return result && compiler->validateAST(root); +} +} // namespace sh |