1//
2// Copyright (c) 2002-2014 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// Scalarize vector and matrix constructor args, so that vectors built from components don't have
7// matrix arguments, and matrices built from components don't have vector arguments. This avoids
8// driver bugs around vector and matrix constructors.
9//
10
11#include "compiler/translator/tree_ops/ScalarizeVecAndMatConstructorArgs.h"
12#include "common/debug.h"
13
14#include <algorithm>
15
16#include "angle_gl.h"
17#include "common/angleutils.h"
18#include "compiler/translator/tree_util/IntermNodePatternMatcher.h"
19#include "compiler/translator/tree_util/IntermNode_util.h"
20#include "compiler/translator/tree_util/IntermTraverse.h"
21
22namespace sh
23{
24
25namespace
26{
27
28TIntermBinary *ConstructVectorIndexBinaryNode(TIntermSymbol *symbolNode, int index)
29{
30 return new TIntermBinary(EOpIndexDirect, symbolNode, CreateIndexNode(index));
31}
32
33TIntermBinary *ConstructMatrixIndexBinaryNode(TIntermSymbol *symbolNode, int colIndex, int rowIndex)
34{
35 TIntermBinary *colVectorNode = ConstructVectorIndexBinaryNode(symbolNode, colIndex);
36
37 return new TIntermBinary(EOpIndexDirect, colVectorNode, CreateIndexNode(rowIndex));
38}
39
40class ScalarizeArgsTraverser : public TIntermTraverser
41{
42 public:
43 ScalarizeArgsTraverser(sh::GLenum shaderType,
44 bool fragmentPrecisionHigh,
45 TSymbolTable *symbolTable)
46 : TIntermTraverser(true, false, false, symbolTable),
47 mShaderType(shaderType),
48 mFragmentPrecisionHigh(fragmentPrecisionHigh),
49 mNodesToScalarize(IntermNodePatternMatcher::kScalarizedVecOrMatConstructor)
50 {}
51
52 protected:
53 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
54 bool visitBlock(Visit visit, TIntermBlock *node) override;
55
56 private:
57 void scalarizeArgs(TIntermAggregate *aggregate, bool scalarizeVector, bool scalarizeMatrix);
58
59 // If we have the following code:
60 // mat4 m(0);
61 // vec4 v(1, m);
62 // We will rewrite to:
63 // mat4 m(0);
64 // mat4 s0 = m;
65 // vec4 v(1, s0[0][0], s0[0][1], s0[0][2]);
66 // This function is to create nodes for "mat4 s0 = m;" and insert it to the code sequence. This
67 // way the possible side effects of the constructor argument will only be evaluated once.
68 TVariable *createTempVariable(TIntermTyped *original);
69
70 std::vector<TIntermSequence> mBlockStack;
71
72 sh::GLenum mShaderType;
73 bool mFragmentPrecisionHigh;
74
75 IntermNodePatternMatcher mNodesToScalarize;
76};
77
78bool ScalarizeArgsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
79{
80 ASSERT(visit == PreVisit);
81 if (mNodesToScalarize.match(node, getParentNode()))
82 {
83 if (node->getType().isVector())
84 {
85 scalarizeArgs(node, false, true);
86 }
87 else
88 {
89 ASSERT(node->getType().isMatrix());
90 scalarizeArgs(node, true, false);
91 }
92 }
93 return true;
94}
95
96bool ScalarizeArgsTraverser::visitBlock(Visit visit, TIntermBlock *node)
97{
98 mBlockStack.push_back(TIntermSequence());
99 {
100 for (TIntermNode *child : *node->getSequence())
101 {
102 ASSERT(child != nullptr);
103 child->traverse(this);
104 mBlockStack.back().push_back(child);
105 }
106 }
107 if (mBlockStack.back().size() > node->getSequence()->size())
108 {
109 node->getSequence()->clear();
110 *(node->getSequence()) = mBlockStack.back();
111 }
112 mBlockStack.pop_back();
113 return false;
114}
115
116void ScalarizeArgsTraverser::scalarizeArgs(TIntermAggregate *aggregate,
117 bool scalarizeVector,
118 bool scalarizeMatrix)
119{
120 ASSERT(aggregate);
121 ASSERT(!aggregate->isArray());
122 int size = static_cast<int>(aggregate->getType().getObjectSize());
123 TIntermSequence *sequence = aggregate->getSequence();
124 TIntermSequence originalArgs(*sequence);
125 sequence->clear();
126 for (TIntermNode *originalArgNode : originalArgs)
127 {
128 ASSERT(size > 0);
129 TIntermTyped *originalArg = originalArgNode->getAsTyped();
130 ASSERT(originalArg);
131 TVariable *argVariable = createTempVariable(originalArg);
132 if (originalArg->isScalar())
133 {
134 sequence->push_back(CreateTempSymbolNode(argVariable));
135 size--;
136 }
137 else if (originalArg->isVector())
138 {
139 if (scalarizeVector)
140 {
141 int repeat = std::min(size, originalArg->getNominalSize());
142 size -= repeat;
143 for (int index = 0; index < repeat; ++index)
144 {
145 TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
146 TIntermBinary *newNode = ConstructVectorIndexBinaryNode(symbolNode, index);
147 sequence->push_back(newNode);
148 }
149 }
150 else
151 {
152 TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
153 sequence->push_back(symbolNode);
154 size -= originalArg->getNominalSize();
155 }
156 }
157 else
158 {
159 ASSERT(originalArg->isMatrix());
160 if (scalarizeMatrix)
161 {
162 int colIndex = 0, rowIndex = 0;
163 int repeat = std::min(size, originalArg->getCols() * originalArg->getRows());
164 size -= repeat;
165 while (repeat > 0)
166 {
167 TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
168 TIntermBinary *newNode =
169 ConstructMatrixIndexBinaryNode(symbolNode, colIndex, rowIndex);
170 sequence->push_back(newNode);
171 rowIndex++;
172 if (rowIndex >= originalArg->getRows())
173 {
174 rowIndex = 0;
175 colIndex++;
176 }
177 repeat--;
178 }
179 }
180 else
181 {
182 TIntermSymbol *symbolNode = CreateTempSymbolNode(argVariable);
183 sequence->push_back(symbolNode);
184 size -= originalArg->getCols() * originalArg->getRows();
185 }
186 }
187 }
188}
189
190TVariable *ScalarizeArgsTraverser::createTempVariable(TIntermTyped *original)
191{
192 ASSERT(original);
193
194 TType *type = new TType(original->getType());
195 type->setQualifier(EvqTemporary);
196 if (mShaderType == GL_FRAGMENT_SHADER && type->getBasicType() == EbtFloat &&
197 type->getPrecision() == EbpUndefined)
198 {
199 // We use the highest available precision for the temporary variable
200 // to avoid computing the actual precision using the rules defined
201 // in GLSL ES 1.0 Section 4.5.2.
202 type->setPrecision(mFragmentPrecisionHigh ? EbpHigh : EbpMedium);
203 }
204
205 TVariable *variable = CreateTempVariable(mSymbolTable, type);
206
207 ASSERT(mBlockStack.size() > 0);
208 TIntermSequence &sequence = mBlockStack.back();
209 TIntermDeclaration *declaration = CreateTempInitDeclarationNode(variable, original);
210 sequence.push_back(declaration);
211
212 return variable;
213}
214
215} // namespace
216
217void ScalarizeVecAndMatConstructorArgs(TIntermBlock *root,
218 sh::GLenum shaderType,
219 bool fragmentPrecisionHigh,
220 TSymbolTable *symbolTable)
221{
222 ScalarizeArgsTraverser scalarizer(shaderType, fragmentPrecisionHigh, symbolTable);
223 root->traverse(&scalarizer);
224}
225
226} // namespace sh
227