getParent() ^ 3 == getModule() ; NFCI
[oota-llvm.git] / lib / Transforms / Utils / SimplifyLibCalls.cpp
index 47e587fab7b6a94f585a36a26d47b909fde4a72d..81dea6d1b9ae3e7b2382b45c157f9dbe7c5cb4fc 100644 (file)
@@ -995,7 +995,7 @@ Value *LibCallSimplifier::optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B,
 
   // floor((double)floatval) -> (double)floorf(floatval)
   if (Callee->isIntrinsic()) {
-    Module *M = CI->getParent()->getParent()->getParent();
+    Module *M = CI->getModule();
     Intrinsic::ID IID = Callee->getIntrinsicID();
     Function *F = Intrinsic::getDeclaration(M, IID, B.getFloatTy());
     V = B.CreateCall(F, V);
@@ -1058,6 +1058,31 @@ Value *LibCallSimplifier::optimizeCos(CallInst *CI, IRBuilder<> &B) {
   return Ret;
 }
 
+static Value *getPow(Value *InnerChain[33], unsigned Exp, IRBuilder<> &B) {
+  // Multiplications calculated using Addition Chains.
+  // Refer: http://wwwhomes.uni-bielefeld.de/achim/addition_chain.html
+
+  assert(Exp != 0 && "Incorrect exponent 0 not handled");
+
+  if (InnerChain[Exp])
+    return InnerChain[Exp];
+
+  static const unsigned AddChain[33][2] = {
+      {0, 0}, // Unused.
+      {0, 0}, // Unused (base case = pow1).
+      {1, 1}, // Unused (pre-computed).
+      {1, 2},  {2, 2},   {2, 3},  {3, 3},   {2, 5},  {4, 4},
+      {1, 8},  {5, 5},   {1, 10}, {6, 6},   {4, 9},  {7, 7},
+      {3, 12}, {8, 8},   {8, 9},  {2, 16},  {1, 18}, {10, 10},
+      {6, 15}, {11, 11}, {3, 20}, {12, 12}, {8, 17}, {13, 13},
+      {3, 24}, {14, 14}, {4, 25}, {15, 15}, {3, 28}, {16, 16},
+  };
+
+  InnerChain[Exp] = B.CreateFMul(getPow(InnerChain, AddChain[Exp][0], B),
+                                 getPow(InnerChain, AddChain[Exp][1], B));
+  return InnerChain[Exp];
+}
+
 Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
   Function *Callee = CI->getCalledFunction();
   Value *Ret = nullptr;
@@ -1156,6 +1181,32 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) {
     return B.CreateFMul(Op1, Op1, "pow2");
   if (Op2C->isExactlyValue(-1.0)) // pow(x, -1.0) -> 1.0/x
     return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Op1, "powrecip");
+
+  // In -ffast-math, generate repeated fmul instead of generating pow(x, n).
+  if (unsafeFPMath) {
+    APFloat V = abs(Op2C->getValueAPF());
+    // We limit to a max of 7 fmul(s). Thus max exponent is 32.
+    // This transformation applies to integer exponents only.
+    if (V.compare(APFloat(V.getSemantics(), 32.0)) == APFloat::cmpGreaterThan ||
+        !V.isInteger())
+      return nullptr;
+
+    // We will memoize intermediate products of the Addition Chain.
+    Value *InnerChain[33] = {nullptr};
+    InnerChain[1] = Op1;
+    InnerChain[2] = B.CreateFMul(Op1, Op1);
+
+    // We cannot readily convert a non-double type (like float) to a double.
+    // So we first convert V to something which could be converted to double.
+    bool ignored;
+    V.convert(APFloat::IEEEdouble, APFloat::rmTowardZero, &ignored);
+    Value *FMul = getPow(InnerChain, V.convertToDouble(), B);
+    // For negative exponents simply compute the reciprocal.
+    if (Op2C->isNegative())
+      FMul = B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), FMul);
+    return FMul;
+  }
+
   return nullptr;
 }
 
@@ -1322,6 +1373,15 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) {
     return B.CreateFMul(OpC->getArgOperand(1),
       EmitUnaryFloatFnCall(OpC->getOperand(0), Callee->getName(), B,
                            Callee->getAttributes()), "mul");
+
+  // log(exp2(y)) -> y*log(2)
+  if (F && Name == "log" && TLI->getLibFunc(F->getName(), Func) &&
+      TLI->has(Func) && Func == LibFunc::exp2)
+    return B.CreateFMul(
+        OpC->getArgOperand(0),
+        EmitUnaryFloatFnCall(ConstantFP::get(CI->getType(), 2.0),
+                             Callee->getName(), B, Callee->getAttributes()),
+        "logmul");
   return Ret;
 }
 
@@ -2301,7 +2361,6 @@ void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) {
 // log, logf, logl:
 //   * log(exp(x))   -> x
 //   * log(exp(y))   -> y*log(e)
-//   * log(exp2(y))  -> y*log(2)
 //   * log(exp10(y)) -> y*log(10)
 //   * log(sqrt(x))  -> 0.5*log(x)
 //