summaryrefslogtreecommitdiffstats
path: root/gfx/angle/checkout/src/compiler/translator/tree_ops/d3d/ExpandIntegerPowExpressions.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'gfx/angle/checkout/src/compiler/translator/tree_ops/d3d/ExpandIntegerPowExpressions.cpp')
-rw-r--r--gfx/angle/checkout/src/compiler/translator/tree_ops/d3d/ExpandIntegerPowExpressions.cpp152
1 files changed, 152 insertions, 0 deletions
diff --git a/gfx/angle/checkout/src/compiler/translator/tree_ops/d3d/ExpandIntegerPowExpressions.cpp b/gfx/angle/checkout/src/compiler/translator/tree_ops/d3d/ExpandIntegerPowExpressions.cpp
new file mode 100644
index 0000000000..e873db56cd
--- /dev/null
+++ b/gfx/angle/checkout/src/compiler/translator/tree_ops/d3d/ExpandIntegerPowExpressions.cpp
@@ -0,0 +1,152 @@
+//
+// Copyright 2016 The ANGLE Project Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// Implementation of the integer pow expressions HLSL bug workaround.
+// See header for more info.
+
+#include "compiler/translator/tree_ops/d3d/ExpandIntegerPowExpressions.h"
+
+#include <cmath>
+#include <cstdlib>
+
+#include "compiler/translator/tree_util/IntermNode_util.h"
+#include "compiler/translator/tree_util/IntermTraverse.h"
+
+namespace sh
+{
+
+namespace
+{
+
+class Traverser : public TIntermTraverser
+{
+ public:
+ [[nodiscard]] static bool Apply(TCompiler *compiler,
+ TIntermNode *root,
+ TSymbolTable *symbolTable);
+
+ private:
+ Traverser(TSymbolTable *symbolTable);
+ bool visitAggregate(Visit visit, TIntermAggregate *node) override;
+ void nextIteration();
+
+ bool mFound = false;
+};
+
+// static
+bool Traverser::Apply(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable)
+{
+ Traverser traverser(symbolTable);
+ do
+ {
+ traverser.nextIteration();
+ root->traverse(&traverser);
+ if (traverser.mFound)
+ {
+ if (!traverser.updateTree(compiler, root))
+ {
+ return false;
+ }
+ }
+ } while (traverser.mFound);
+
+ return true;
+}
+
+Traverser::Traverser(TSymbolTable *symbolTable) : TIntermTraverser(true, false, false, symbolTable)
+{}
+
+void Traverser::nextIteration()
+{
+ mFound = false;
+}
+
+bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
+{
+ if (mFound)
+ {
+ return false;
+ }
+
+ // Test 0: skip non-pow operators.
+ if (node->getOp() != EOpPow)
+ {
+ return true;
+ }
+
+ const TIntermSequence *sequence = node->getSequence();
+ ASSERT(sequence->size() == 2u);
+ const TIntermConstantUnion *constantExponent = sequence->at(1)->getAsConstantUnion();
+
+ // Test 1: check for a single constant.
+ if (!constantExponent || constantExponent->getNominalSize() != 1)
+ {
+ return true;
+ }
+
+ float exponentValue = constantExponent->getConstantValue()->getFConst();
+
+ // Test 2: exponentValue is in the problematic range.
+ if (exponentValue < -5.0f || exponentValue > 9.0f)
+ {
+ return true;
+ }
+
+ // Test 3: exponentValue is integer or pretty close to an integer.
+ if (std::abs(exponentValue - std::round(exponentValue)) > 0.0001f)
+ {
+ return true;
+ }
+
+ // Test 4: skip -1, 0, and 1
+ int exponent = static_cast<int>(std::round(exponentValue));
+ int n = std::abs(exponent);
+ if (n < 2)
+ {
+ return true;
+ }
+
+ // Potential problem case detected, apply workaround.
+
+ TIntermTyped *lhs = sequence->at(0)->getAsTyped();
+ ASSERT(lhs);
+
+ TIntermDeclaration *lhsVariableDeclaration = nullptr;
+ TVariable *lhsVariable =
+ DeclareTempVariable(mSymbolTable, lhs, EvqTemporary, &lhsVariableDeclaration);
+ insertStatementInParentBlock(lhsVariableDeclaration);
+
+ // Create a chain of n-1 multiples.
+ TIntermTyped *current = CreateTempSymbolNode(lhsVariable);
+ for (int i = 1; i < n; ++i)
+ {
+ TIntermBinary *mul = new TIntermBinary(EOpMul, current, CreateTempSymbolNode(lhsVariable));
+ mul->setLine(node->getLine());
+ current = mul;
+ }
+
+ // For negative pow, compute the reciprocal of the positive pow.
+ if (exponent < 0)
+ {
+ TConstantUnion *oneVal = new TConstantUnion();
+ oneVal->setFConst(1.0f);
+ TIntermConstantUnion *oneNode = new TIntermConstantUnion(oneVal, node->getType());
+ TIntermBinary *div = new TIntermBinary(EOpDiv, oneNode, current);
+ current = div;
+ }
+
+ queueReplacement(current, OriginalNode::IS_DROPPED);
+ mFound = true;
+ return false;
+}
+
+} // anonymous namespace
+
+bool ExpandIntegerPowExpressions(TCompiler *compiler, TIntermNode *root, TSymbolTable *symbolTable)
+{
+ return Traverser::Apply(compiler, root, symbolTable);
+}
+
+} // namespace sh