X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FTransforms%2FUtils%2FSimplifyLibCalls.cpp;h=47e587fab7b6a94f585a36a26d47b909fde4a72d;hb=401b67d4eb0bed5dd9aa007f0a8241060cabb680;hp=67e1697c7da5900db07baba5e3dbad91ea3a1180;hpb=2748acd4155076c048c0742585408980b32ac4eb;p=oota-llvm.git diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp index 67e1697c7da..47e587fab7b 100644 --- a/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -86,10 +86,9 @@ static bool isOnlyUsedInEqualityComparison(Value *V, Value *With) { } static bool callHasFloatingPointArgument(const CallInst *CI) { - for (const Use &OI : CI->operands()) - if (OI->getType()->isFloatingPointTy()) - return true; - return false; + return std::any_of(CI->op_begin(), CI->op_end(), [](const Use &OI) { + return OI->getType()->isFloatingPointTy(); + }); } /// \brief Check whether the overloaded unary floating point function @@ -1285,6 +1284,47 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) { return B.CreateSelect(Cmp, Op0, Op1); } +Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) { + Function *Callee = CI->getCalledFunction(); + Value *Ret = nullptr; + StringRef Name = Callee->getName(); + if (UnsafeFPShrink && hasFloatVersion(Name)) + Ret = optimizeUnaryDoubleFP(CI, B, true); + FunctionType *FT = Callee->getFunctionType(); + + // Just make sure this has 1 argument of FP type, which matches the + // result type. + if (FT->getNumParams() != 1 || FT->getReturnType() != FT->getParamType(0) || + !FT->getParamType(0)->isFloatingPointTy()) + return Ret; + + if (!canUseUnsafeFPMath(CI->getParent()->getParent())) + return Ret; + Value *Op1 = CI->getArgOperand(0); + auto *OpC = dyn_cast(Op1); + if (!OpC) + return Ret; + + // log(pow(x,y)) -> y*log(x) + // This is only applicable to log, log2, log10. + if (Name != "log" && Name != "log2" && Name != "log10") + return Ret; + + IRBuilder<>::FastMathFlagGuard Guard(B); + FastMathFlags FMF; + FMF.setUnsafeAlgebra(); + B.SetFastMathFlags(FMF); + + LibFunc::Func Func; + Function *F = OpC->getCalledFunction(); + if (F && ((TLI->getLibFunc(F->getName(), Func) && TLI->has(Func) && + Func == LibFunc::pow) || F->getIntrinsicID() == Intrinsic::pow)) + return B.CreateFMul(OpC->getArgOperand(1), + EmitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B, + Callee->getAttributes()), "mul"); + return Ret; +} + Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); @@ -1454,8 +1494,8 @@ LibCallSimplifier::classifyArgUse(Value *Val, BasicBlock *BB, bool IsFloat, Function *Callee = CI->getCalledFunction(); LibFunc::Func Func; - if (Callee && (!TLI->getLibFunc(Callee->getName(), Func) || !TLI->has(Func) || - !isTrigLibCall(CI))) + if (!Callee || !TLI->getLibFunc(Callee->getName(), Func) || !TLI->has(Func) || + !isTrigLibCall(CI)) return; if (IsFloat) { @@ -2089,6 +2129,8 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { return optimizeExp2(CI, Builder); case Intrinsic::fabs: return optimizeFabs(CI, Builder); + case Intrinsic::log: + return optimizeLog(CI, Builder); case Intrinsic::sqrt: return optimizeSqrt(CI, Builder); default: @@ -2171,6 +2213,12 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { return optimizeFWrite(CI, Builder); case LibFunc::fputs: return optimizeFPuts(CI, Builder); + case LibFunc::log: + case LibFunc::log10: + case LibFunc::log1p: + case LibFunc::log2: + case LibFunc::logb: + return optimizeLog(CI, Builder); case LibFunc::puts: return optimizePuts(CI, Builder); case LibFunc::tan: @@ -2204,11 +2252,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { case LibFunc::exp: case LibFunc::exp10: case LibFunc::expm1: - case LibFunc::log: - case LibFunc::log10: - case LibFunc::log1p: - case LibFunc::log2: - case LibFunc::logb: case LibFunc::sin: case LibFunc::sinh: case LibFunc::tanh: @@ -2257,12 +2300,10 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) { // // log, logf, logl: // * log(exp(x)) -> x -// * log(x**y) -> y*log(x) // * log(exp(y)) -> y*log(e) // * log(exp2(y)) -> y*log(2) // * log(exp10(y)) -> y*log(10) // * log(sqrt(x)) -> 0.5*log(x) -// * log(pow(x,y)) -> y*log(x) // // lround, lroundf, lroundl: // * lround(cnst) -> cnst'