Use the new script to sort the includes of every file under lib.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineMulDivRem.cpp
index e104a0a9798d1b7077ef0cbefbcc878dbf38f13b..5cd611c4200adb819f93b491a18fd4fe28a47d2a 100644 (file)
@@ -13,8 +13,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "InstCombine.h"
-#include "llvm/IntrinsicInst.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/IntrinsicInst.h"
 #include "llvm/Support/PatternMatch.h"
 using namespace llvm;
 using namespace PatternMatch;
@@ -37,7 +37,7 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) {
   if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(PowerOf2), m_Value(A))),
                       m_Value(B))) &&
       // The "1" can be any value known to be a power of 2.
-      isPowerOfTwo(PowerOf2, IC.getTargetData())) {
+      isPowerOfTwo(PowerOf2, IC.getDataLayout())) {
     A = IC.Builder->CreateSub(A, B);
     return IC.Builder->CreateShl(PowerOf2, A);
   }
@@ -46,7 +46,7 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) {
   // inexact.  Similarly for <<.
   if (BinaryOperator *I = dyn_cast<BinaryOperator>(V))
     if (I->isLogicalShift() &&
-        isPowerOfTwo(I->getOperand(0), IC.getTargetData())) {
+        isPowerOfTwo(I->getOperand(0), IC.getDataLayout())) {
       // We know that this is an exact/nuw shift and that the input is a
       // non-zero context as well.
       if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC)) {
@@ -252,6 +252,46 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
   return Changed ? &I : 0;
 }
 
+//
+// Detect pattern:
+//
+// log2(Y*0.5)
+//
+// And check for corresponding fast math flags
+//
+
+static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) {
+
+   if (!Op->hasOneUse())
+     return;
+
+   IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op);
+   if (!II)
+     return;
+   if (II->getIntrinsicID() != Intrinsic::log2 || !II->hasUnsafeAlgebra())
+     return;
+   Log2 = II;
+
+   Value *OpLog2Of = II->getArgOperand(0);
+   if (!OpLog2Of->hasOneUse())
+     return;
+
+   Instruction *I = dyn_cast<Instruction>(OpLog2Of);
+   if (!I)
+     return;
+   if (I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra())
+     return;
+              
+   ConstantFP *CFP = dyn_cast<ConstantFP>(I->getOperand(0));
+   if (CFP && CFP->isExactlyValue(0.5)) {
+     Y = I->getOperand(1);
+     return;
+   }
+   CFP = dyn_cast<ConstantFP>(I->getOperand(1));
+   if (CFP && CFP->isExactlyValue(0.5))
+     Y = I->getOperand(0);
+} 
+
 Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
   bool Changed = SimplifyAssociativeOrCommutative(I);
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -284,6 +324,33 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
     if (Value *Op1v = dyn_castFNegVal(Op1))
       return BinaryOperator::CreateFMul(Op0v, Op1v);
 
+  // Under unsafe algebra do:
+  // X * log2(0.5*Y) = X*log2(Y) - X
+  if (I.hasUnsafeAlgebra()) {
+    Value *OpX = NULL;
+    Value *OpY = NULL;
+    IntrinsicInst *Log2;
+    detectLog2OfHalf(Op0, OpY, Log2);
+    if (OpY) {
+      OpX = Op1;
+    } else {
+      detectLog2OfHalf(Op1, OpY, Log2);
+      if (OpY) {
+        OpX = Op0;
+      }
+    }
+    // if pattern detected emit alternate sequence
+    if (OpX && OpY) {
+      Log2->setArgOperand(0, OpY);
+      Value *FMulVal = Builder->CreateFMul(OpX, Log2);
+      Instruction *FMul = cast<Instruction>(FMulVal);
+      FMul->copyFastMathFlags(Log2);
+      Instruction *FSub = BinaryOperator::CreateFSub(FMulVal, OpX);
+      FSub->copyFastMathFlags(Log2);
+      return FSub;
+    }
+  }
+
   return Changed ? &I : 0;
 }
 
@@ -462,13 +529,13 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
     }
   }
 
-  // Udiv ((Lshl x, c1) , c2) ->  x / (C1 * 1<<C2);
-  if (Constant *C = dyn_cast<Constant>(Op1)) {
-    Value *X = 0, *C1 = 0;
-    if (match(Op0, m_LShr(m_Value(X), m_Value(C1)))) {
-      uint64_t NC = cast<ConstantInt>(C)->getZExtValue() *
-                    (1<< cast<ConstantInt>(C1)->getZExtValue());
-      return BinaryOperator::CreateUDiv(X, ConstantInt::get(I.getType(), NC));
+  // (x lshr C1) udiv C2 --> x udiv (C2 << C1)
+  if (ConstantInt *C2 = dyn_cast<ConstantInt>(Op1)) {
+    Value *X;
+    ConstantInt *C1;
+    if (match(Op0, m_LShr(m_Value(X), m_ConstantInt(C1)))) {
+      APInt NC = C2->getValue().shl(C1->getLimitedValue(C1->getBitWidth()-1));
+      return BinaryOperator::CreateUDiv(X, Builder->getInt(NC));
     }
   }
 
@@ -477,7 +544,8 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
     if (match(Op1, m_Shl(m_Power2(CI), m_Value(N))) ||
         match(Op1, m_ZExt(m_Shl(m_Power2(CI), m_Value(N))))) {
       if (*CI != 1)
-        N = Builder->CreateAdd(N, ConstantInt::get(I.getType(),CI->logBase2()));
+        N = Builder->CreateAdd(N,
+                               ConstantInt::get(N->getType(), CI->logBase2()));
       if (ZExtInst *Z = dyn_cast<ZExtInst>(Op1))
         N = Builder->CreateZExt(N, Z->getDestTy());
       if (I.isExact())
@@ -543,16 +611,6 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
                                           ConstantExpr::getNeg(RHS));
   }
 
-  // Sdiv ((Ashl x, c1) , c2) ->  x / (C1 * 1<<C2);
-  if (Constant *C = dyn_cast<Constant>(Op1)) {
-    Value *X = 0, *C1 = 0;
-    if (match(Op0, m_AShr(m_Value(X), m_Value(C1)))) {
-      uint64_t NC = cast<ConstantInt>(C)->getZExtValue() *
-                    (1<< cast<ConstantInt>(C1)->getZExtValue());
-      return BinaryOperator::CreateSDiv(X, ConstantInt::get(I.getType(), NC));
-    }
-  }
-
   // If the sign bits of both operands are zero (i.e. we can prove they are
   // unsigned inputs), turn this into a udiv.
   if (I.getType()->isIntegerTy()) {