diff options
Diffstat (limited to 'gfx/angle/checkout/src/compiler/translator/tree_util/RunAtTheEndOfShader.cpp')
-rw-r--r-- | gfx/angle/checkout/src/compiler/translator/tree_util/RunAtTheEndOfShader.cpp | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/gfx/angle/checkout/src/compiler/translator/tree_util/RunAtTheEndOfShader.cpp b/gfx/angle/checkout/src/compiler/translator/tree_util/RunAtTheEndOfShader.cpp new file mode 100644 index 0000000000..daf99185ab --- /dev/null +++ b/gfx/angle/checkout/src/compiler/translator/tree_util/RunAtTheEndOfShader.cpp @@ -0,0 +1,129 @@ +// +// Copyright 2017 The ANGLE Project Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// RunAtTheEndOfShader.cpp: Add code to be run at the end of the shader. In case main() contains a +// return statement, this is done by replacing the main() function with another function that calls +// the old main, like this: +// +// void main() { body } +// => +// void main0() { body } +// void main() +// { +// main0(); +// codeToRun +// } +// +// This way the code will get run even if the return statement inside main is executed. +// +// This is done if main ends in an unconditional |discard| as well, to help with SPIR-V generation +// that expects no dead-code to be present after branches in a block. To avoid bugs when |discard| +// is wrapped in unconditional blocks, any |discard| in main() is used as a signal to wrap it. +// + +#include "compiler/translator/tree_util/RunAtTheEndOfShader.h" + +#include "compiler/translator/Compiler.h" +#include "compiler/translator/IntermNode.h" +#include "compiler/translator/StaticType.h" +#include "compiler/translator/SymbolTable.h" +#include "compiler/translator/tree_util/FindMain.h" +#include "compiler/translator/tree_util/IntermNode_util.h" +#include "compiler/translator/tree_util/IntermTraverse.h" + +namespace sh +{ + +namespace +{ + +constexpr const ImmutableString kMainString("main"); + +class ContainsReturnOrDiscardTraverser : public TIntermTraverser +{ + public: + ContainsReturnOrDiscardTraverser() + : TIntermTraverser(true, false, false), mContainsReturnOrDiscard(false) + {} + + bool visitBranch(Visit visit, TIntermBranch *node) override + { + if (node->getFlowOp() == EOpReturn || node->getFlowOp() == EOpKill) + { + mContainsReturnOrDiscard = true; + } + return false; + } + + bool containsReturnOrDiscard() { return mContainsReturnOrDiscard; } + + private: + bool mContainsReturnOrDiscard; +}; + +bool ContainsReturnOrDiscard(TIntermNode *node) +{ + ContainsReturnOrDiscardTraverser traverser; + node->traverse(&traverser); + return traverser.containsReturnOrDiscard(); +} + +void WrapMainAndAppend(TIntermBlock *root, + TIntermFunctionDefinition *main, + TIntermNode *codeToRun, + TSymbolTable *symbolTable) +{ + // Replace main() with main0() with the same body. + TFunction *oldMain = + new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal, + StaticType::GetBasic<EbtVoid, EbpUndefined>(), false); + TIntermFunctionDefinition *oldMainDefinition = + CreateInternalFunctionDefinitionNode(*oldMain, main->getBody()); + + bool replaced = root->replaceChildNode(main, oldMainDefinition); + ASSERT(replaced); + + // void main() + TFunction *newMain = new TFunction(symbolTable, kMainString, SymbolType::UserDefined, + StaticType::GetBasic<EbtVoid, EbpUndefined>(), false); + TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype(newMain); + + // { + // main0(); + // codeToRun + // } + TIntermBlock *newMainBody = new TIntermBlock(); + TIntermSequence emptySequence; + TIntermAggregate *oldMainCall = TIntermAggregate::CreateFunctionCall(*oldMain, &emptySequence); + newMainBody->appendStatement(oldMainCall); + newMainBody->appendStatement(codeToRun); + + // Add the new main() to the root node. + TIntermFunctionDefinition *newMainDefinition = + new TIntermFunctionDefinition(newMainProto, newMainBody); + root->appendStatement(newMainDefinition); +} + +} // anonymous namespace + +bool RunAtTheEndOfShader(TCompiler *compiler, + TIntermBlock *root, + TIntermNode *codeToRun, + TSymbolTable *symbolTable) +{ + TIntermFunctionDefinition *main = FindMain(root); + if (ContainsReturnOrDiscard(main)) + { + WrapMainAndAppend(root, main, codeToRun, symbolTable); + } + else + { + main->getBody()->appendStatement(codeToRun); + } + + return compiler->validateAST(root); +} + +} // namespace sh |