[SimplifyLibCalls] Tranform log(pow(x, y)) -> y*log(x).
authorDavide Italiano <davide@freebsd.org>
Sun, 29 Nov 2015 20:58:04 +0000 (20:58 +0000)
committerDavide Italiano <davide@freebsd.org>
Sun, 29 Nov 2015 20:58:04 +0000 (20:58 +0000)
This one is enabled only under -ffast-math. There are cases where the
difference between the value computed and the correct value is huge
even for ffast-math, e.g. as Steven pointed out:

x = -1, y = -4
log(pow(-1), 4) = 0
4*log(-1) = NaN

I checked what GCC does and apparently they do the same optimization
(which result in the dramatic difference). Future work might try to
make this (slightly) less worse.

Differential Revision: http://reviews.llvm.org/D14400

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@254263 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Transforms/Utils/SimplifyLibCalls.h
lib/Transforms/Utils/SimplifyLibCalls.cpp
test/Transforms/InstCombine/log-pow-nofastmath.ll [new file with mode: 0644]
test/Transforms/InstCombine/log-pow.ll [new file with mode: 0644]

index 55b0a13..410a075 100644 (file)
@@ -132,6 +132,7 @@ private:
   Value *optimizeExp2(CallInst *CI, IRBuilder<> &B);
   Value *optimizeFabs(CallInst *CI, IRBuilder<> &B);
   Value *optimizeFMinFMax(CallInst *CI, IRBuilder<> &B);
+  Value *optimizeLog(CallInst *CI, IRBuilder<> &B);
   Value *optimizeSqrt(CallInst *CI, IRBuilder<> &B);
   Value *optimizeSinCosPi(CallInst *CI, IRBuilder<> &B);
   Value *optimizeTan(CallInst *CI, IRBuilder<> &B);
index 6d3dfd6..c811e19 100644 (file)
@@ -1284,6 +1284,48 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) {
   return B.CreateSelect(Cmp, Op0, Op1);
 }
 
+Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {
+  Function *Callee = CI->getCalledFunction();
+  Value *Ret = nullptr;
+  StringRef Name = Callee->getName();
+  if (UnsafeFPShrink && hasFloatVersion(Name))
+    Ret = optimizeUnaryDoubleFP(CI, B, true);
+  FunctionType *FT = Callee->getFunctionType();
+
+  // Just make sure this has 1 argument of FP type, which matches the
+  // result type.
+  if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) ||
+      !FT->getParamType(0)->isFloatingPointTy())
+    return Ret;
+
+  if (!canUseUnsafeFPMath(CI->getParent()->getParent()))
+    return Ret;
+  Value *Op1 = CI->getArgOperand(0);
+  auto *OpC = dyn_cast<CallInst>(Op1);
+  if (!OpC)
+    return Ret;
+
+  // log(pow(x,y)) -> y*log(x)
+  // This is only applicable to log, log2, log10.
+  if (Name != "log" && Name != "log2" && Name != "log10")
+    return Ret;
+
+  IRBuilder<>::FastMathFlagGuard Guard(B);
+  FastMathFlags FMF;
+  FMF.setUnsafeAlgebra();
+  B.SetFastMathFlags(FMF);
+
+  LibFunc::Func Func;
+  Function *F = OpC->getCalledFunction();
+  StringRef FuncName = F->getName();
+  if ((TLI->getLibFunc(FuncName, Func) && TLI->has(Func) &&
+      Func == LibFunc::pow) || F->getIntrinsicID() == Intrinsic::pow)
+    return B.CreateFMul(OpC->getArgOperand(1),
+      EmitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B,
+                           Callee->getAttributes()), "mul");
+  return Ret;
+}
+
 Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
   
@@ -2088,6 +2130,8 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
       return optimizeExp2(CI, Builder);
     case Intrinsic::fabs:
       return optimizeFabs(CI, Builder);
+    case Intrinsic::log:
+      return optimizeLog(CI, Builder);
     case Intrinsic::sqrt:
       return optimizeSqrt(CI, Builder);
     default:
@@ -2170,6 +2214,12 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
       return optimizeFWrite(CI, Builder);
     case LibFunc::fputs:
       return optimizeFPuts(CI, Builder);
+    case LibFunc::log:
+    case LibFunc::log10:
+    case LibFunc::log1p:
+    case LibFunc::log2:
+    case LibFunc::logb:
+      return optimizeLog(CI, Builder);
     case LibFunc::puts:
       return optimizePuts(CI, Builder);
     case LibFunc::tan:
@@ -2203,11 +2253,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) {
     case LibFunc::exp:
     case LibFunc::exp10:
     case LibFunc::expm1:
-    case LibFunc::log:
-    case LibFunc::log10:
-    case LibFunc::log1p:
-    case LibFunc::log2:
-    case LibFunc::logb:
     case LibFunc::sin:
     case LibFunc::sinh:
     case LibFunc::tanh:
diff --git a/test/Transforms/InstCombine/log-pow-nofastmath.ll b/test/Transforms/InstCombine/log-pow-nofastmath.ll
new file mode 100644 (file)
index 0000000..0811e63
--- /dev/null
@@ -0,0 +1,17 @@
+; RUN: opt < %s -instcombine -S | FileCheck %s
+
+define double @mylog(double %x, double %y) #0 {
+entry:
+  %pow = call double @llvm.pow.f64(double %x, double %y)
+  %call = call double @log(double %pow) #0
+  ret double %call
+}
+
+; CHECK-LABEL: define double @mylog(
+; CHECK:   %pow = call double @llvm.pow.f64(double %x, double %y)
+; CHECK:   %call = call double @log(double %pow)
+; CHECK:   ret double %call
+; CHECK: }
+
+declare double @log(double) #0
+declare double @llvm.pow.f64(double, double)
diff --git a/test/Transforms/InstCombine/log-pow.ll b/test/Transforms/InstCombine/log-pow.ll
new file mode 100644 (file)
index 0000000..2cafcca
--- /dev/null
@@ -0,0 +1,19 @@
+; RUN: opt < %s -instcombine -S | FileCheck %s
+
+define double @mylog(double %x, double %y) #0 {
+entry:
+  %pow = call double @llvm.pow.f64(double %x, double %y)
+  %call = call double @log(double %pow) #0
+  ret double %call
+}
+
+; CHECK-LABEL: define double @mylog(
+; CHECK:   %log = call double @log(double %x) #0
+; CHECK:   %mul = fmul fast double %log, %y
+; CHECK:   ret double %mul
+; CHECK: }
+
+declare double @log(double) #0
+declare double @llvm.pow.f64(double, double)
+
+attributes #0 = { "unsafe-fp-math"="true" }