From 0f019d6283f8819771a3bc164320abb5d09df8f1 Mon Sep 17 00:00:00 2001 From: Davide Italiano Date: Sun, 29 Nov 2015 20:58:04 +0000 Subject: [PATCH] [SimplifyLibCalls] Tranform log(pow(x, y)) -> y*log(x). This one is enabled only under -ffast-math. There are cases where the difference between the value computed and the correct value is huge even for ffast-math, e.g. as Steven pointed out: x = -1, y = -4 log(pow(-1), 4) = 0 4*log(-1) = NaN I checked what GCC does and apparently they do the same optimization (which result in the dramatic difference). Future work might try to make this (slightly) less worse. Differential Revision: http://reviews.llvm.org/D14400 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@254263 91177308-0d34-0410-b5e6-96231b3b80d8 --- .../llvm/Transforms/Utils/SimplifyLibCalls.h | 1 + lib/Transforms/Utils/SimplifyLibCalls.cpp | 55 +++++++++++++++++-- .../InstCombine/log-pow-nofastmath.ll | 17 ++++++ test/Transforms/InstCombine/log-pow.ll | 19 +++++++ 4 files changed, 87 insertions(+), 5 deletions(-) create mode 100644 test/Transforms/InstCombine/log-pow-nofastmath.ll create mode 100644 test/Transforms/InstCombine/log-pow.ll diff --git a/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/include/llvm/Transforms/Utils/SimplifyLibCalls.h index 55b0a13ad3c..410a075aeb9 100644 --- a/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ b/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -132,6 +132,7 @@ private: Value *optimizeExp2(CallInst *CI, IRBuilder<> &B); Value *optimizeFabs(CallInst *CI, IRBuilder<> &B); Value *optimizeFMinFMax(CallInst *CI, IRBuilder<> &B); + Value *optimizeLog(CallInst *CI, IRBuilder<> &B); Value *optimizeSqrt(CallInst *CI, IRBuilder<> &B); Value *optimizeSinCosPi(CallInst *CI, IRBuilder<> &B); Value *optimizeTan(CallInst *CI, IRBuilder<> &B); diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp index 6d3dfd6750a..c811e19f8b2 100644 --- a/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1284,6 +1284,48 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) { return B.CreateSelect(Cmp, Op0, Op1); } +Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + Value *Ret = nullptr; + StringRef Name = Callee->getName(); + if (UnsafeFPShrink && hasFloatVersion(Name)) + Ret = optimizeUnaryDoubleFP(CI, B, true); + FunctionType *FT = Callee->getFunctionType(); + + // Just make sure this has 1 argument of FP type, which matches the + // result type. + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isFloatingPointTy()) + return Ret; + + if (!canUseUnsafeFPMath(CI->getParent()->getParent())) + return Ret; + Value *Op1 = CI->getArgOperand(0); + auto *OpC = dyn_cast(Op1); + if (!OpC) + return Ret; + + // log(pow(x,y)) -> y*log(x) + // This is only applicable to log, log2, log10. + if (Name != "log" && Name != "log2" && Name != "log10") + return Ret; + + IRBuilder<>::FastMathFlagGuard Guard(B); + FastMathFlags FMF; + FMF.setUnsafeAlgebra(); + B.SetFastMathFlags(FMF); + + LibFunc::Func Func; + Function *F = OpC->getCalledFunction(); + StringRef FuncName = F->getName(); + if ((TLI->getLibFunc(FuncName, Func) && TLI->has(Func) && + Func == LibFunc::pow) || F->getIntrinsicID() == Intrinsic::pow) + return B.CreateFMul(OpC->getArgOperand(1), + EmitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B, + Callee->getAttributes()), "mul"); + return Ret; +} + Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); @@ -2088,6 +2130,8 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { return optimizeExp2(CI, Builder); case Intrinsic::fabs: return optimizeFabs(CI, Builder); + case Intrinsic::log: + return optimizeLog(CI, Builder); case Intrinsic::sqrt: return optimizeSqrt(CI, Builder); default: @@ -2170,6 +2214,12 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { return optimizeFWrite(CI, Builder); case LibFunc::fputs: return optimizeFPuts(CI, Builder); + case LibFunc::log: + case LibFunc::log10: + case LibFunc::log1p: + case LibFunc::log2: + case LibFunc::logb: + return optimizeLog(CI, Builder); case LibFunc::puts: return optimizePuts(CI, Builder); case LibFunc::tan: @@ -2203,11 +2253,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { case LibFunc::exp: case LibFunc::exp10: case LibFunc::expm1: - case LibFunc::log: - case LibFunc::log10: - case LibFunc::log1p: - case LibFunc::log2: - case LibFunc::logb: case LibFunc::sin: case LibFunc::sinh: case LibFunc::tanh: diff --git a/test/Transforms/InstCombine/log-pow-nofastmath.ll b/test/Transforms/InstCombine/log-pow-nofastmath.ll new file mode 100644 index 00000000000..0811e63cc74 --- /dev/null +++ b/test/Transforms/InstCombine/log-pow-nofastmath.ll @@ -0,0 +1,17 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +define double @mylog(double %x, double %y) #0 { +entry: + %pow = call double @llvm.pow.f64(double %x, double %y) + %call = call double @log(double %pow) #0 + ret double %call +} + +; CHECK-LABEL: define double @mylog( +; CHECK: %pow = call double @llvm.pow.f64(double %x, double %y) +; CHECK: %call = call double @log(double %pow) +; CHECK: ret double %call +; CHECK: } + +declare double @log(double) #0 +declare double @llvm.pow.f64(double, double) diff --git a/test/Transforms/InstCombine/log-pow.ll b/test/Transforms/InstCombine/log-pow.ll new file mode 100644 index 00000000000..2cafccabf0a --- /dev/null +++ b/test/Transforms/InstCombine/log-pow.ll @@ -0,0 +1,19 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +define double @mylog(double %x, double %y) #0 { +entry: + %pow = call double @llvm.pow.f64(double %x, double %y) + %call = call double @log(double %pow) #0 + ret double %call +} + +; CHECK-LABEL: define double @mylog( +; CHECK: %log = call double @log(double %x) #0 +; CHECK: %mul = fmul fast double %log, %y +; CHECK: ret double %mul +; CHECK: } + +declare double @log(double) #0 +declare double @llvm.pow.f64(double, double) + +attributes #0 = { "unsafe-fp-math"="true" } -- 2.34.1