1 | // |
2 | // Copyright (c) 2002-2014 The ANGLE Project Authors. All rights reserved. |
3 | // Use of this source code is governed by a BSD-style license that can be |
4 | // found in the LICENSE file. |
5 | // |
6 | // Scalarize vector and matrix constructor args, so that vectors built from components don't have |
7 | // matrix arguments, and matrices built from components don't have vector arguments. This avoids |
8 | // driver bugs around vector and matrix constructors. |
9 | // |
10 | |
11 | #include "compiler/translator/tree_ops/ScalarizeVecAndMatConstructorArgs.h" |
12 | #include "common/debug.h" |
13 | |
14 | #include <algorithm> |
15 | |
16 | #include "angle_gl.h" |
17 | #include "common/angleutils.h" |
18 | #include "compiler/translator/tree_util/IntermNodePatternMatcher.h" |
19 | #include "compiler/translator/tree_util/IntermNode_util.h" |
20 | #include "compiler/translator/tree_util/IntermTraverse.h" |
21 | |
22 | namespace sh |
23 | { |
24 | |
25 | namespace |
26 | { |
27 | |
28 | TIntermBinary *ConstructVectorIndexBinaryNode(TIntermSymbol *symbolNode, int index) |
29 | { |
30 | return new TIntermBinary(EOpIndexDirect, symbolNode, CreateIndexNode(index)); |
31 | } |
32 | |
33 | TIntermBinary *ConstructMatrixIndexBinaryNode(TIntermSymbol *symbolNode, int colIndex, int rowIndex) |
34 | { |
35 | TIntermBinary *colVectorNode = ConstructVectorIndexBinaryNode(symbolNode, colIndex); |
36 | |
37 | return new TIntermBinary(EOpIndexDirect, colVectorNode, CreateIndexNode(rowIndex)); |
38 | } |
39 | |
40 | class ScalarizeArgsTraverser : public TIntermTraverser |
41 | { |
42 | public: |
43 | ScalarizeArgsTraverser(sh::GLenum shaderType, |
44 | bool fragmentPrecisionHigh, |
45 | TSymbolTable *symbolTable) |
46 | : TIntermTraverser(true, false, false, symbolTable), |
47 | mShaderType(shaderType), |
48 | mFragmentPrecisionHigh(fragmentPrecisionHigh), |
49 | mNodesToScalarize(IntermNodePatternMatcher::kScalarizedVecOrMatConstructor) |
50 | {} |
51 | |
52 | protected: |
53 | bool visitAggregate(Visit visit, TIntermAggregate *node) override; |
54 | bool visitBlock(Visit visit, TIntermBlock *node) override; |
55 | |
56 | private: |
57 | void scalarizeArgs(TIntermAggregate *aggregate, bool scalarizeVector, bool scalarizeMatrix); |
58 | |
59 | // If we have the following code: |
60 | // mat4 m(0); |
61 | // vec4 v(1, m); |
62 | // We will rewrite to: |
63 | // mat4 m(0); |
64 | // mat4 s0 = m; |
65 | // vec4 v(1, s0[0][0], s0[0][1], s0[0][2]); |
66 | // This function is to create nodes for "mat4 s0 = m;" and insert it to the code sequence. This |
67 | // way the possible side effects of the constructor argument will only be evaluated once. |
68 | TVariable *createTempVariable(TIntermTyped *original); |
69 | |
70 | std::vector<TIntermSequence> mBlockStack; |
71 | |
72 | sh::GLenum mShaderType; |
73 | bool mFragmentPrecisionHigh; |
74 | |
75 | IntermNodePatternMatcher mNodesToScalarize; |
76 | }; |
77 | |
78 | bool ScalarizeArgsTraverser::visitAggregate(Visit visit, TIntermAggregate *node) |
79 | { |
80 | ASSERT(visit == PreVisit); |
81 | if (mNodesToScalarize.match(node, getParentNode())) |
82 | { |
83 | if (node->getType().isVector()) |
84 | { |
85 | scalarizeArgs(node, false, true); |
86 | } |
87 | else |
88 | { |
89 | ASSERT(node->getType().isMatrix()); |
90 | scalarizeArgs(node, true, false); |
91 | } |
92 | } |
93 | return true; |
94 | } |
95 | |
96 | bool ScalarizeArgsTraverser::visitBlock(Visit visit, TIntermBlock *node) |
97 | { |
98 | mBlockStack.push_back(TIntermSequence()); |
99 | { |
100 | for (TIntermNode *child : *node->getSequence()) |
101 | { |
102 | ASSERT(child != nullptr); |
103 | child->traverse(this); |
104 | mBlockStack.back().push_back(child); |
105 | } |
106 | } |
107 | if (mBlockStack.back().size() > node->getSequence()->size()) |
108 | { |
109 | node->getSequence()->clear(); |
110 | *(node->getSequence()) = mBlockStack.back(); |
111 | } |
112 | mBlockStack.pop_back(); |
113 | return false; |
114 | } |
115 | |
116 | void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate, |
117 | bool scalarizeVector, |
118 | bool scalarizeMatrix) |
119 | { |
120 | ASSERT(aggregate); |
121 | ASSERT(!aggregate->isArray()); |
122 | int size = static_cast<int>(aggregate->getType().getObjectSize()); |
123 | TIntermSequence *sequence = aggregate->getSequence(); |
124 | TIntermSequence originalArgs(*sequence); |
125 | sequence->clear(); |
126 | for (TIntermNode *originalArgNode : originalArgs) |
127 | { |
128 | ASSERT(size > 0); |
129 | TIntermTyped *originalArg = originalArgNode->getAsTyped(); |
130 | ASSERT(originalArg); |
131 | TVariable *argVariable = createTempVariable(originalArg); |
132 | if (originalArg->isScalar()) |
133 | { |
134 | sequence->push_back(CreateTempSymbolNode(argVariable)); |
135 | size--; |
136 | } |
137 | else if (originalArg->isVector()) |
138 | { |
139 | if (scalarizeVector) |
140 | { |
141 | int repeat = std::min(size, originalArg->getNominalSize()); |
142 | size -= repeat; |
143 | for (int index = 0; index < repeat; ++index) |
144 | { |
145 | TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable); |
146 | TIntermBinary *newNode = ConstructVectorIndexBinaryNode(symbolNode, index); |
147 | sequence->push_back(newNode); |
148 | } |
149 | } |
150 | else |
151 | { |
152 | TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable); |
153 | sequence->push_back(symbolNode); |
154 | size -= originalArg->getNominalSize(); |
155 | } |
156 | } |
157 | else |
158 | { |
159 | ASSERT(originalArg->isMatrix()); |
160 | if (scalarizeMatrix) |
161 | { |
162 | int colIndex = 0, rowIndex = 0; |
163 | int repeat = std::min(size, originalArg->getCols() * originalArg->getRows()); |
164 | size -= repeat; |
165 | while (repeat > 0) |
166 | { |
167 | TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable); |
168 | TIntermBinary *newNode = |
169 | ConstructMatrixIndexBinaryNode(symbolNode, colIndex, rowIndex); |
170 | sequence->push_back(newNode); |
171 | rowIndex++; |
172 | if (rowIndex >= originalArg->getRows()) |
173 | { |
174 | rowIndex = 0; |
175 | colIndex++; |
176 | } |
177 | repeat--; |
178 | } |
179 | } |
180 | else |
181 | { |
182 | TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable); |
183 | sequence->push_back(symbolNode); |
184 | size -= originalArg->getCols() * originalArg->getRows(); |
185 | } |
186 | } |
187 | } |
188 | } |
189 | |
190 | TVariable *ScalarizeArgsTraverser::createTempVariable(TIntermTyped *original) |
191 | { |
192 | ASSERT(original); |
193 | |
194 | TType *type = new TType(original->getType()); |
195 | type->setQualifier(EvqTemporary); |
196 | if (mShaderType == GL_FRAGMENT_SHADER && type->getBasicType() == EbtFloat && |
197 | type->getPrecision() == EbpUndefined) |
198 | { |
199 | // We use the highest available precision for the temporary variable |
200 | // to avoid computing the actual precision using the rules defined |
201 | // in GLSL ES 1.0 Section 4.5.2. |
202 | type->setPrecision(mFragmentPrecisionHigh ? EbpHigh : EbpMedium); |
203 | } |
204 | |
205 | TVariable *variable = CreateTempVariable(mSymbolTable, type); |
206 | |
207 | ASSERT(mBlockStack.size() > 0); |
208 | TIntermSequence &sequence = mBlockStack.back(); |
209 | TIntermDeclaration *declaration = CreateTempInitDeclarationNode(variable, original); |
210 | sequence.push_back(declaration); |
211 | |
212 | return variable; |
213 | } |
214 | |
215 | } // namespace |
216 | |
217 | void ScalarizeVecAndMatConstructorArgs(TIntermBlock *root, |
218 | sh::GLenum shaderType, |
219 | bool fragmentPrecisionHigh, |
220 | TSymbolTable *symbolTable) |
221 | { |
222 | ScalarizeArgsTraverser scalarizer(shaderType, fragmentPrecisionHigh, symbolTable); |
223 | root->traverse(&scalarizer); |
224 | } |
225 | |
226 | } // namespace sh |
227 | |