1//
2// Copyright (c) 2017 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// IntermTraverse.h : base classes for AST traversers that walk the AST and
7// also have the ability to transform it by replacing nodes.
8
9#ifndef COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_
10#define COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_
11
12#include "compiler/translator/IntermNode.h"
13#include "compiler/translator/tree_util/Visit.h"
14
15namespace sh
16{
17
18class TSymbolTable;
19class TSymbolUniqueId;
20
21// For traversing the tree. User should derive from this class overriding the visit functions,
22// and then pass an object of the subclass to a traverse method of a node.
23//
24// The traverse*() functions may also be overridden to do other bookkeeping on the tree to provide
25// contextual information to the visit functions, such as whether the node is the target of an
26// assignment. This is complex to maintain and so should only be done in special cases.
27//
28// When using this, just fill in the methods for nodes you want visited.
29// Return false from a pre-visit to skip visiting that node's subtree.
30//
31// See also how to write AST transformations documentation:
32// https://github.com/google/angle/blob/master/doc/WritingShaderASTTransformations.md
33class TIntermTraverser : angle::NonCopyable
34{
35 public:
36 POOL_ALLOCATOR_NEW_DELETE
37 TIntermTraverser(bool preVisit,
38 bool inVisit,
39 bool postVisit,
40 TSymbolTable *symbolTable = nullptr);
41 virtual ~TIntermTraverser();
42
43 virtual void visitSymbol(TIntermSymbol *node) {}
44 virtual void visitConstantUnion(TIntermConstantUnion *node) {}
45 virtual bool visitSwizzle(Visit visit, TIntermSwizzle *node) { return true; }
46 virtual bool visitBinary(Visit visit, TIntermBinary *node) { return true; }
47 virtual bool visitUnary(Visit visit, TIntermUnary *node) { return true; }
48 virtual bool visitTernary(Visit visit, TIntermTernary *node) { return true; }
49 virtual bool visitIfElse(Visit visit, TIntermIfElse *node) { return true; }
50 virtual bool visitSwitch(Visit visit, TIntermSwitch *node) { return true; }
51 virtual bool visitCase(Visit visit, TIntermCase *node) { return true; }
52 virtual void visitFunctionPrototype(TIntermFunctionPrototype *node) {}
53 virtual bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
54 {
55 return true;
56 }
57 virtual bool visitAggregate(Visit visit, TIntermAggregate *node) { return true; }
58 virtual bool visitBlock(Visit visit, TIntermBlock *node) { return true; }
59 virtual bool visitInvariantDeclaration(Visit visit, TIntermInvariantDeclaration *node)
60 {
61 return true;
62 }
63 virtual bool visitDeclaration(Visit visit, TIntermDeclaration *node) { return true; }
64 virtual bool visitLoop(Visit visit, TIntermLoop *node) { return true; }
65 virtual bool visitBranch(Visit visit, TIntermBranch *node) { return true; }
66 virtual void visitPreprocessorDirective(TIntermPreprocessorDirective *node) {}
67
68 // The traverse functions contain logic for iterating over the children of the node
69 // and calling the visit functions in the appropriate places. They also track some
70 // context that may be used by the visit functions.
71
72 // The generic traverse() function is used for nodes that don't need special handling.
73 // It's templated in order to avoid virtual function calls, this gains around 2% compiler
74 // performance.
75 template <typename T>
76 void traverse(T *node);
77
78 // Specialized traverse functions are implemented for node types where traversal logic may need
79 // to be overridden or where some special bookkeeping needs to be done.
80 virtual void traverseBinary(TIntermBinary *node);
81 virtual void traverseUnary(TIntermUnary *node);
82 virtual void traverseFunctionDefinition(TIntermFunctionDefinition *node);
83 virtual void traverseAggregate(TIntermAggregate *node);
84 virtual void traverseBlock(TIntermBlock *node);
85 virtual void traverseLoop(TIntermLoop *node);
86
87 int getMaxDepth() const { return mMaxDepth; }
88
89 // If traversers need to replace nodes, they can add the replacements in
90 // mReplacements/mMultiReplacements during traversal and the user of the traverser should call
91 // this function after traversal to perform them.
92 void updateTree();
93
94 protected:
95 void setMaxAllowedDepth(int depth);
96
97 // Should only be called from traverse*() functions
98 bool incrementDepth(TIntermNode *current)
99 {
100 mMaxDepth = std::max(mMaxDepth, static_cast<int>(mPath.size()));
101 mPath.push_back(current);
102 return mMaxDepth < mMaxAllowedDepth;
103 }
104
105 // Should only be called from traverse*() functions
106 void decrementDepth() { mPath.pop_back(); }
107
108 int getCurrentTraversalDepth() const { return static_cast<int>(mPath.size()) - 1; }
109
110 // RAII helper for incrementDepth/decrementDepth
111 class ScopedNodeInTraversalPath
112 {
113 public:
114 ScopedNodeInTraversalPath(TIntermTraverser *traverser, TIntermNode *current)
115 : mTraverser(traverser)
116 {
117 mWithinDepthLimit = mTraverser->incrementDepth(current);
118 }
119 ~ScopedNodeInTraversalPath() { mTraverser->decrementDepth(); }
120
121 bool isWithinDepthLimit() { return mWithinDepthLimit; }
122
123 private:
124 TIntermTraverser *mTraverser;
125 bool mWithinDepthLimit;
126 };
127 // Optimized traversal functions for leaf nodes directly access ScopedNodeInTraversalPath.
128 friend void TIntermSymbol::traverse(TIntermTraverser *);
129 friend void TIntermConstantUnion::traverse(TIntermTraverser *);
130 friend void TIntermFunctionPrototype::traverse(TIntermTraverser *);
131
132 TIntermNode *getParentNode() { return mPath.size() <= 1 ? nullptr : mPath[mPath.size() - 2u]; }
133
134 // Return the nth ancestor of the node being traversed. getAncestorNode(0) == getParentNode()
135 TIntermNode *getAncestorNode(unsigned int n)
136 {
137 if (mPath.size() > n + 1u)
138 {
139 return mPath[mPath.size() - n - 2u];
140 }
141 return nullptr;
142 }
143
144 const TIntermBlock *getParentBlock() const;
145
146 void pushParentBlock(TIntermBlock *node);
147 void incrementParentBlockPos();
148 void popParentBlock();
149
150 // To replace a single node with multiple nodes in the parent aggregate. May be used with blocks
151 // but also with other nodes like declarations.
152 struct NodeReplaceWithMultipleEntry
153 {
154 NodeReplaceWithMultipleEntry(TIntermAggregateBase *parentIn,
155 TIntermNode *originalIn,
156 TIntermSequence replacementsIn)
157 : parent(parentIn), original(originalIn), replacements(std::move(replacementsIn))
158 {}
159
160 TIntermAggregateBase *parent;
161 TIntermNode *original;
162 TIntermSequence replacements;
163 };
164
165 // Helper to insert statements in the parent block of the node currently being traversed.
166 // The statements will be inserted before the node being traversed once updateTree is called.
167 // Should only be called during PreVisit or PostVisit if called from block nodes.
168 // Note that two insertions to the same position in the same block are not supported.
169 void insertStatementsInParentBlock(const TIntermSequence &insertions);
170
171 // Same as above, but supports simultaneous insertion of statements before and after the node
172 // currently being traversed.
173 void insertStatementsInParentBlock(const TIntermSequence &insertionsBefore,
174 const TIntermSequence &insertionsAfter);
175
176 // Helper to insert a single statement.
177 void insertStatementInParentBlock(TIntermNode *statement);
178
179 // Explicitly specify where to insert statements. The statements are inserted before and after
180 // the specified position. The statements will be inserted once updateTree is called. Note that
181 // two insertions to the same position in the same block are not supported.
182 void insertStatementsInBlockAtPosition(TIntermBlock *parent,
183 size_t position,
184 const TIntermSequence &insertionsBefore,
185 const TIntermSequence &insertionsAfter);
186
187 enum class OriginalNode
188 {
189 BECOMES_CHILD,
190 IS_DROPPED
191 };
192
193 void clearReplacementQueue();
194
195 // Replace the node currently being visited with replacement.
196 void queueReplacement(TIntermNode *replacement, OriginalNode originalStatus);
197 // Explicitly specify a node to replace with replacement.
198 void queueReplacementWithParent(TIntermNode *parent,
199 TIntermNode *original,
200 TIntermNode *replacement,
201 OriginalNode originalStatus);
202
203 const bool preVisit;
204 const bool inVisit;
205 const bool postVisit;
206
207 int mMaxDepth;
208 int mMaxAllowedDepth;
209
210 bool mInGlobalScope;
211
212 // During traversing, save all the changes that need to happen into
213 // mReplacements/mMultiReplacements, then do them by calling updateTree().
214 // Multi replacements are processed after single replacements.
215 std::vector<NodeReplaceWithMultipleEntry> mMultiReplacements;
216
217 TSymbolTable *mSymbolTable;
218
219 private:
220 // To insert multiple nodes into the parent block.
221 struct NodeInsertMultipleEntry
222 {
223 NodeInsertMultipleEntry(TIntermBlock *_parent,
224 TIntermSequence::size_type _position,
225 TIntermSequence _insertionsBefore,
226 TIntermSequence _insertionsAfter)
227 : parent(_parent),
228 position(_position),
229 insertionsBefore(_insertionsBefore),
230 insertionsAfter(_insertionsAfter)
231 {}
232
233 TIntermBlock *parent;
234 TIntermSequence::size_type position;
235 TIntermSequence insertionsBefore;
236 TIntermSequence insertionsAfter;
237 };
238
239 static bool CompareInsertion(const NodeInsertMultipleEntry &a,
240 const NodeInsertMultipleEntry &b);
241
242 // To replace a single node with another on the parent node
243 struct NodeUpdateEntry
244 {
245 NodeUpdateEntry(TIntermNode *_parent,
246 TIntermNode *_original,
247 TIntermNode *_replacement,
248 bool _originalBecomesChildOfReplacement)
249 : parent(_parent),
250 original(_original),
251 replacement(_replacement),
252 originalBecomesChildOfReplacement(_originalBecomesChildOfReplacement)
253 {}
254
255 TIntermNode *parent;
256 TIntermNode *original;
257 TIntermNode *replacement;
258 bool originalBecomesChildOfReplacement;
259 };
260
261 struct ParentBlock
262 {
263 ParentBlock(TIntermBlock *nodeIn, TIntermSequence::size_type posIn)
264 : node(nodeIn), pos(posIn)
265 {}
266
267 TIntermBlock *node;
268 TIntermSequence::size_type pos;
269 };
270
271 std::vector<NodeInsertMultipleEntry> mInsertions;
272 std::vector<NodeUpdateEntry> mReplacements;
273
274 // All the nodes from root to the current node during traversing.
275 TVector<TIntermNode *> mPath;
276
277 // All the code blocks from the root to the current node's parent during traversal.
278 std::vector<ParentBlock> mParentBlockStack;
279};
280
281// Traverser parent class that tracks where a node is a destination of a write operation and so is
282// required to be an l-value.
283class TLValueTrackingTraverser : public TIntermTraverser
284{
285 public:
286 TLValueTrackingTraverser(bool preVisit,
287 bool inVisit,
288 bool postVisit,
289 TSymbolTable *symbolTable);
290 virtual ~TLValueTrackingTraverser() {}
291
292 void traverseBinary(TIntermBinary *node) final;
293 void traverseUnary(TIntermUnary *node) final;
294 void traverseAggregate(TIntermAggregate *node) final;
295
296 protected:
297 bool isLValueRequiredHere() const
298 {
299 return mOperatorRequiresLValue || mInFunctionCallOutParameter;
300 }
301
302 private:
303 // Track whether an l-value is required in the node that is currently being traversed by the
304 // surrounding operator.
305 // Use isLValueRequiredHere to check all conditions which require an l-value.
306 void setOperatorRequiresLValue(bool lValueRequired)
307 {
308 mOperatorRequiresLValue = lValueRequired;
309 }
310 bool operatorRequiresLValue() const { return mOperatorRequiresLValue; }
311
312 // Track whether an l-value is required inside a function call.
313 void setInFunctionCallOutParameter(bool inOutParameter);
314 bool isInFunctionCallOutParameter() const;
315
316 bool mOperatorRequiresLValue;
317 bool mInFunctionCallOutParameter;
318};
319
320} // namespace sh
321
322#endif // COMPILER_TRANSLATOR_TREEUTIL_INTERMTRAVERSE_H_
323