Merging r258416 and r258428:
[oota-llvm.git] / lib / Transforms / Utils / SimplifyLibCalls.cpp
index cf87ac1cf28bfead14b80bd1656c5bc7c47366b9..2f3c31128cf031e785d2415a26e5265990968009 100644 (file)
@@ -970,15 +970,34 @@ static Value *valueHasFloatPrecision(Value *Val) {
   return nullptr;
 }
 
-//===----------------------------------------------------------------------===//
-// Double -> Float Shrinking Optimizations for Unary Functions like 'floor'
+/// Any floating-point library function that we're trying to simplify will have
+/// a signature of the form: fptype foo(fptype param1, fptype param2, ...).
+/// CheckDoubleTy indicates that 'fptype' must be 'double'.
+static bool matchesFPLibFunctionSignature(const Function *F, unsigned NumParams,
+                                          bool CheckDoubleTy) {
+  FunctionType *FT = F->getFunctionType();
+  if (FT->getNumParams() != NumParams)
+    return false;
+
+  // The return type must match what we're looking for.
+  Type *RetTy = FT->getReturnType();
+  if (CheckDoubleTy ? !RetTy->isDoubleTy() : !RetTy->isFloatingPointTy())
+    return false;
 
-Value *LibCallSimplifier::optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B,
-                                                bool CheckRetType) {
+  // Each parameter must match the return type, and therefore, match every other
+  // parameter too.
+  for (const Type *ParamTy : FT->params())
+    if (ParamTy != RetTy)
+      return false;
+
+  return true;
+}
+
+/// Shrink double -> float for unary functions like 'floor'.
+static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B,
+                                    bool CheckRetType) {
   Function *Callee = CI->getCalledFunction();
-  FunctionType *FT = Callee->getFunctionType();
-  if (FT->getNumParams() != 1 || !FT->getReturnType()->isDoubleTy() ||
-      !FT->getParamType(0)->isDoubleTy())
+  if (!matchesFPLibFunctionSignature(Callee, 1, true))
     return nullptr;
 
   if (CheckRetType) {
@@ -997,7 +1016,7 @@ Value *LibCallSimplifier::optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B,
   
   // Propagate fast-math flags from the existing call to the new call.
   IRBuilder<>::FastMathFlagGuard Guard(B);
-  B.SetFastMathFlags(CI->getFastMathFlags());
+  B.setFastMathFlags(CI->getFastMathFlags());
 
   // floor((double)floatval) -> (double)floorf(floatval)
   if (Callee->isIntrinsic()) {
@@ -1013,15 +1032,10 @@ Value *LibCallSimplifier::optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B,
   return B.CreateFPExt(V, B.getDoubleTy());
 }
 
-// Double -> Float Shrinking Optimizations for Binary Functions like 'fmin/fmax'
-Value *LibCallSimplifier::optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) {
+/// Shrink double -> float for binary functions like 'fmin/fmax'.
+static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
-  FunctionType *FT = Callee->getFunctionType();
-  // Just make sure this has 2 arguments of the same FP type, which match the
-  // result type.
-  if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) ||
-      FT->getParamType(0) != FT->getParamType(1) ||
-      !FT->getParamType(0)->isFloatingPointTy())
+  if (!matchesFPLibFunctionSignature(Callee, 2, true))
     return nullptr;
 
   // If this is something like 'fmin((double)floatval1, (double)floatval2)',
@@ -1035,7 +1049,7 @@ Value *LibCallSimplifier::optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) {
 
   // Propagate fast-math flags from the existing call to the new call.
   IRBuilder<>::FastMathFlagGuard Guard(B);
-  B.SetFastMathFlags(CI->getFastMathFlags());
+  B.setFastMathFlags(CI->getFastMathFlags());
 
   // fmin((double)floatval1, (double)floatval2)
   //                      -> (double)fminf(floatval1, floatval2)
@@ -1127,29 +1141,26 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
                                   Callee->getAttributes());
   }
 
+  // FIXME: Use instruction-level FMF.
   bool UnsafeFPMath = canUseUnsafeFPMath(CI->getParent()->getParent());
 
-  // pow(exp(x), y) -> exp(x*y)
+  // pow(exp(x), y) -> exp(x * y)
   // pow(exp2(x), y) -> exp2(x * y)
-  // We enable these only under fast-math. Besides rounding
-  // differences the transformation changes overflow and
-  // underflow behavior quite dramatically.
+  // We enable these only with fast-math. Besides rounding differences, the
+  // transformation changes overflow and underflow behavior quite dramatically.
   // Example: x = 1000, y = 0.001.
   // pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x*y) = exp(1).
-  if (UnsafeFPMath) {
-    if (auto *OpC = dyn_cast<CallInst>(Op1)) {
+  auto *OpC = dyn_cast<CallInst>(Op1);
+  if (OpC && OpC->hasUnsafeAlgebra() && CI->hasUnsafeAlgebra()) {
+    LibFunc::Func Func;
+    Function *OpCCallee = OpC->getCalledFunction();
+    if (OpCCallee && TLI->getLibFunc(OpCCallee->getName(), Func) &&
+        TLI->has(Func) && (Func == LibFunc::exp || Func == LibFunc::exp2)) {
       IRBuilder<>::FastMathFlagGuard Guard(B);
-      FastMathFlags FMF;
-      FMF.setUnsafeAlgebra();
-      B.SetFastMathFlags(FMF);
-
-      LibFunc::Func Func;
-      Function *OpCCallee = OpC->getCalledFunction();
-      if (OpCCallee && TLI->getLibFunc(OpCCallee->getName(), Func) &&
-          TLI->has(Func) && (Func == LibFunc::exp || Func == LibFunc::exp2))
-        return EmitUnaryFloatFnCall(
-            B.CreateFMul(OpC->getArgOperand(0), Op2, "mul"),
-            OpCCallee->getName(), B, OpCCallee->getAttributes());
+      B.setFastMathFlags(CI->getFastMathFlags());
+      Value *FMul = B.CreateFMul(OpC->getArgOperand(0), Op2, "mul");
+      return EmitUnaryFloatFnCall(FMul, OpCCallee->getName(), B,
+                                  OpCCallee->getAttributes());
     }
   }
 
@@ -1167,9 +1178,12 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
                       LibFunc::fabsl)) {
 
     // In -ffast-math, pow(x, 0.5) -> sqrt(x).
-    if (UnsafeFPMath)
+    if (CI->hasUnsafeAlgebra()) {
+      IRBuilder<>::FastMathFlagGuard Guard(B);
+      B.setFastMathFlags(CI->getFastMathFlags());
       return EmitUnaryFloatFnCall(Op1, TLI->getName(LibFunc::sqrt), B,
                                   Callee->getAttributes());
+    }
 
     // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))).
     // This is faster than calling pow, and still handles negative zero
@@ -1328,7 +1342,7 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) {
     FMF.setNoSignedZeros();
     FMF.setNoNaNs();
   }
-  B.SetFastMathFlags(FMF);
+  B.setFastMathFlags(FMF);
 
   // We have a relaxed floating-point environment. We can ignore NaN-handling
   // and transform to a compare and select. We do not have to consider errno or
@@ -1354,11 +1368,13 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {
       !FT->getParamType(0)->isFloatingPointTy())
     return Ret;
 
-  if (!canUseUnsafeFPMath(CI->getParent()->getParent()))
+  if (!CI->hasUnsafeAlgebra())
     return Ret;
   Value *Op1 = CI->getArgOperand(0);
   auto *OpC = dyn_cast<CallInst>(Op1);
-  if (!OpC)
+
+  // The earlier call must also be unsafe in order to do these transforms.
+  if (!OpC || !OpC->hasUnsafeAlgebra())
     return Ret;
 
   // log(pow(x,y)) -> y*log(x)
@@ -1369,7 +1385,7 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {
   IRBuilder<>::FastMathFlagGuard Guard(B);
   FastMathFlags FMF;
   FMF.setUnsafeAlgebra();
-  B.SetFastMathFlags(FMF);
+  B.setFastMathFlags(FMF);
 
   LibFunc::Func Func;
   Function *F = OpC->getCalledFunction();
@@ -1392,12 +1408,21 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {
 
 Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
-  
+
   Value *Ret = nullptr;
   if (TLI->has(LibFunc::sqrtf) && (Callee->getName() == "sqrt" ||
                                    Callee->getIntrinsicID() == Intrinsic::sqrt))
     Ret = optimizeUnaryDoubleFP(CI, B, true);
 
+  // FIXME: Refactor - this check is repeated all over this file and even in the
+  // preceding call to shrink double -> float.
+
+  // Make sure this has 1 argument of FP type, which matches the result type.
+  FunctionType *FT = Callee->getFunctionType();
+  if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) ||
+      !FT->getParamType(0)->isFloatingPointTy())
+    return Ret;
+
   if (!CI->hasUnsafeAlgebra())
     return Ret;
 
@@ -1422,10 +1447,10 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
     // variations of this pattern because instcombine's visitFMUL and/or the
     // reassociation pass should give us this form.
     Value *OtherMul0, *OtherMul1;
-    // FIXME: This multiply must be unsafe to allow this transform.
     if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) {
       // Pattern: sqrt((x * y) * z)
-      if (OtherMul0 == OtherMul1) {
+      if (OtherMul0 == OtherMul1 &&
+          cast<Instruction>(Op0)->hasUnsafeAlgebra()) {
         // Matched: sqrt((x * x) * z)
         RepeatOp = OtherMul0;
         OtherOp = Op1;
@@ -1438,7 +1463,8 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
   // Fast math flags for any created instructions should match the sqrt
   // and multiply.
   IRBuilder<>::FastMathFlagGuard Guard(B);
-  B.SetFastMathFlags(I->getFastMathFlags());
+  B.setFastMathFlags(I->getFastMathFlags());
+
   // If we found a repeated factor, hoist it out of the square root and
   // replace it with the fabs of that factor.
   Module *M = Callee->getParent();