1//
2// Copyright (c) 2018 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// Implementation of the function RewriteAtomicFunctionExpressions.
7// See the header for more details.
8
9#include "RewriteAtomicFunctionExpressions.h"
10
11#include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
12#include "compiler/translator/tree_util/IntermNode_util.h"
13#include "compiler/translator/tree_util/IntermTraverse.h"
14#include "compiler/translator/util.h"
15
16namespace sh
17{
18namespace
19{
20// Traverser that simplifies all the atomic function expressions into the ones that can be directly
21// translated into HLSL.
22//
23// case 1 (only for atomicExchange and atomicCompSwap):
24// original:
25// atomicExchange(counter, newValue);
26// new:
27// tempValue = atomicExchange(counter, newValue);
28//
29// case 2 (atomic function, temporary variable required):
30// original:
31// value = atomicAdd(counter, 1) * otherValue;
32// someArray[atomicAdd(counter, 1)] = someOtherValue;
33// new:
34// value = ((tempValue = atomicAdd(counter, 1)), tempValue) * otherValue;
35// someArray[((tempValue = atomicAdd(counter, 1)), tempValue)] = someOtherValue;
36//
37// case 3 (atomic function used directly initialize a variable):
38// original:
39// int value = atomicAdd(counter, 1);
40// new:
41// tempValue = atomicAdd(counter, 1);
42// int value = tempValue;
43//
44class RewriteAtomicFunctionExpressionsTraverser : public TIntermTraverser
45{
46 public:
47 RewriteAtomicFunctionExpressionsTraverser(TSymbolTable *symbolTable, int shaderVersion);
48
49 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
50 bool visitBlock(Visit visit, TIntermBlock *node) override;
51
52 private:
53 static bool IsAtomicExchangeOrCompSwapNoReturnValue(TIntermAggregate *node,
54 TIntermNode *parentNode);
55 static bool IsAtomicFunctionInsideExpression(TIntermAggregate *node, TIntermNode *parentNode);
56
57 void rewriteAtomicFunctionCallNode(TIntermAggregate *oldAtomicFunctionNode);
58
59 const TVariable *getTempVariable(const TType *type);
60
61 int mShaderVersion;
62 TIntermSequence mTempVariables;
63};
64
65RewriteAtomicFunctionExpressionsTraverser::RewriteAtomicFunctionExpressionsTraverser(
66 TSymbolTable *symbolTable,
67 int shaderVersion)
68 : TIntermTraverser(false, false, true, symbolTable), mShaderVersion(shaderVersion)
69{}
70
71void RewriteAtomicFunctionExpressionsTraverser::rewriteAtomicFunctionCallNode(
72 TIntermAggregate *oldAtomicFunctionNode)
73{
74 ASSERT(oldAtomicFunctionNode);
75
76 const TVariable *returnVariable = getTempVariable(&oldAtomicFunctionNode->getType());
77
78 TIntermBinary *rewrittenNode = new TIntermBinary(
79 TOperator::EOpAssign, CreateTempSymbolNode(returnVariable), oldAtomicFunctionNode);
80
81 auto *parentNode = getParentNode();
82
83 auto *parentBinary = parentNode->getAsBinaryNode();
84 if (parentBinary && parentBinary->getOp() == EOpInitialize)
85 {
86 insertStatementInParentBlock(rewrittenNode);
87 queueReplacement(CreateTempSymbolNode(returnVariable), OriginalNode::IS_DROPPED);
88 }
89 else
90 {
91 // As all atomic function assignment will be converted to the last argument of an
92 // interlocked function, if we need the return value, assignment needs to be wrapped with
93 // the comma operator and the temporary variables.
94 if (!parentNode->getAsBlock())
95 {
96 rewrittenNode = TIntermBinary::CreateComma(
97 rewrittenNode, new TIntermSymbol(returnVariable), mShaderVersion);
98 }
99
100 queueReplacement(rewrittenNode, OriginalNode::IS_DROPPED);
101 }
102}
103
104const TVariable *RewriteAtomicFunctionExpressionsTraverser::getTempVariable(const TType *type)
105{
106 TIntermDeclaration *variableDeclaration;
107 TVariable *returnVariable =
108 DeclareTempVariable(mSymbolTable, type, EvqTemporary, &variableDeclaration);
109 mTempVariables.push_back(variableDeclaration);
110 return returnVariable;
111}
112
113bool RewriteAtomicFunctionExpressionsTraverser::IsAtomicExchangeOrCompSwapNoReturnValue(
114 TIntermAggregate *node,
115 TIntermNode *parentNode)
116{
117 ASSERT(node);
118 return (node->getOp() == EOpAtomicExchange || node->getOp() == EOpAtomicCompSwap) &&
119 parentNode && parentNode->getAsBlock();
120}
121
122bool RewriteAtomicFunctionExpressionsTraverser::IsAtomicFunctionInsideExpression(
123 TIntermAggregate *node,
124 TIntermNode *parentNode)
125{
126 ASSERT(node);
127 // We only need to handle atomic functions with a parent that it is not block nodes. If the
128 // parent node is block, it means that the atomic function is not inside an expression.
129 if (!IsAtomicFunction(node->getOp()) || parentNode->getAsBlock())
130 {
131 return false;
132 }
133
134 auto *parentAsBinary = parentNode->getAsBinaryNode();
135 // Assignments are handled in OutputHLSL
136 return !parentAsBinary || parentAsBinary->getOp() != EOpAssign;
137}
138
139bool RewriteAtomicFunctionExpressionsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
140{
141 ASSERT(visit == PostVisit);
142 // Skip atomic memory functions for SSBO. They will be processed in the OutputHLSL traverser.
143 if (IsAtomicFunction(node->getOp()) &&
144 IsInShaderStorageBlock((*node->getSequence())[0]->getAsTyped()))
145 {
146 return false;
147 }
148
149 TIntermNode *parentNode = getParentNode();
150 if (IsAtomicExchangeOrCompSwapNoReturnValue(node, parentNode) ||
151 IsAtomicFunctionInsideExpression(node, parentNode))
152 {
153 rewriteAtomicFunctionCallNode(node);
154 }
155
156 return true;
157}
158
159bool RewriteAtomicFunctionExpressionsTraverser::visitBlock(Visit visit, TIntermBlock *node)
160{
161 ASSERT(visit == PostVisit);
162
163 if (!mTempVariables.empty() && getParentNode()->getAsFunctionDefinition())
164 {
165 insertStatementsInBlockAtPosition(node, 0, mTempVariables, TIntermSequence());
166 mTempVariables.clear();
167 }
168
169 return true;
170}
171
172} // anonymous namespace
173
174void RewriteAtomicFunctionExpressions(TIntermNode *root,
175 TSymbolTable *symbolTable,
176 int shaderVersion)
177{
178 RewriteAtomicFunctionExpressionsTraverser traverser(symbolTable, shaderVersion);
179 traverser.traverse(root);
180 traverser.updateTree();
181}
182} // namespace sh