1 | // |
2 | // Copyright (c) 2016 The ANGLE Project Authors. All rights reserved. |
3 | // Use of this source code is governed by a BSD-style license that can be |
4 | // found in the LICENSE file. |
5 | // |
6 | // Implementation of the integer pow expressions HLSL bug workaround. |
7 | // See header for more info. |
8 | |
9 | #include "compiler/translator/tree_ops/ExpandIntegerPowExpressions.h" |
10 | |
11 | #include <cmath> |
12 | #include <cstdlib> |
13 | |
14 | #include "compiler/translator/tree_util/IntermNode_util.h" |
15 | #include "compiler/translator/tree_util/IntermTraverse.h" |
16 | |
17 | namespace sh |
18 | { |
19 | |
20 | namespace |
21 | { |
22 | |
23 | class Traverser : public TIntermTraverser |
24 | { |
25 | public: |
26 | static void Apply(TIntermNode *root, TSymbolTable *symbolTable); |
27 | |
28 | private: |
29 | Traverser(TSymbolTable *symbolTable); |
30 | bool visitAggregate(Visit visit, TIntermAggregate *node) override; |
31 | void nextIteration(); |
32 | |
33 | bool mFound = false; |
34 | }; |
35 | |
36 | // static |
37 | void Traverser::Apply(TIntermNode *root, TSymbolTable *symbolTable) |
38 | { |
39 | Traverser traverser(symbolTable); |
40 | do |
41 | { |
42 | traverser.nextIteration(); |
43 | root->traverse(&traverser); |
44 | if (traverser.mFound) |
45 | { |
46 | traverser.updateTree(); |
47 | } |
48 | } while (traverser.mFound); |
49 | } |
50 | |
51 | Traverser::Traverser(TSymbolTable *symbolTable) : TIntermTraverser(true, false, false, symbolTable) |
52 | {} |
53 | |
54 | void Traverser::nextIteration() |
55 | { |
56 | mFound = false; |
57 | } |
58 | |
59 | bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node) |
60 | { |
61 | if (mFound) |
62 | { |
63 | return false; |
64 | } |
65 | |
66 | // Test 0: skip non-pow operators. |
67 | if (node->getOp() != EOpPow) |
68 | { |
69 | return true; |
70 | } |
71 | |
72 | const TIntermSequence *sequence = node->getSequence(); |
73 | ASSERT(sequence->size() == 2u); |
74 | const TIntermConstantUnion *constantExponent = sequence->at(1)->getAsConstantUnion(); |
75 | |
76 | // Test 1: check for a single constant. |
77 | if (!constantExponent || constantExponent->getNominalSize() != 1) |
78 | { |
79 | return true; |
80 | } |
81 | |
82 | ASSERT(constantExponent->getBasicType() == EbtFloat); |
83 | float exponentValue = constantExponent->getConstantValue()->getFConst(); |
84 | |
85 | // Test 2: exponentValue is in the problematic range. |
86 | if (exponentValue < -5.0f || exponentValue > 9.0f) |
87 | { |
88 | return true; |
89 | } |
90 | |
91 | // Test 3: exponentValue is integer or pretty close to an integer. |
92 | if (std::abs(exponentValue - std::round(exponentValue)) > 0.0001f) |
93 | { |
94 | return true; |
95 | } |
96 | |
97 | // Test 4: skip -1, 0, and 1 |
98 | int exponent = static_cast<int>(std::round(exponentValue)); |
99 | int n = std::abs(exponent); |
100 | if (n < 2) |
101 | { |
102 | return true; |
103 | } |
104 | |
105 | // Potential problem case detected, apply workaround. |
106 | |
107 | TIntermTyped *lhs = sequence->at(0)->getAsTyped(); |
108 | ASSERT(lhs); |
109 | |
110 | TIntermDeclaration *lhsVariableDeclaration = nullptr; |
111 | TVariable *lhsVariable = |
112 | DeclareTempVariable(mSymbolTable, lhs, EvqTemporary, &lhsVariableDeclaration); |
113 | insertStatementInParentBlock(lhsVariableDeclaration); |
114 | |
115 | // Create a chain of n-1 multiples. |
116 | TIntermTyped *current = CreateTempSymbolNode(lhsVariable); |
117 | for (int i = 1; i < n; ++i) |
118 | { |
119 | TIntermBinary *mul = new TIntermBinary(EOpMul, current, CreateTempSymbolNode(lhsVariable)); |
120 | mul->setLine(node->getLine()); |
121 | current = mul; |
122 | } |
123 | |
124 | // For negative pow, compute the reciprocal of the positive pow. |
125 | if (exponent < 0) |
126 | { |
127 | TConstantUnion *oneVal = new TConstantUnion(); |
128 | oneVal->setFConst(1.0f); |
129 | TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType()); |
130 | TIntermBinary *div = new TIntermBinary(EOpDiv, oneNode, current); |
131 | current = div; |
132 | } |
133 | |
134 | queueReplacement(current, OriginalNode::IS_DROPPED); |
135 | mFound = true; |
136 | return false; |
137 | } |
138 | |
139 | } // anonymous namespace |
140 | |
141 | void ExpandIntegerPowExpressions(TIntermNode *root, TSymbolTable *symbolTable) |
142 | { |
143 | Traverser::Apply(root, symbolTable); |
144 | } |
145 | |
146 | } // namespace sh |
147 | |