[SimplifyLibCalls] Use hasFloatVersion(). NFCI.
[oota-llvm.git] / lib / Transforms / Utils / SimplifyLibCalls.cpp
index 9fabf5de7f4e63d154d9b92c000dc1a19d4057a9..76a28e13bafe29fcd974b2a4d3b28eb2132036f1 100644 (file)
@@ -1049,9 +1049,9 @@ Value *LibCallSimplifier::optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) {
 Value *LibCallSimplifier::optimizeCos(CallInst *CI, IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
   Value *Ret = nullptr;
-  if (UnsafeFPShrink && Callee->getName() == "cos" && TLI->has(LibFunc::cosf)) {
+  StringRef Name = Callee->getName();
+  if (UnsafeFPShrink && Name == "cos" && hasFloatVersion(Name))
     Ret = optimizeUnaryDoubleFP(CI, B, true);
-  }
 
   FunctionType *FT = Callee->getFunctionType();
   // Just make sure this has 1 argument of FP type, which matches the
@@ -1071,11 +1071,10 @@ Value *LibCallSimplifier::optimizeCos(CallInst *CI, IRBuilder<> &B) {
 
 Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
-
   Value *Ret = nullptr;
-  if (UnsafeFPShrink && Callee->getName() == "pow" && TLI->has(LibFunc::powf)) {
+  StringRef Name = Callee->getName();
+  if (UnsafeFPShrink && Name == "pow" && hasFloatVersion(Name))
     Ret = optimizeUnaryDoubleFP(CI, B, true);
-  }
 
   FunctionType *FT = Callee->getFunctionType();
   // Just make sure this has 2 arguments of the same FP type, which match the
@@ -1103,6 +1102,32 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
                                   Callee->getAttributes());
   }
 
+  // pow(exp(x), y) -> exp(x*y)
+  // pow(exp2(x), y) -> exp2(x * y)
+  // We enable these only under fast-math. Besides rounding
+  // differences the transformation changes overflow and
+  // underflow behavior quite dramatically.
+  // Example: x = 1000, y = 0.001.
+  // pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x*y) = exp(1).
+  if (canUseUnsafeFPMath(CI->getParent()->getParent())) {
+    if (auto *OpC = dyn_cast<CallInst>(Op1)) {
+      IRBuilder<>::FastMathFlagGuard Guard(B);
+      FastMathFlags FMF;
+      FMF.setUnsafeAlgebra();
+      B.SetFastMathFlags(FMF);
+
+      LibFunc::Func Func;
+      Function *Callee = OpC->getCalledFunction();
+      StringRef FuncName = Callee->getName();
+
+      if (TLI->getLibFunc(FuncName, Func) && TLI->has(Func) &&
+          (Func == LibFunc::exp || Func == LibFunc::exp2))
+        return EmitUnaryFloatFnCall(
+            B.CreateFMul(OpC->getArgOperand(0), Op2, "mul"), FuncName, B,
+            Callee->getAttributes());
+    }
+  }
+
   ConstantFP *Op2C = dyn_cast<ConstantFP>(Op2);
   if (!Op2C)
     return Ret;
@@ -1142,12 +1167,10 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
 Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
   Function *Caller = CI->getParent()->getParent();
-
   Value *Ret = nullptr;
-  if (UnsafeFPShrink && Callee->getName() == "exp2" &&
-      TLI->has(LibFunc::exp2f)) {
+  StringRef Name = Callee->getName();
+  if (UnsafeFPShrink && Name == "exp2" && hasFloatVersion(Name))
     Ret = optimizeUnaryDoubleFP(CI, B, true);
-  }
 
   FunctionType *FT = Callee->getFunctionType();
   // Just make sure this has 1 argument of FP type, which matches the
@@ -1196,11 +1219,10 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) {
 
 Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
-
   Value *Ret = nullptr;
-  if (Callee->getName() == "fabs" && TLI->has(LibFunc::fabsf)) {
+  StringRef Name = Callee->getName();
+  if (Name == "fabs" && hasFloatVersion(Name))
     Ret = optimizeUnaryDoubleFP(CI, B, false);
-  }
 
   FunctionType *FT = Callee->getFunctionType();
   // Make sure this has 1 argument of FP type which matches the result type.
@@ -1222,8 +1244,9 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) {
   // If we can shrink the call to a float function rather than a double
   // function, do that first.
   Function *Callee = CI->getCalledFunction();
-  if ((Callee->getName() == "fmin" && TLI->has(LibFunc::fminf)) ||
-      (Callee->getName() == "fmax" && TLI->has(LibFunc::fmaxf))) {
+  StringRef Name = Callee->getName();
+  if ((Name == "fmin" && hasFloatVersion(Name)) ||
+      (Name == "fmax" && hasFloatVersion(Name))) {
     Value *Ret = optimizeBinaryDoubleFP(CI, B);
     if (Ret)
       return Ret;
@@ -1333,6 +1356,41 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
   return Ret;
 }
 
+Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilder<> &B) {
+  Function *Callee = CI->getCalledFunction();
+  Value *Ret = nullptr;
+  StringRef Name = Callee->getName();
+  if (UnsafeFPShrink && Name == "tan" && hasFloatVersion(Name))
+    Ret = optimizeUnaryDoubleFP(CI, B, true);
+  FunctionType *FT = Callee->getFunctionType();
+
+  // Just make sure this has 1 argument of FP type, which matches the
+  // result type.
+  if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) ||
+      !FT->getParamType(0)->isFloatingPointTy())
+    return Ret;
+
+  if (!canUseUnsafeFPMath(CI->getParent()->getParent()))
+    return Ret;
+  Value *Op1 = CI->getArgOperand(0);
+  auto *OpC = dyn_cast<CallInst>(Op1);
+  if (!OpC)
+    return Ret;
+
+  // tan(atan(x)) -> x
+  // tanf(atanf(x)) -> x
+  // tanl(atanl(x)) -> x
+  LibFunc::Func Func;
+  Function *F = OpC->getCalledFunction();
+  StringRef FuncName = F->getName();
+  if (TLI->getLibFunc(FuncName, Func) && TLI->has(Func) &&
+      ((Func == LibFunc::atan && Callee->getName() == "tan") ||
+       (Func == LibFunc::atanf && Callee->getName() == "tanf") ||
+       (Func == LibFunc::atanl && Callee->getName() == "tanl")))
+    Ret = OpC->getArgOperand(0);
+  return Ret;
+}
+
 static bool isTrigLibCall(CallInst *CI);
 static void insertSinCosCall(IRBuilder<> &B, Function *OrigCallee, Value *Arg,
                              bool UseFloat, Value *&Sin, Value *&Cos,
@@ -2120,6 +2178,10 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
       return optimizeFPuts(CI, Builder);
     case LibFunc::puts:
       return optimizePuts(CI, Builder);
+    case LibFunc::tan:
+    case LibFunc::tanf:
+    case LibFunc::tanl:
+      return optimizeTan(CI, Builder);
     case LibFunc::perror:
       return optimizeErrorReporting(CI, Builder);
     case LibFunc::vfprintf:
@@ -2154,7 +2216,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
     case LibFunc::logb:
     case LibFunc::sin:
     case LibFunc::sinh:
-    case LibFunc::tan:
     case LibFunc::tanh:
       if (UnsafeFPShrink && hasFloatVersion(FuncName))
         return optimizeUnaryDoubleFP(CI, Builder, true);