summaryrefslogtreecommitdiffstats
path: root/gfx/angle/checkout/src/compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'gfx/angle/checkout/src/compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.cpp')
-rw-r--r--gfx/angle/checkout/src/compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.cpp613
1 files changed, 613 insertions, 0 deletions
diff --git a/gfx/angle/checkout/src/compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.cpp b/gfx/angle/checkout/src/compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.cpp
new file mode 100644
index 0000000000..11c8b72002
--- /dev/null
+++ b/gfx/angle/checkout/src/compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.cpp
@@ -0,0 +1,613 @@
+//
+// Copyright 2021 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// MonomorphizeUnsupportedFunctions: Monomorphize functions that are called with
+// parameters that are incompatible with both Vulkan GLSL and Metal.
+//
+
+#include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h"
+
+#include "compiler/translator/ImmutableStringBuilder.h"
+#include "compiler/translator/SymbolTable.h"
+#include "compiler/translator/tree_util/IntermNode_util.h"
+#include "compiler/translator/tree_util/IntermTraverse.h"
+#include "compiler/translator/tree_util/ReplaceVariable.h"
+
+namespace sh
+{
+namespace
+{
+struct Argument
+{
+ size_t argumentIndex;
+ TIntermTyped *argument;
+};
+
+struct FunctionData
+{
+ // Whether the original function is used. If this is false, the function can be removed because
+ // all callers have been modified.
+ bool isOriginalUsed;
+ // The original definition of the function, used to create the monomorphized version.
+ TIntermFunctionDefinition *originalDefinition;
+ // List of monomorphized versions of this function. They will be added next to the original
+ // version (or replace it).
+ TVector<TIntermFunctionDefinition *> monomorphizedDefinitions;
+};
+
+using FunctionMap = angle::HashMap<const TFunction *, FunctionData>;
+
+// Traverse the function definitions and initialize the map. Allows visitAggregate to have access
+// to TIntermFunctionDefinition even when the function is only forward declared at that point.
+void InitializeFunctionMap(TIntermBlock *root, FunctionMap *functionMapOut)
+{
+ TIntermSequence &sequence = *root->getSequence();
+
+ for (TIntermNode *node : sequence)
+ {
+ TIntermFunctionDefinition *asFuncDef = node->getAsFunctionDefinition();
+ if (asFuncDef != nullptr)
+ {
+ const TFunction *function = asFuncDef->getFunction();
+ ASSERT(function && functionMapOut->find(function) == functionMapOut->end());
+ (*functionMapOut)[function] = FunctionData{false, asFuncDef, {}};
+ }
+ }
+}
+
+const TVariable *GetBaseUniform(TIntermTyped *node, bool *isSamplerInStructOut)
+{
+ *isSamplerInStructOut = false;
+
+ while (node->getAsBinaryNode())
+ {
+ TIntermBinary *asBinary = node->getAsBinaryNode();
+
+ TOperator op = asBinary->getOp();
+
+ // No opaque uniform can be inside an interface block.
+ if (op == EOpIndexDirectInterfaceBlock)
+ {
+ return nullptr;
+ }
+
+ if (op == EOpIndexDirectStruct)
+ {
+ *isSamplerInStructOut = true;
+ }
+
+ node = asBinary->getLeft();
+ }
+
+ // Only interested in uniform opaque types. If a function call within another function uses
+ // opaque uniforms in an unsupported way, it will be replaced in a follow up pass after the
+ // calling function is monomorphized.
+ if (node->getType().getQualifier() != EvqUniform)
+ {
+ return nullptr;
+ }
+
+ ASSERT(IsOpaqueType(node->getType().getBasicType()) ||
+ node->getType().isStructureContainingSamplers());
+
+ TIntermSymbol *asSymbol = node->getAsSymbolNode();
+ ASSERT(asSymbol);
+
+ return &asSymbol->variable();
+}
+
+TIntermTyped *ExtractSideEffects(TSymbolTable *symbolTable,
+ TIntermTyped *node,
+ TIntermSequence *replacementIndices)
+{
+ TIntermTyped *withoutSideEffects = node->deepCopy();
+
+ for (TIntermBinary *asBinary = withoutSideEffects->getAsBinaryNode(); asBinary;
+ asBinary = asBinary->getLeft()->getAsBinaryNode())
+ {
+ TOperator op = asBinary->getOp();
+ TIntermTyped *index = asBinary->getRight();
+
+ if (op == EOpIndexDirectStruct)
+ {
+ break;
+ }
+
+ // No side effects with constant expressions.
+ if (op == EOpIndexDirect)
+ {
+ ASSERT(index->getAsConstantUnion());
+ continue;
+ }
+
+ ASSERT(op == EOpIndexIndirect);
+
+ // If the index is a symbol, there's no side effect, so leave it as-is.
+ if (index->getAsSymbolNode())
+ {
+ continue;
+ }
+
+ // Otherwise create a temp variable initialized with the index and use that temp variable as
+ // the index.
+ TIntermDeclaration *tempDecl = nullptr;
+ TVariable *tempVar = DeclareTempVariable(symbolTable, index, EvqTemporary, &tempDecl);
+
+ replacementIndices->push_back(tempDecl);
+ asBinary->replaceChildNode(index, new TIntermSymbol(tempVar));
+ }
+
+ return withoutSideEffects;
+}
+
+void CreateMonomorphizedFunctionCallArgs(const TIntermSequence &originalCallArguments,
+ const TVector<Argument> &replacedArguments,
+ TIntermSequence *substituteArgsOut)
+{
+ size_t nextReplacedArg = 0;
+ for (size_t argIndex = 0; argIndex < originalCallArguments.size(); ++argIndex)
+ {
+ if (nextReplacedArg >= replacedArguments.size() ||
+ argIndex != replacedArguments[nextReplacedArg].argumentIndex)
+ {
+ // Not replaced, keep argument as is.
+ substituteArgsOut->push_back(originalCallArguments[argIndex]);
+ }
+ else
+ {
+ TIntermTyped *argument = replacedArguments[nextReplacedArg].argument;
+
+ // Iterate over indices of the argument and create a new arg for every non-const
+ // index. Note that the index itself may be an expression, and it may require further
+ // substitution in the next pass.
+ while (argument->getAsBinaryNode())
+ {
+ TIntermBinary *asBinary = argument->getAsBinaryNode();
+ if (asBinary->getOp() == EOpIndexIndirect)
+ {
+ TIntermTyped *index = asBinary->getRight();
+ substituteArgsOut->push_back(index->deepCopy());
+ }
+ argument = asBinary->getLeft();
+ }
+
+ ++nextReplacedArg;
+ }
+ }
+}
+
+const TFunction *MonomorphizeFunction(TSymbolTable *symbolTable,
+ const TFunction *original,
+ TVector<Argument> *replacedArguments,
+ VariableReplacementMap *argumentMapOut)
+{
+ TFunction *substituteFunction =
+ new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
+ &original->getReturnType(), original->isKnownToNotHaveSideEffects());
+
+ size_t nextReplacedArg = 0;
+ for (size_t paramIndex = 0; paramIndex < original->getParamCount(); ++paramIndex)
+ {
+ const TVariable *originalParam = original->getParam(paramIndex);
+
+ if (nextReplacedArg >= replacedArguments->size() ||
+ paramIndex != (*replacedArguments)[nextReplacedArg].argumentIndex)
+ {
+ TVariable *substituteArgument =
+ new TVariable(symbolTable, originalParam->name(), &originalParam->getType(),
+ originalParam->symbolType());
+ // Not replaced, add an identical parameter.
+ substituteFunction->addParameter(substituteArgument);
+ (*argumentMapOut)[originalParam] = new TIntermSymbol(substituteArgument);
+ }
+ else
+ {
+ TIntermTyped *substituteArgument = (*replacedArguments)[nextReplacedArg].argument;
+ (*argumentMapOut)[originalParam] = substituteArgument;
+
+ // Iterate over indices of the argument and create a new parameter for every non-const
+ // index (which may be an expression). Replace the symbol in the argument with a
+ // variable of the index type. This is later used to replace the parameter in the
+ // function body.
+ while (substituteArgument->getAsBinaryNode())
+ {
+ TIntermBinary *asBinary = substituteArgument->getAsBinaryNode();
+ if (asBinary->getOp() == EOpIndexIndirect)
+ {
+ TIntermTyped *index = asBinary->getRight();
+ TType *indexType = new TType(index->getType());
+ indexType->setQualifier(EvqParamIn);
+
+ TVariable *param = new TVariable(symbolTable, kEmptyImmutableString, indexType,
+ SymbolType::AngleInternal);
+ substituteFunction->addParameter(param);
+
+ // The argument now uses the function parameters as indices.
+ asBinary->replaceChildNode(asBinary->getRight(), new TIntermSymbol(param));
+ }
+ substituteArgument = asBinary->getLeft();
+ }
+
+ ++nextReplacedArg;
+ }
+ }
+
+ return substituteFunction;
+}
+
+class MonomorphizeTraverser final : public TIntermTraverser
+{
+ public:
+ explicit MonomorphizeTraverser(TCompiler *compiler,
+ TSymbolTable *symbolTable,
+ const ShCompileOptions &compileOptions,
+ UnsupportedFunctionArgsBitSet unsupportedFunctionArgs,
+ FunctionMap *functionMap)
+ : TIntermTraverser(true, false, false, symbolTable),
+ mCompiler(compiler),
+ mCompileOptions(compileOptions),
+ mUnsupportedFunctionArgs(unsupportedFunctionArgs),
+ mFunctionMap(functionMap)
+ {}
+
+ bool visitAggregate(Visit visit, TIntermAggregate *node) override
+ {
+ if (node->getOp() != EOpCallFunctionInAST)
+ {
+ return true;
+ }
+
+ const TFunction *function = node->getFunction();
+ ASSERT(function && mFunctionMap->find(function) != mFunctionMap->end());
+
+ FunctionData &data = (*mFunctionMap)[function];
+
+ TIntermFunctionDefinition *monomorphized =
+ processFunctionCall(node, data.originalDefinition, &data.isOriginalUsed);
+ if (monomorphized)
+ {
+ data.monomorphizedDefinitions.push_back(monomorphized);
+ }
+
+ return true;
+ }
+
+ bool getAnyMonomorphized() const { return mAnyMonomorphized; }
+
+ private:
+ bool isUnsupportedArgument(TIntermTyped *callArgument, const TVariable *funcArgument) const
+ {
+ // Only interested in opaque uniforms and structs that contain samplers.
+ const bool isOpaqueType = IsOpaqueType(funcArgument->getType().getBasicType());
+ const bool isStructContainingSamplers =
+ funcArgument->getType().isStructureContainingSamplers();
+ if (!isOpaqueType && !isStructContainingSamplers)
+ {
+ return false;
+ }
+
+ // If not uniform (the variable was itself a function parameter), don't process it in
+ // this pass, as we don't know which actual uniform it corresponds to.
+ bool isSamplerInStruct = false;
+ const TVariable *uniform = GetBaseUniform(callArgument, &isSamplerInStruct);
+ if (uniform == nullptr)
+ {
+ return false;
+ }
+
+ const TType &type = uniform->getType();
+
+ if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::StructContainingSamplers])
+ {
+ // Monomorphize if the parameter is a structure that contains samplers (so in
+ // RewriteStructSamplers we don't need to rewrite the functions to accept multiple
+ // parameters split from the struct).
+ if (isStructContainingSamplers)
+ {
+ return true;
+ }
+ }
+
+ if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::ArrayOfArrayOfSamplerOrImage])
+ {
+ // Monomorphize if:
+ //
+ // - The opaque uniform is a sampler in a struct (which can create an array-of-array
+ // situation), and the function expects an array of samplers, or
+ //
+ // - The opaque uniform is an array of array of sampler or image, and it's partially
+ // subscripted (i.e. the function itself expects an array)
+ //
+ const bool isParameterArrayOfOpaqueType = funcArgument->getType().isArray();
+ const bool isArrayOfArrayOfSamplerOrImage =
+ (type.isSampler() || type.isImage()) && type.isArrayOfArrays();
+ if (isSamplerInStruct && isParameterArrayOfOpaqueType)
+ {
+ return true;
+ }
+ if (isArrayOfArrayOfSamplerOrImage && isParameterArrayOfOpaqueType)
+ {
+ return true;
+ }
+ }
+
+ if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::AtomicCounter])
+ {
+ if (type.isAtomicCounter())
+ {
+ return true;
+ }
+ }
+
+ if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::SamplerCubeEmulation])
+ {
+ // Monomorphize if the opaque uniform is a samplerCube and ES2's cube sampling emulation
+ // is requested.
+ if (type.isSamplerCube() && mCompileOptions.emulateSeamfulCubeMapSampling)
+ {
+ return true;
+ }
+ }
+
+ if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::Image])
+ {
+ if (type.isImage())
+ {
+ return true;
+ }
+ }
+
+ if (mUnsupportedFunctionArgs[UnsupportedFunctionArgs::PixelLocalStorage])
+ {
+ if (type.isPixelLocal())
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ TIntermFunctionDefinition *processFunctionCall(TIntermAggregate *functionCall,
+ TIntermFunctionDefinition *originalDefinition,
+ bool *isOriginalUsedOut)
+ {
+ const TFunction *function = functionCall->getFunction();
+ const TIntermSequence &callArguments = *functionCall->getSequence();
+
+ TVector<Argument> replacedArguments;
+ TIntermSequence replacementIndices;
+
+ // Go through function call arguments, and see if any is used in an unsupported way.
+ for (size_t argIndex = 0; argIndex < callArguments.size(); ++argIndex)
+ {
+ TIntermTyped *callArgument = callArguments[argIndex]->getAsTyped();
+ const TVariable *funcArgument = function->getParam(argIndex);
+ if (isUnsupportedArgument(callArgument, funcArgument))
+ {
+ // Copy the argument and extract the side effects.
+ TIntermTyped *argument =
+ ExtractSideEffects(mSymbolTable, callArgument, &replacementIndices);
+
+ replacedArguments.push_back({argIndex, argument});
+ }
+ }
+
+ if (replacedArguments.empty())
+ {
+ *isOriginalUsedOut = true;
+ return nullptr;
+ }
+
+ mAnyMonomorphized = true;
+
+ insertStatementsInParentBlock(replacementIndices);
+
+ // Create the arguments for the substitute function call. Done before monomorphizing the
+ // function, which transforms the arguments to what needs to be replaced in the function
+ // body.
+ TIntermSequence newCallArgs;
+ CreateMonomorphizedFunctionCallArgs(callArguments, replacedArguments, &newCallArgs);
+
+ // Duplicate the function and substitute the replaced arguments with only the non-const
+ // indices. Additionally, substitute the non-const indices of arguments with the new
+ // function parameters.
+ VariableReplacementMap argumentMap;
+ const TFunction *monomorphized =
+ MonomorphizeFunction(mSymbolTable, function, &replacedArguments, &argumentMap);
+
+ // Replace this function call with a call to the new one.
+ queueReplacement(TIntermAggregate::CreateFunctionCall(*monomorphized, &newCallArgs),
+ OriginalNode::IS_DROPPED);
+
+ // Create a new function definition, with the body of the old function but with the replaced
+ // parameters substituted with the calling expressions.
+ TIntermFunctionPrototype *substitutePrototype = new TIntermFunctionPrototype(monomorphized);
+ TIntermBlock *substituteBlock = originalDefinition->getBody()->deepCopy();
+ GetDeclaratorReplacements(mSymbolTable, substituteBlock, &argumentMap);
+ bool valid = ReplaceVariables(mCompiler, substituteBlock, argumentMap);
+ ASSERT(valid);
+
+ return new TIntermFunctionDefinition(substitutePrototype, substituteBlock);
+ }
+
+ TCompiler *mCompiler;
+ const ShCompileOptions &mCompileOptions;
+ UnsupportedFunctionArgsBitSet mUnsupportedFunctionArgs;
+ bool mAnyMonomorphized = false;
+
+ // Map of original to monomorphized functions.
+ FunctionMap *mFunctionMap;
+};
+
+class UpdateFunctionsDefinitionsTraverser final : public TIntermTraverser
+{
+ public:
+ explicit UpdateFunctionsDefinitionsTraverser(TSymbolTable *symbolTable,
+ const FunctionMap &functionMap)
+ : TIntermTraverser(true, false, false, symbolTable), mFunctionMap(functionMap)
+ {}
+
+ void visitFunctionPrototype(TIntermFunctionPrototype *node) override
+ {
+ const bool isInFunctionDefinition = getParentNode()->getAsFunctionDefinition() != nullptr;
+ if (isInFunctionDefinition)
+ {
+ return;
+ }
+
+ // Add to and possibly replace the function prototype with replacement prototypes.
+ const TFunction *function = node->getFunction();
+ ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
+
+ const FunctionData &data = mFunctionMap.at(function);
+
+ // If nothing to do, leave it be.
+ if (data.monomorphizedDefinitions.empty())
+ {
+ ASSERT(data.isOriginalUsed);
+ return;
+ }
+
+ // Replace the prototype with itself (if function is still used) as well as any
+ // monomorphized versions.
+ TIntermSequence replacement;
+ if (data.isOriginalUsed)
+ {
+ replacement.push_back(node);
+ }
+ for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
+ {
+ replacement.push_back(new TIntermFunctionPrototype(
+ monomorphizedDefinition->getFunctionPrototype()->getFunction()));
+ }
+ mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
+ std::move(replacement));
+ }
+
+ bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
+ {
+ // Add to and possibly replace the function definition with replacement definitions.
+ const TFunction *function = node->getFunction();
+ ASSERT(function && mFunctionMap.find(function) != mFunctionMap.end());
+
+ const FunctionData &data = mFunctionMap.at(function);
+
+ // If nothing to do, leave it be.
+ if (data.monomorphizedDefinitions.empty())
+ {
+ ASSERT(data.isOriginalUsed || function->name() == "main");
+ return false;
+ }
+
+ // Replace the definition with itself (if function is still used) as well as any
+ // monomorphized versions.
+ TIntermSequence replacement;
+ if (data.isOriginalUsed)
+ {
+ replacement.push_back(node);
+ }
+ for (TIntermFunctionDefinition *monomorphizedDefinition : data.monomorphizedDefinitions)
+ {
+ replacement.push_back(monomorphizedDefinition);
+ }
+ mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
+ std::move(replacement));
+
+ return false;
+ }
+
+ private:
+ const FunctionMap &mFunctionMap;
+};
+
+void SortDeclarations(TIntermBlock *root)
+{
+ TIntermSequence *original = root->getSequence();
+
+ TIntermSequence replacement;
+ TIntermSequence functionDefs;
+
+ // Accumulate non-function-definition declarations in |replacement| and function definitions in
+ // |functionDefs|.
+ for (TIntermNode *node : *original)
+ {
+ if (node->getAsFunctionDefinition() || node->getAsFunctionPrototypeNode())
+ {
+ functionDefs.push_back(node);
+ }
+ else
+ {
+ replacement.push_back(node);
+ }
+ }
+
+ // Append function definitions to |replacement|.
+ replacement.insert(replacement.end(), functionDefs.begin(), functionDefs.end());
+
+ // Replace root's sequence with |replacement|.
+ root->replaceAllChildren(replacement);
+}
+
+bool MonomorphizeUnsupportedFunctionsImpl(TCompiler *compiler,
+ TIntermBlock *root,
+ TSymbolTable *symbolTable,
+ const ShCompileOptions &compileOptions,
+ UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)
+{
+ // First, sort out the declarations such that all non-function declarations are placed before
+ // function definitions. This way when the function is replaced with one that references said
+ // declarations (i.e. uniforms), the uniform declaration is already present above it.
+ SortDeclarations(root);
+
+ while (true)
+ {
+ FunctionMap functionMap;
+ InitializeFunctionMap(root, &functionMap);
+
+ MonomorphizeTraverser monomorphizer(compiler, symbolTable, compileOptions,
+ unsupportedFunctionArgs, &functionMap);
+ root->traverse(&monomorphizer);
+
+ if (!monomorphizer.getAnyMonomorphized())
+ {
+ break;
+ }
+
+ if (!monomorphizer.updateTree(compiler, root))
+ {
+ return false;
+ }
+
+ UpdateFunctionsDefinitionsTraverser functionUpdater(symbolTable, functionMap);
+ root->traverse(&functionUpdater);
+
+ if (!functionUpdater.updateTree(compiler, root))
+ {
+ return false;
+ }
+ }
+
+ return true;
+}
+} // anonymous namespace
+
+bool MonomorphizeUnsupportedFunctions(TCompiler *compiler,
+ TIntermBlock *root,
+ TSymbolTable *symbolTable,
+ const ShCompileOptions &compileOptions,
+ UnsupportedFunctionArgsBitSet unsupportedFunctionArgs)
+{
+ // This function actually applies multiple transformation, and the AST may not be valid until
+ // the transformations are entirely done. Some validation is momentarily disabled.
+ bool enableValidateFunctionCall = compiler->disableValidateFunctionCall();
+
+ bool result = MonomorphizeUnsupportedFunctionsImpl(compiler, root, symbolTable, compileOptions,
+ unsupportedFunctionArgs);
+
+ compiler->restoreValidateFunctionCall(enableValidateFunctionCall);
+ return result && compiler->validateAST(root);
+}
+} // namespace sh