1 | // |
2 | // Copyright (c) 2018 The ANGLE Project Authors. All rights reserved. |
3 | // Use of this source code is governed by a BSD-style license that can be |
4 | // found in the LICENSE file. |
5 | // |
6 | // Implementation of the function RewriteAtomicFunctionExpressions. |
7 | // See the header for more details. |
8 | |
9 | #include "RewriteAtomicFunctionExpressions.h" |
10 | |
11 | #include "compiler/translator/tree_util/IntermNodePatternMatcher.h" |
12 | #include "compiler/translator/tree_util/IntermNode_util.h" |
13 | #include "compiler/translator/tree_util/IntermTraverse.h" |
14 | #include "compiler/translator/util.h" |
15 | |
16 | namespace sh |
17 | { |
18 | namespace |
19 | { |
20 | // Traverser that simplifies all the atomic function expressions into the ones that can be directly |
21 | // translated into HLSL. |
22 | // |
23 | // case 1 (only for atomicExchange and atomicCompSwap): |
24 | // original: |
25 | // atomicExchange(counter, newValue); |
26 | // new: |
27 | // tempValue = atomicExchange(counter, newValue); |
28 | // |
29 | // case 2 (atomic function, temporary variable required): |
30 | // original: |
31 | // value = atomicAdd(counter, 1) * otherValue; |
32 | // someArray[atomicAdd(counter, 1)] = someOtherValue; |
33 | // new: |
34 | // value = ((tempValue = atomicAdd(counter, 1)), tempValue) * otherValue; |
35 | // someArray[((tempValue = atomicAdd(counter, 1)), tempValue)] = someOtherValue; |
36 | // |
37 | // case 3 (atomic function used directly initialize a variable): |
38 | // original: |
39 | // int value = atomicAdd(counter, 1); |
40 | // new: |
41 | // tempValue = atomicAdd(counter, 1); |
42 | // int value = tempValue; |
43 | // |
44 | class RewriteAtomicFunctionExpressionsTraverser : public TIntermTraverser |
45 | { |
46 | public: |
47 | RewriteAtomicFunctionExpressionsTraverser(TSymbolTable *symbolTable, int shaderVersion); |
48 | |
49 | bool visitAggregate(Visit visit, TIntermAggregate *node) override; |
50 | bool visitBlock(Visit visit, TIntermBlock *node) override; |
51 | |
52 | private: |
53 | static bool IsAtomicExchangeOrCompSwapNoReturnValue(TIntermAggregate *node, |
54 | TIntermNode *parentNode); |
55 | static bool IsAtomicFunctionInsideExpression(TIntermAggregate *node, TIntermNode *parentNode); |
56 | |
57 | void rewriteAtomicFunctionCallNode(TIntermAggregate *oldAtomicFunctionNode); |
58 | |
59 | const TVariable *getTempVariable(const TType *type); |
60 | |
61 | int mShaderVersion; |
62 | TIntermSequence mTempVariables; |
63 | }; |
64 | |
65 | RewriteAtomicFunctionExpressionsTraverser::RewriteAtomicFunctionExpressionsTraverser( |
66 | TSymbolTable *symbolTable, |
67 | int shaderVersion) |
68 | : TIntermTraverser(false, false, true, symbolTable), mShaderVersion(shaderVersion) |
69 | {} |
70 | |
71 | void RewriteAtomicFunctionExpressionsTraverser::rewriteAtomicFunctionCallNode( |
72 | TIntermAggregate *oldAtomicFunctionNode) |
73 | { |
74 | ASSERT(oldAtomicFunctionNode); |
75 | |
76 | const TVariable *returnVariable = getTempVariable(&oldAtomicFunctionNode->getType()); |
77 | |
78 | TIntermBinary *rewrittenNode = new TIntermBinary( |
79 | TOperator::EOpAssign, CreateTempSymbolNode(returnVariable), oldAtomicFunctionNode); |
80 | |
81 | auto *parentNode = getParentNode(); |
82 | |
83 | auto *parentBinary = parentNode->getAsBinaryNode(); |
84 | if (parentBinary && parentBinary->getOp() == EOpInitialize) |
85 | { |
86 | insertStatementInParentBlock(rewrittenNode); |
87 | queueReplacement(CreateTempSymbolNode(returnVariable), OriginalNode::IS_DROPPED); |
88 | } |
89 | else |
90 | { |
91 | // As all atomic function assignment will be converted to the last argument of an |
92 | // interlocked function, if we need the return value, assignment needs to be wrapped with |
93 | // the comma operator and the temporary variables. |
94 | if (!parentNode->getAsBlock()) |
95 | { |
96 | rewrittenNode = TIntermBinary::CreateComma( |
97 | rewrittenNode, new TIntermSymbol(returnVariable), mShaderVersion); |
98 | } |
99 | |
100 | queueReplacement(rewrittenNode, OriginalNode::IS_DROPPED); |
101 | } |
102 | } |
103 | |
104 | const TVariable *RewriteAtomicFunctionExpressionsTraverser::getTempVariable(const TType *type) |
105 | { |
106 | TIntermDeclaration *variableDeclaration; |
107 | TVariable *returnVariable = |
108 | DeclareTempVariable(mSymbolTable, type, EvqTemporary, &variableDeclaration); |
109 | mTempVariables.push_back(variableDeclaration); |
110 | return returnVariable; |
111 | } |
112 | |
113 | bool RewriteAtomicFunctionExpressionsTraverser::IsAtomicExchangeOrCompSwapNoReturnValue( |
114 | TIntermAggregate *node, |
115 | TIntermNode *parentNode) |
116 | { |
117 | ASSERT(node); |
118 | return (node->getOp() == EOpAtomicExchange || node->getOp() == EOpAtomicCompSwap) && |
119 | parentNode && parentNode->getAsBlock(); |
120 | } |
121 | |
122 | bool RewriteAtomicFunctionExpressionsTraverser::IsAtomicFunctionInsideExpression( |
123 | TIntermAggregate *node, |
124 | TIntermNode *parentNode) |
125 | { |
126 | ASSERT(node); |
127 | // We only need to handle atomic functions with a parent that it is not block nodes. If the |
128 | // parent node is block, it means that the atomic function is not inside an expression. |
129 | if (!IsAtomicFunction(node->getOp()) || parentNode->getAsBlock()) |
130 | { |
131 | return false; |
132 | } |
133 | |
134 | auto *parentAsBinary = parentNode->getAsBinaryNode(); |
135 | // Assignments are handled in OutputHLSL |
136 | return !parentAsBinary || parentAsBinary->getOp() != EOpAssign; |
137 | } |
138 | |
139 | bool RewriteAtomicFunctionExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node) |
140 | { |
141 | ASSERT(visit == PostVisit); |
142 | // Skip atomic memory functions for SSBO. They will be processed in the OutputHLSL traverser. |
143 | if (IsAtomicFunction(node->getOp()) && |
144 | IsInShaderStorageBlock((*node->getSequence())[0]->getAsTyped())) |
145 | { |
146 | return false; |
147 | } |
148 | |
149 | TIntermNode *parentNode = getParentNode(); |
150 | if (IsAtomicExchangeOrCompSwapNoReturnValue(node, parentNode) || |
151 | IsAtomicFunctionInsideExpression(node, parentNode)) |
152 | { |
153 | rewriteAtomicFunctionCallNode(node); |
154 | } |
155 | |
156 | return true; |
157 | } |
158 | |
159 | bool RewriteAtomicFunctionExpressionsTraverser::visitBlock(Visit visit, TIntermBlock *node) |
160 | { |
161 | ASSERT(visit == PostVisit); |
162 | |
163 | if (!mTempVariables.empty() && getParentNode()->getAsFunctionDefinition()) |
164 | { |
165 | insertStatementsInBlockAtPosition(node, 0, mTempVariables, TIntermSequence()); |
166 | mTempVariables.clear(); |
167 | } |
168 | |
169 | return true; |
170 | } |
171 | |
172 | } // anonymous namespace |
173 | |
174 | void RewriteAtomicFunctionExpressions(TIntermNode *root, |
175 | TSymbolTable *symbolTable, |
176 | int shaderVersion) |
177 | { |
178 | RewriteAtomicFunctionExpressionsTraverser traverser(symbolTable, shaderVersion); |
179 | traverser.traverse(root); |
180 | traverser.updateTree(); |
181 | } |
182 | } // namespace sh |