fold: sqrt(x * x * y) -> fabs(x) * sqrt(y)
[oota-llvm.git] / lib / Transforms / Utils / SimplifyLibCalls.cpp
index 9fac7ef540eeea799a6232d901934e70b0bbc173..c3e2f3aec0065d02f85edbacbaf8e15cca83cada 100644 (file)
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/Support/Allocator.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Target/TargetLibraryInfo.h"
 #include "llvm/Transforms/Utils/BuildLibCalls.h"
 
 using namespace llvm;
+using namespace PatternMatch;
 
 static cl::opt<bool>
     ColdErrorCalls("error-reporting-is-cold", cl::init(true), cl::Hidden,
@@ -1254,6 +1256,85 @@ Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) {
   return Ret;
 }
 
+Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
+  Function *Callee = CI->getCalledFunction();
+  
+  Value *Ret = nullptr;
+  if (UnsafeFPShrink && Callee->getName() == "sqrt" &&
+      TLI->has(LibFunc::sqrtf)) {
+    Ret = optimizeUnaryDoubleFP(CI, B, true);
+  }
+
+  // FIXME: For finer-grain optimization, we need intrinsics to have the same
+  // fast-math flag decorations that are applied to FP instructions. For now,
+  // we have to rely on the function-level unsafe-fp-math attribute to do this
+  // optimization because there's no other way to express that the sqrt can be
+  // reassociated.
+  Function *F = CI->getParent()->getParent();
+  if (F->hasFnAttribute("unsafe-fp-math")) {
+    // Check for unsafe-fp-math = true.
+    Attribute Attr = F->getFnAttribute("unsafe-fp-math");
+    if (Attr.getValueAsString() != "true")
+      return Ret;
+  }
+  Value *Op = CI->getArgOperand(0);
+  if (Instruction *I = dyn_cast<Instruction>(Op)) {
+    if (I->getOpcode() == Instruction::FMul && I->hasUnsafeAlgebra()) {
+      // We're looking for a repeated factor in a multiplication tree,
+      // so we can do this fold: sqrt(x * x) -> fabs(x);
+      // or this fold: sqrt(x * x * y) -> fabs(x) * sqrt(y).
+      Value *Op0 = I->getOperand(0);
+      Value *Op1 = I->getOperand(1);
+      Value *RepeatOp = nullptr;
+      Value *OtherOp = nullptr;
+      if (Op0 == Op1) {
+        // Simple match: the operands of the multiply are identical.
+        RepeatOp = Op0;
+      } else {
+        // Look for a more complicated pattern: one of the operands is itself
+        // a multiply, so search for a common factor in that multiply.
+        // Note: We don't bother looking any deeper than this first level or for
+        // variations of this pattern because instcombine's visitFMUL and/or the
+        // reassociation pass should give us this form.
+        Value *OtherMul0, *OtherMul1;
+        if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) {
+          // Pattern: sqrt((x * y) * z)
+          if (OtherMul0 == OtherMul1) {
+            // Matched: sqrt((x * x) * z)
+            RepeatOp = OtherMul0;
+            OtherOp = Op1;
+          }
+        }
+      }
+      if (RepeatOp) {
+        // Fast math flags for any created instructions should match the sqrt
+        // and multiply.
+        // FIXME: We're not checking the sqrt because it doesn't have
+        // fast-math-flags (see earlier comment).
+        IRBuilder<true, ConstantFolder,
+          IRBuilderDefaultInserter<true> >::FastMathFlagGuard Guard(B);
+        B.SetFastMathFlags(I->getFastMathFlags());
+        // If we found a repeated factor, hoist it out of the square root and
+        // replace it with the fabs of that factor.
+        Module *M = Callee->getParent();
+        Type *ArgType = Op->getType();
+        Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType);
+        Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs");
+        if (OtherOp) {
+          // If we found a non-repeated factor, we still need to get its square
+          // root. We then multiply that by the value that was simplified out
+          // of the square root calculation.
+          Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType);
+          Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt");
+          return B.CreateFMul(FabsCall, SqrtCall);
+        }
+        return FabsCall;
+      }
+    }
+  }
+  return Ret;
+}
+
 static bool isTrigLibCall(CallInst *CI);
 static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg,
                              bool UseFloat, Value *&Sin, Value *&Cos,
@@ -1919,6 +2000,8 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
       return optimizeExp2(CI, Builder);
     case Intrinsic::fabs:
       return optimizeFabs(CI, Builder);
+    case Intrinsic::sqrt:
+      return optimizeSqrt(CI, Builder);
     default:
       return nullptr;
     }
@@ -1995,6 +2078,10 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
     case LibFunc::fabs:
     case LibFunc::fabsl:
       return optimizeFabs(CI, Builder);
+    case LibFunc::sqrtf:
+    case LibFunc::sqrt:
+    case LibFunc::sqrtl:
+      return optimizeSqrt(CI, Builder);
     case LibFunc::ffs:
     case LibFunc::ffsl:
     case LibFunc::ffsll:
@@ -2055,7 +2142,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
     case LibFunc::logb:
     case LibFunc::sin:
     case LibFunc::sinh:
-    case LibFunc::sqrt:
     case LibFunc::tan:
     case LibFunc::tanh:
       if (UnsafeFPShrink && hasFloatVersion(FuncName))