1//
2// Copyright (c) 2017 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// RunAtTheEndOfShader.cpp: Add code to be run at the end of the shader. In case main() contains a
7// return statement, this is done by replacing the main() function with another function that calls
8// the old main, like this:
9//
10// void main() { body }
11// =>
12// void main0() { body }
13// void main()
14// {
15// main0();
16// codeToRun
17// }
18//
19// This way the code will get run even if the return statement inside main is executed.
20//
21
22#include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
23
24#include "compiler/translator/IntermNode.h"
25#include "compiler/translator/StaticType.h"
26#include "compiler/translator/SymbolTable.h"
27#include "compiler/translator/tree_util/FindMain.h"
28#include "compiler/translator/tree_util/IntermNode_util.h"
29#include "compiler/translator/tree_util/IntermTraverse.h"
30
31namespace sh
32{
33
34namespace
35{
36
37constexpr const ImmutableString kMainString("main");
38
39class ContainsReturnTraverser : public TIntermTraverser
40{
41 public:
42 ContainsReturnTraverser() : TIntermTraverser(true, false, false), mContainsReturn(false) {}
43
44 bool visitBranch(Visit visit, TIntermBranch *node) override
45 {
46 if (node->getFlowOp() == EOpReturn)
47 {
48 mContainsReturn = true;
49 }
50 return false;
51 }
52
53 bool containsReturn() { return mContainsReturn; }
54
55 private:
56 bool mContainsReturn;
57};
58
59bool ContainsReturn(TIntermNode *node)
60{
61 ContainsReturnTraverser traverser;
62 node->traverse(&traverser);
63 return traverser.containsReturn();
64}
65
66void WrapMainAndAppend(TIntermBlock *root,
67 TIntermFunctionDefinition *main,
68 TIntermNode *codeToRun,
69 TSymbolTable *symbolTable)
70{
71 // Replace main() with main0() with the same body.
72 TFunction *oldMain =
73 new TFunction(symbolTable, kEmptyImmutableString, SymbolType::AngleInternal,
74 StaticType::GetBasic<EbtVoid>(), false);
75 TIntermFunctionDefinition *oldMainDefinition =
76 CreateInternalFunctionDefinitionNode(*oldMain, main->getBody());
77
78 bool replaced = root->replaceChildNode(main, oldMainDefinition);
79 ASSERT(replaced);
80
81 // void main()
82 TFunction *newMain = new TFunction(symbolTable, kMainString, SymbolType::UserDefined,
83 StaticType::GetBasic<EbtVoid>(), false);
84 TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype(newMain);
85
86 // {
87 // main0();
88 // codeToRun
89 // }
90 TIntermBlock *newMainBody = new TIntermBlock();
91 TIntermAggregate *oldMainCall =
92 TIntermAggregate::CreateFunctionCall(*oldMain, new TIntermSequence());
93 newMainBody->appendStatement(oldMainCall);
94 newMainBody->appendStatement(codeToRun);
95
96 // Add the new main() to the root node.
97 TIntermFunctionDefinition *newMainDefinition =
98 new TIntermFunctionDefinition(newMainProto, newMainBody);
99 root->appendStatement(newMainDefinition);
100}
101
102} // anonymous namespace
103
104void RunAtTheEndOfShader(TIntermBlock *root, TIntermNode *codeToRun, TSymbolTable *symbolTable)
105{
106 TIntermFunctionDefinition *main = FindMain(root);
107 if (!ContainsReturn(main))
108 {
109 main->getBody()->appendStatement(codeToRun);
110 return;
111 }
112
113 WrapMainAndAppend(root, main, codeToRun, symbolTable);
114}
115
116} // namespace sh
117