Merging r258325:
[oota-llvm.git] / lib / Transforms / Utils / SimplifyLibCalls.cpp
index ea49131e9377294944c13c070c2e96c43f0b504e..908b4bb6a654e99bb6ccdf4774873a93a9833d5f 100644 (file)
@@ -994,6 +994,10 @@ Value *LibCallSimplifier::optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B,
   Value *V = valueHasFloatPrecision(CI->getArgOperand(0));
   if (V == nullptr)
     return nullptr;
+  
+  // Propagate fast-math flags from the existing call to the new call.
+  IRBuilder<>::FastMathFlagGuard Guard(B);
+  B.setFastMathFlags(CI->getFastMathFlags());
 
   // floor((double)floatval) -> (double)floorf(floatval)
   if (Callee->isIntrinsic()) {
@@ -1029,6 +1033,10 @@ Value *LibCallSimplifier::optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) {
   if (V2 == nullptr)
     return nullptr;
 
+  // Propagate fast-math flags from the existing call to the new call.
+  IRBuilder<>::FastMathFlagGuard Guard(B);
+  B.setFastMathFlags(CI->getFastMathFlags());
+
   // fmin((double)floatval1, (double)floatval2)
   //                      -> (double)fminf(floatval1, floatval2)
   // TODO: Handle intrinsics in the same way as in optimizeUnaryDoubleFP().
@@ -1119,29 +1127,26 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
                                   Callee->getAttributes());
   }
 
+  // FIXME: Use instruction-level FMF.
   bool UnsafeFPMath = canUseUnsafeFPMath(CI->getParent()->getParent());
 
-  // pow(exp(x), y) -> exp(x*y)
+  // pow(exp(x), y) -> exp(x * y)
   // pow(exp2(x), y) -> exp2(x * y)
-  // We enable these only under fast-math. Besides rounding
-  // differences the transformation changes overflow and
-  // underflow behavior quite dramatically.
+  // We enable these only with fast-math. Besides rounding differences, the
+  // transformation changes overflow and underflow behavior quite dramatically.
   // Example: x = 1000, y = 0.001.
   // pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x*y) = exp(1).
-  if (UnsafeFPMath) {
-    if (auto *OpC = dyn_cast<CallInst>(Op1)) {
+  auto *OpC = dyn_cast<CallInst>(Op1);
+  if (OpC && OpC->hasUnsafeAlgebra() && CI->hasUnsafeAlgebra()) {
+    LibFunc::Func Func;
+    Function *OpCCallee = OpC->getCalledFunction();
+    if (OpCCallee && TLI->getLibFunc(OpCCallee->getName(), Func) &&
+        TLI->has(Func) && (Func == LibFunc::exp || Func == LibFunc::exp2)) {
       IRBuilder<>::FastMathFlagGuard Guard(B);
-      FastMathFlags FMF;
-      FMF.setUnsafeAlgebra();
-      B.SetFastMathFlags(FMF);
-
-      LibFunc::Func Func;
-      Function *OpCCallee = OpC->getCalledFunction();
-      if (OpCCallee && TLI->getLibFunc(OpCCallee->getName(), Func) &&
-          TLI->has(Func) && (Func == LibFunc::exp || Func == LibFunc::exp2))
-        return EmitUnaryFloatFnCall(
-            B.CreateFMul(OpC->getArgOperand(0), Op2, "mul"),
-            OpCCallee->getName(), B, OpCCallee->getAttributes());
+      B.setFastMathFlags(CI->getFastMathFlags());
+      Value *FMul = B.CreateFMul(OpC->getArgOperand(0), Op2, "mul");
+      return EmitUnaryFloatFnCall(FMul, OpCCallee->getName(), B,
+                                  OpCCallee->getAttributes());
     }
   }
 
@@ -1159,9 +1164,12 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
                       LibFunc::fabsl)) {
 
     // In -ffast-math, pow(x, 0.5) -> sqrt(x).
-    if (UnsafeFPMath)
+    if (CI->hasUnsafeAlgebra()) {
+      IRBuilder<>::FastMathFlagGuard Guard(B);
+      B.setFastMathFlags(CI->getFastMathFlags());
       return EmitUnaryFloatFnCall(Op1, TLI->getName(LibFunc::sqrt), B,
                                   Callee->getAttributes());
+    }
 
     // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))).
     // This is faster than calling pow, and still handles negative zero
@@ -1293,12 +1301,9 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) {
   // function, do that first.
   Function *Callee = CI->getCalledFunction();
   StringRef Name = Callee->getName();
-  if ((Name == "fmin" && hasFloatVersion(Name)) ||
-      (Name == "fmax" && hasFloatVersion(Name))) {
-    Value *Ret = optimizeBinaryDoubleFP(CI, B);
-    if (Ret)
+  if ((Name == "fmin" || Name == "fmax") && hasFloatVersion(Name))
+    if (Value *Ret = optimizeBinaryDoubleFP(CI, B))
       return Ret;
-  }
 
   // Make sure this has 2 arguments of FP type which match the result type.
   FunctionType *FT = Callee->getFunctionType();
@@ -1309,14 +1314,12 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) {
 
   IRBuilder<>::FastMathFlagGuard Guard(B);
   FastMathFlags FMF;
-  Function *F = CI->getParent()->getParent();
-  if (canUseUnsafeFPMath(F)) {
+  if (CI->hasUnsafeAlgebra()) {
     // Unsafe algebra sets all fast-math-flags to true.
     FMF.setUnsafeAlgebra();
   } else {
     // At a minimum, no-nans-fp-math must be true.
-    Attribute Attr = F->getFnAttribute("no-nans-fp-math");
-    if (Attr.getValueAsString() != "true")
+    if (!CI->hasNoNaNs())
       return nullptr;
     // No-signed-zeros is implied by the definitions of fmax/fmin themselves:
     // "Ideally, fmax would be sensitive to the sign of zero, for example
@@ -1325,7 +1328,7 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) {
     FMF.setNoSignedZeros();
     FMF.setNoNaNs();
   }
-  B.SetFastMathFlags(FMF);
+  B.setFastMathFlags(FMF);
 
   // We have a relaxed floating-point environment. We can ignore NaN-handling
   // and transform to a compare and select. We do not have to consider errno or
@@ -1351,11 +1354,13 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {
       !FT->getParamType(0)->isFloatingPointTy())
     return Ret;
 
-  if (!canUseUnsafeFPMath(CI->getParent()->getParent()))
+  if (!CI->hasUnsafeAlgebra())
     return Ret;
   Value *Op1 = CI->getArgOperand(0);
   auto *OpC = dyn_cast<CallInst>(Op1);
-  if (!OpC)
+
+  // The earlier call must also be unsafe in order to do these transforms.
+  if (!OpC || !OpC->hasUnsafeAlgebra())
     return Ret;
 
   // log(pow(x,y)) -> y*log(x)
@@ -1366,7 +1371,7 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {
   IRBuilder<>::FastMathFlagGuard Guard(B);
   FastMathFlags FMF;
   FMF.setUnsafeAlgebra();
-  B.SetFastMathFlags(FMF);
+  B.setFastMathFlags(FMF);
 
   LibFunc::Func Func;
   Function *F = OpC->getCalledFunction();
@@ -1389,71 +1394,81 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {
 
 Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
-  
+
   Value *Ret = nullptr;
   if (TLI->has(LibFunc::sqrtf) && (Callee->getName() == "sqrt" ||
                                    Callee->getIntrinsicID() == Intrinsic::sqrt))
     Ret = optimizeUnaryDoubleFP(CI, B, true);
-  if (!canUseUnsafeFPMath(CI->getParent()->getParent()))
+
+  // FIXME: Refactor - this check is repeated all over this file and even in the
+  // preceding call to shrink double -> float.
+
+  // Make sure this has 1 argument of FP type, which matches the result type.
+  FunctionType *FT = Callee->getFunctionType();
+  if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) ||
+      !FT->getParamType(0)->isFloatingPointTy())
     return Ret;
 
-  Value *Op = CI->getArgOperand(0);
-  if (Instruction *I = dyn_cast<Instruction>(Op)) {
-    if (I->getOpcode() == Instruction::FMul && I->hasUnsafeAlgebra()) {
-      // We're looking for a repeated factor in a multiplication tree,
-      // so we can do this fold: sqrt(x * x) -> fabs(x);
-      // or this fold: sqrt(x * x * y) -> fabs(x) * sqrt(y).
-      Value *Op0 = I->getOperand(0);
-      Value *Op1 = I->getOperand(1);
-      Value *RepeatOp = nullptr;
-      Value *OtherOp = nullptr;
-      if (Op0 == Op1) {
-        // Simple match: the operands of the multiply are identical.
-        RepeatOp = Op0;
-      } else {
-        // Look for a more complicated pattern: one of the operands is itself
-        // a multiply, so search for a common factor in that multiply.
-        // Note: We don't bother looking any deeper than this first level or for
-        // variations of this pattern because instcombine's visitFMUL and/or the
-        // reassociation pass should give us this form.
-        Value *OtherMul0, *OtherMul1;
-        if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) {
-          // Pattern: sqrt((x * y) * z)
-          if (OtherMul0 == OtherMul1) {
-            // Matched: sqrt((x * x) * z)
-            RepeatOp = OtherMul0;
-            OtherOp = Op1;
-          }
-        }
-      }
-      if (RepeatOp) {
-        // Fast math flags for any created instructions should match the sqrt
-        // and multiply.
-        // FIXME: We're not checking the sqrt because it doesn't have
-        // fast-math-flags (see earlier comment).
-        IRBuilder<>::FastMathFlagGuard Guard(B);
-        B.SetFastMathFlags(I->getFastMathFlags());
-        // If we found a repeated factor, hoist it out of the square root and
-        // replace it with the fabs of that factor.
-        Module *M = Callee->getParent();
-        Type *ArgType = Op->getType();
-        Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType);
-        Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs");
-        if (OtherOp) {
-          // If we found a non-repeated factor, we still need to get its square
-          // root. We then multiply that by the value that was simplified out
-          // of the square root calculation.
-          Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType);
-          Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt");
-          return B.CreateFMul(FabsCall, SqrtCall);
-        }
-        return FabsCall;
+  if (!CI->hasUnsafeAlgebra())
+    return Ret;
+
+  Instruction *I = dyn_cast<Instruction>(CI->getArgOperand(0));
+  if (!I || I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra())
+    return Ret;
+
+  // We're looking for a repeated factor in a multiplication tree,
+  // so we can do this fold: sqrt(x * x) -> fabs(x);
+  // or this fold: sqrt((x * x) * y) -> fabs(x) * sqrt(y).
+  Value *Op0 = I->getOperand(0);
+  Value *Op1 = I->getOperand(1);
+  Value *RepeatOp = nullptr;
+  Value *OtherOp = nullptr;
+  if (Op0 == Op1) {
+    // Simple match: the operands of the multiply are identical.
+    RepeatOp = Op0;
+  } else {
+    // Look for a more complicated pattern: one of the operands is itself
+    // a multiply, so search for a common factor in that multiply.
+    // Note: We don't bother looking any deeper than this first level or for
+    // variations of this pattern because instcombine's visitFMUL and/or the
+    // reassociation pass should give us this form.
+    Value *OtherMul0, *OtherMul1;
+    if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) {
+      // Pattern: sqrt((x * y) * z)
+      if (OtherMul0 == OtherMul1 &&
+          cast<Instruction>(Op0)->hasUnsafeAlgebra()) {
+        // Matched: sqrt((x * x) * z)
+        RepeatOp = OtherMul0;
+        OtherOp = Op1;
       }
     }
   }
-  return Ret;
-}
+  if (!RepeatOp)
+    return Ret;
 
+  // Fast math flags for any created instructions should match the sqrt
+  // and multiply.
+  IRBuilder<>::FastMathFlagGuard Guard(B);
+  B.setFastMathFlags(I->getFastMathFlags());
+
+  // If we found a repeated factor, hoist it out of the square root and
+  // replace it with the fabs of that factor.
+  Module *M = Callee->getParent();
+  Type *ArgType = I->getType();
+  Value *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, ArgType);
+  Value *FabsCall = B.CreateCall(Fabs, RepeatOp, "fabs");
+  if (OtherOp) {
+    // If we found a non-repeated factor, we still need to get its square
+    // root. We then multiply that by the value that was simplified out
+    // of the square root calculation.
+    Value *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, ArgType);
+    Value *SqrtCall = B.CreateCall(Sqrt, OtherOp, "sqrt");
+    return B.CreateFMul(FabsCall, SqrtCall);
+  }
+  return FabsCall;
+}
+
+// TODO: Generalize to handle any trig function and its inverse.
 Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
   Value *Ret = nullptr;
@@ -1468,13 +1483,15 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilder<> &B) {
       !FT->getParamType(0)->isFloatingPointTy())
     return Ret;
 
-  if (!canUseUnsafeFPMath(CI->getParent()->getParent()))
-    return Ret;
   Value *Op1 = CI->getArgOperand(0);
   auto *OpC = dyn_cast<CallInst>(Op1);
   if (!OpC)
     return Ret;
 
+  // Both calls must allow unsafe optimizations in order to remove them.
+  if (!CI->hasUnsafeAlgebra() || !OpC->hasUnsafeAlgebra())
+    return Ret;
+
   // tan(atan(x)) -> x
   // tanf(atanf(x)) -> x
   // tanl(atanl(x)) -> x
@@ -2171,7 +2188,10 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
   LibFunc::Func Func;
   Function *Callee = CI->getCalledFunction();
   StringRef FuncName = Callee->getName();
-  IRBuilder<> Builder(CI);
+
+  SmallVector<OperandBundleDef, 2> OpBundles;
+  CI->getOperandBundlesAsDefs(OpBundles);
+  IRBuilder<> Builder(CI, /*FPMathTag=*/nullptr, OpBundles);
   bool isCallingConvC = CI->getCallingConv() == llvm::CallingConv::C;
 
   // Command-line parameter overrides function attribute.
@@ -2544,7 +2564,10 @@ Value *FortifiedLibCallSimplifier::optimizeCall(CallInst *CI) {
   LibFunc::Func Func;
   Function *Callee = CI->getCalledFunction();
   StringRef FuncName = Callee->getName();
-  IRBuilder<> Builder(CI);
+
+  SmallVector<OperandBundleDef, 2> OpBundles;
+  CI->getOperandBundlesAsDefs(OpBundles);
+  IRBuilder<> Builder(CI, /*FPMathTag=*/nullptr, OpBundles);
   bool isCallingConvC = CI->getCallingConv() == llvm::CallingConv::C;
 
   // First, check that this is a known library functions.