summaryrefslogtreecommitdiffstats
path: root/gfx/angle/checkout/src/compiler/translator/tree_util/RunAtTheEndOfShader.cpp
blob: daf99185ab2b15f276a299c600b4c9b3fc51ebb8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
//
// Copyright 2017 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// RunAtTheEndOfShader.cpp: Add code to be run at the end of the shader. In case main() contains a
// return statement, this is done by replacing the main() function with another function that calls
// the old main, like this:
//
// void main() { body }
// =>
// void main0() { body }
// void main()
// {
//     main0();
//     codeToRun
// }
//
// This way the code will get run even if the return statement inside main is executed.
//
// This is done if main ends in an unconditional |discard| as well, to help with SPIR-V generation
// that expects no dead-code to be present after branches in a block.  To avoid bugs when |discard|
// is wrapped in unconditional blocks, any |discard| in main() is used as a signal to wrap it.
//

#include "compiler/translator/tree_util/RunAtTheEndOfShader.h"

#include "compiler/translator/Compiler.h"
#include "compiler/translator/IntermNode.h"
#include "compiler/translator/StaticType.h"
#include "compiler/translator/SymbolTable.h"
#include "compiler/translator/tree_util/FindMain.h"
#include "compiler/translator/tree_util/IntermNode_util.h"
#include "compiler/translator/tree_util/IntermTraverse.h"

namespace sh
{

namespace
{

constexpr const ImmutableString kMainString("main");

class ContainsReturnOrDiscardTraverser : public TIntermTraverser
{
  public:
    ContainsReturnOrDiscardTraverser()
        : TIntermTraverser(true, false, false), mContainsReturnOrDiscard(false)
    {}

    bool visitBranch(Visit visit, TIntermBranch *node) override
    {
        if (node->getFlowOp() == EOpReturn || node->getFlowOp() == EOpKill)
        {
            mContainsReturnOrDiscard = true;
        }
        return false;
    }

    bool containsReturnOrDiscard() { return mContainsReturnOrDiscard; }

  private:
    bool mContainsReturnOrDiscard;
};

bool ContainsReturnOrDiscard(TIntermNode *node)
{
    ContainsReturnOrDiscardTraverser traverser;
    node->traverse(&traverser);
    return traverser.containsReturnOrDiscard();
}

void WrapMainAndAppend(TIntermBlock *root,
                       TIntermFunctionDefinition *main,
                       TIntermNode *codeToRun,
                       TSymbolTable *symbolTable)
{
    // Replace main() with main0() with the same body.
    TFunction *oldMain =
        new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
                      StaticType::GetBasic<EbtVoid, EbpUndefined>(), false);
    TIntermFunctionDefinition *oldMainDefinition =
        CreateInternalFunctionDefinitionNode(*oldMain, main->getBody());

    bool replaced = root->replaceChildNode(main, oldMainDefinition);
    ASSERT(replaced);

    // void main()
    TFunction *newMain = new TFunction(symbolTable, kMainString, SymbolType::UserDefined,
                                       StaticType::GetBasic<EbtVoid, EbpUndefined>(), false);
    TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype(newMain);

    // {
    //     main0();
    //     codeToRun
    // }
    TIntermBlock *newMainBody = new TIntermBlock();
    TIntermSequence emptySequence;
    TIntermAggregate *oldMainCall = TIntermAggregate::CreateFunctionCall(*oldMain, &emptySequence);
    newMainBody->appendStatement(oldMainCall);
    newMainBody->appendStatement(codeToRun);

    // Add the new main() to the root node.
    TIntermFunctionDefinition *newMainDefinition =
        new TIntermFunctionDefinition(newMainProto, newMainBody);
    root->appendStatement(newMainDefinition);
}

}  // anonymous namespace

bool RunAtTheEndOfShader(TCompiler *compiler,
                         TIntermBlock *root,
                         TIntermNode *codeToRun,
                         TSymbolTable *symbolTable)
{
    TIntermFunctionDefinition *main = FindMain(root);
    if (ContainsReturnOrDiscard(main))
    {
        WrapMainAndAppend(root, main, codeToRun, symbolTable);
    }
    else
    {
        main->getBody()->appendStatement(codeToRun);
    }

    return compiler->validateAST(root);
}

}  // namespace sh