// // Copyright 2002 The ANGLE Project Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // // The ArrayReturnValueToOutParameter function changes return values of an array type to out // parameters in function definitions, prototypes, and call sites. #include "compiler/translator/tree_ops/d3d/ArrayReturnValueToOutParameter.h" #include #include "compiler/translator/StaticType.h" #include "compiler/translator/SymbolTable.h" #include "compiler/translator/tree_util/IntermNode_util.h" #include "compiler/translator/tree_util/IntermTraverse.h" namespace sh { namespace { constexpr const ImmutableString kReturnValueVariableName("angle_return"); class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser { public: [[nodiscard]] static bool apply(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable); private: ArrayReturnValueToOutParameterTraverser(TSymbolTable *symbolTable); void visitFunctionPrototype(TIntermFunctionPrototype *node) override; bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override; bool visitAggregate(Visit visit, TIntermAggregate *node) override; bool visitBranch(Visit visit, TIntermBranch *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override; TIntermAggregate *createReplacementCall(TIntermAggregate *originalCall, TIntermTyped *returnValueTarget); // Set when traversal is inside a function with array return value. TIntermFunctionDefinition *mFunctionWithArrayReturnValue; struct ChangedFunction { const TVariable *returnValueVariable; const TFunction *func; }; // Map from function symbol ids to the changed function. std::map mChangedFunctions; }; TIntermAggregate *ArrayReturnValueToOutParameterTraverser::createReplacementCall( TIntermAggregate *originalCall, TIntermTyped *returnValueTarget) { TIntermSequence replacementArguments; TIntermSequence *originalArguments = originalCall->getSequence(); for (auto &arg : *originalArguments) { replacementArguments.push_back(arg); } replacementArguments.push_back(returnValueTarget); ASSERT(originalCall->getFunction()); const TSymbolUniqueId &originalId = originalCall->getFunction()->uniqueId(); TIntermAggregate *replacementCall = TIntermAggregate::CreateFunctionCall( *mChangedFunctions[originalId.get()].func, &replacementArguments); replacementCall->setLine(originalCall->getLine()); return replacementCall; } bool ArrayReturnValueToOutParameterTraverser::apply(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable) { ArrayReturnValueToOutParameterTraverser arrayReturnValueToOutParam(symbolTable); root->traverse(&arrayReturnValueToOutParam); return arrayReturnValueToOutParam.updateTree(compiler, root); } ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser( TSymbolTable *symbolTable) : TIntermTraverser(true, false, true, symbolTable), mFunctionWithArrayReturnValue(nullptr) {} bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition( Visit visit, TIntermFunctionDefinition *node) { if (node->getFunctionPrototype()->isArray() && visit == PreVisit) { // Replacing the function header is done on visitFunctionPrototype(). mFunctionWithArrayReturnValue = node; } if (visit == PostVisit) { mFunctionWithArrayReturnValue = nullptr; } return true; } void ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node) { if (node->isArray()) { // Replace the whole prototype node with another node that has the out parameter // added. Also set the function to return void. const TSymbolUniqueId &functionId = node->getFunction()->uniqueId(); if (mChangedFunctions.find(functionId.get()) == mChangedFunctions.end()) { TType *returnValueVariableType = new TType(node->getType()); returnValueVariableType->setQualifier(EvqParamOut); ChangedFunction changedFunction; changedFunction.returnValueVariable = new TVariable(mSymbolTable, kReturnValueVariableName, returnValueVariableType, SymbolType::AngleInternal); TFunction *func = new TFunction(mSymbolTable, node->getFunction()->name(), node->getFunction()->symbolType(), StaticType::GetBasic(), false); for (size_t i = 0; i < node->getFunction()->getParamCount(); ++i) { func->addParameter(node->getFunction()->getParam(i)); } func->addParameter(changedFunction.returnValueVariable); changedFunction.func = func; mChangedFunctions[functionId.get()] = changedFunction; } TIntermFunctionPrototype *replacement = new TIntermFunctionPrototype(mChangedFunctions[functionId.get()].func); replacement->setLine(node->getLine()); queueReplacement(replacement, OriginalNode::IS_DROPPED); } } bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node) { ASSERT(!node->isArray() || node->getOp() != EOpCallInternalRawFunction); if (visit == PreVisit && node->isArray() && node->getOp() == EOpCallFunctionInAST) { // Handle call sites where the returned array is not assigned. // Examples where f() is a function returning an array: // 1. f(); // 2. another_array == f(); // 3. another_function(f()); // 4. return f(); // Cases 2 to 4 are already converted to simpler cases by // SeparateExpressionsReturningArrays, so we only need to worry about the case where a // function call returning an array forms an expression by itself. TIntermBlock *parentBlock = getParentNode()->getAsBlock(); if (parentBlock) { // replace // f(); // with // type s0[size]; f(s0); TIntermSequence replacements; // type s0[size]; TIntermDeclaration *returnValueDeclaration = nullptr; TVariable *returnValue = DeclareTempVariable(mSymbolTable, new TType(node->getType()), EvqTemporary, &returnValueDeclaration); replacements.push_back(returnValueDeclaration); // f(s0); TIntermSymbol *returnValueSymbol = CreateTempSymbolNode(returnValue); replacements.push_back(createReplacementCall(node, returnValueSymbol)); mMultiReplacements.emplace_back(parentBlock, node, std::move(replacements)); } return false; } return true; } bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBranch *node) { if (mFunctionWithArrayReturnValue && node->getFlowOp() == EOpReturn) { // Instead of returning a value, assign to the out parameter and then return. TIntermSequence replacements; TIntermTyped *expression = node->getExpression(); ASSERT(expression != nullptr); const TSymbolUniqueId &functionId = mFunctionWithArrayReturnValue->getFunction()->uniqueId(); ASSERT(mChangedFunctions.find(functionId.get()) != mChangedFunctions.end()); TIntermSymbol *returnValueSymbol = new TIntermSymbol(mChangedFunctions[functionId.get()].returnValueVariable); TIntermBinary *replacementAssignment = new TIntermBinary(EOpAssign, returnValueSymbol, expression); replacementAssignment->setLine(expression->getLine()); replacements.push_back(replacementAssignment); TIntermBranch *replacementBranch = new TIntermBranch(EOpReturn, nullptr); replacementBranch->setLine(node->getLine()); replacements.push_back(replacementBranch); mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, std::move(replacements)); } return false; } bool ArrayReturnValueToOutParameterTraverser::visitBinary(Visit visit, TIntermBinary *node) { if (node->getOp() == EOpAssign && node->getLeft()->isArray()) { TIntermAggregate *rightAgg = node->getRight()->getAsAggregate(); ASSERT(rightAgg == nullptr || rightAgg->getOp() != EOpCallInternalRawFunction); if (rightAgg != nullptr && rightAgg->getOp() == EOpCallFunctionInAST) { TIntermAggregate *replacementCall = createReplacementCall(rightAgg, node->getLeft()); queueReplacement(replacementCall, OriginalNode::IS_DROPPED); } } return false; } } // namespace bool ArrayReturnValueToOutParameter(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable) { return ArrayReturnValueToOutParameterTraverser::apply(compiler, root, symbolTable); } } // namespace sh