1//
2// Copyright (c) 2016 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6// Implementation of the integer pow expressions HLSL bug workaround.
7// See header for more info.
8
9#include "compiler/translator/tree_ops/ExpandIntegerPowExpressions.h"
10
11#include <cmath>
12#include <cstdlib>
13
14#include "compiler/translator/tree_util/IntermNode_util.h"
15#include "compiler/translator/tree_util/IntermTraverse.h"
16
17namespace sh
18{
19
20namespace
21{
22
23class Traverser : public TIntermTraverser
24{
25 public:
26 static void Apply(TIntermNode *root, TSymbolTable *symbolTable);
27
28 private:
29 Traverser(TSymbolTable *symbolTable);
30 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
31 void nextIteration();
32
33 bool mFound = false;
34};
35
36// static
37void Traverser::Apply(TIntermNode *root, TSymbolTable *symbolTable)
38{
39 Traverser traverser(symbolTable);
40 do
41 {
42 traverser.nextIteration();
43 root->traverse(&traverser);
44 if (traverser.mFound)
45 {
46 traverser.updateTree();
47 }
48 } while (traverser.mFound);
49}
50
51Traverser::Traverser(TSymbolTable *symbolTable) : TIntermTraverser(true, false, false, symbolTable)
52{}
53
54void Traverser::nextIteration()
55{
56 mFound = false;
57}
58
59bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
60{
61 if (mFound)
62 {
63 return false;
64 }
65
66 // Test 0: skip non-pow operators.
67 if (node->getOp() != EOpPow)
68 {
69 return true;
70 }
71
72 const TIntermSequence *sequence = node->getSequence();
73 ASSERT(sequence->size() == 2u);
74 const TIntermConstantUnion *constantExponent = sequence->at(1)->getAsConstantUnion();
75
76 // Test 1: check for a single constant.
77 if (!constantExponent || constantExponent->getNominalSize() != 1)
78 {
79 return true;
80 }
81
82 ASSERT(constantExponent->getBasicType() == EbtFloat);
83 float exponentValue = constantExponent->getConstantValue()->getFConst();
84
85 // Test 2: exponentValue is in the problematic range.
86 if (exponentValue < -5.0f || exponentValue > 9.0f)
87 {
88 return true;
89 }
90
91 // Test 3: exponentValue is integer or pretty close to an integer.
92 if (std::abs(exponentValue - std::round(exponentValue)) > 0.0001f)
93 {
94 return true;
95 }
96
97 // Test 4: skip -1, 0, and 1
98 int exponent = static_cast<int>(std::round(exponentValue));
99 int n = std::abs(exponent);
100 if (n < 2)
101 {
102 return true;
103 }
104
105 // Potential problem case detected, apply workaround.
106
107 TIntermTyped *lhs = sequence->at(0)->getAsTyped();
108 ASSERT(lhs);
109
110 TIntermDeclaration *lhsVariableDeclaration = nullptr;
111 TVariable *lhsVariable =
112 DeclareTempVariable(mSymbolTable, lhs, EvqTemporary, &lhsVariableDeclaration);
113 insertStatementInParentBlock(lhsVariableDeclaration);
114
115 // Create a chain of n-1 multiples.
116 TIntermTyped *current = CreateTempSymbolNode(lhsVariable);
117 for (int i = 1; i < n; ++i)
118 {
119 TIntermBinary *mul = new TIntermBinary(EOpMul, current, CreateTempSymbolNode(lhsVariable));
120 mul->setLine(node->getLine());
121 current = mul;
122 }
123
124 // For negative pow, compute the reciprocal of the positive pow.
125 if (exponent < 0)
126 {
127 TConstantUnion *oneVal = new TConstantUnion();
128 oneVal->setFConst(1.0f);
129 TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType());
130 TIntermBinary *div = new TIntermBinary(EOpDiv, oneNode, current);
131 current = div;
132 }
133
134 queueReplacement(current, OriginalNode::IS_DROPPED);
135 mFound = true;
136 return false;
137}
138
139} // anonymous namespace
140
141void ExpandIntegerPowExpressions(TIntermNode *root, TSymbolTable *symbolTable)
142{
143 Traverser::Apply(root, symbolTable);
144}
145
146} // namespace sh
147