Use the new script to sort the includes of every file under lib.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineMulDivRem.cpp
index 2d29403097ce79384f76d37f092ab04bce233e49..5cd611c4200adb819f93b491a18fd4fe28a47d2a 100644 (file)
@@ -13,8 +13,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "InstCombine.h"
-#include "llvm/IntrinsicInst.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/IntrinsicInst.h"
 #include "llvm/Support/PatternMatch.h"
 using namespace llvm;
 using namespace PatternMatch;
@@ -37,8 +37,8 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) {
   if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(PowerOf2), m_Value(A))),
                       m_Value(B))) &&
       // The "1" can be any value known to be a power of 2.
-      isPowerOfTwo(PowerOf2, IC.getTargetData())) {
-    A = IC.Builder->CreateSub(A, B, "tmp");
+      isPowerOfTwo(PowerOf2, IC.getDataLayout())) {
+    A = IC.Builder->CreateSub(A, B);
     return IC.Builder->CreateShl(PowerOf2, A);
   }
   
@@ -46,7 +46,7 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) {
   // inexact.  Similarly for <<.
   if (BinaryOperator *I = dyn_cast<BinaryOperator>(V))
     if (I->isLogicalShift() &&
-        isPowerOfTwo(I->getOperand(0), IC.getTargetData())) {
+        isPowerOfTwo(I->getOperand(0), IC.getDataLayout())) {
       // We know that this is an exact/nuw shift and that the input is a
       // non-zero context as well.
       if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC)) {
@@ -131,7 +131,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
     { Value *X; ConstantInt *C1;
       if (Op0->hasOneUse() &&
           match(Op0, m_Add(m_Value(X), m_ConstantInt(C1)))) {
-        Value *Add = Builder->CreateMul(X, CI, "tmp");
+        Value *Add = Builder->CreateMul(X, CI);
         return BinaryOperator::CreateAdd(Add, Builder->CreateMul(C1, CI));
       }
     }
@@ -244,7 +244,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
 
     if (BoolCast) {
       Value *V = Builder->CreateSub(Constant::getNullValue(I.getType()),
-                                    BoolCast, "tmp");
+                                    BoolCast);
       return BinaryOperator::CreateAnd(V, OtherOp);
     }
   }
@@ -252,26 +252,62 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
   return Changed ? &I : 0;
 }
 
+//
+// Detect pattern:
+//
+// log2(Y*0.5)
+//
+// And check for corresponding fast math flags
+//
+
+static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) {
+
+   if (!Op->hasOneUse())
+     return;
+
+   IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op);
+   if (!II)
+     return;
+   if (II->getIntrinsicID() != Intrinsic::log2 || !II->hasUnsafeAlgebra())
+     return;
+   Log2 = II;
+
+   Value *OpLog2Of = II->getArgOperand(0);
+   if (!OpLog2Of->hasOneUse())
+     return;
+
+   Instruction *I = dyn_cast<Instruction>(OpLog2Of);
+   if (!I)
+     return;
+   if (I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra())
+     return;
+              
+   ConstantFP *CFP = dyn_cast<ConstantFP>(I->getOperand(0));
+   if (CFP && CFP->isExactlyValue(0.5)) {
+     Y = I->getOperand(1);
+     return;
+   }
+   CFP = dyn_cast<ConstantFP>(I->getOperand(1));
+   if (CFP && CFP->isExactlyValue(0.5))
+     Y = I->getOperand(0);
+} 
+
 Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
   bool Changed = SimplifyAssociativeOrCommutative(I);
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
-  // Simplify mul instructions with a constant RHS...
+  // Simplify mul instructions with a constant RHS.
   if (Constant *Op1C = dyn_cast<Constant>(Op1)) {
     if (ConstantFP *Op1F = dyn_cast<ConstantFP>(Op1C)) {
       // "In IEEE floating point, x*1 is not equivalent to x for nans.  However,
       // ANSI says we can drop signals, so we can do this anyway." (from GCC)
       if (Op1F->isExactlyValue(1.0))
         return ReplaceInstUsesWith(I, Op0);  // Eliminate 'fmul double %X, 1.0'
-    } else if (Op1C->getType()->isVectorTy()) {
-      if (ConstantVector *Op1V = dyn_cast<ConstantVector>(Op1C)) {
-        // As above, vector X*splat(1.0) -> X in all defined cases.
-        if (Constant *Splat = Op1V->getSplatValue()) {
-          if (ConstantFP *F = dyn_cast<ConstantFP>(Splat))
-            if (F->isExactlyValue(1.0))
-              return ReplaceInstUsesWith(I, Op0);
-        }
-      }
+    } else if (ConstantDataVector *Op1V = dyn_cast<ConstantDataVector>(Op1C)) {
+      // As above, vector X*splat(1.0) -> X in all defined cases.
+      if (ConstantFP *F = dyn_cast_or_null<ConstantFP>(Op1V->getSplatValue()))
+        if (F->isExactlyValue(1.0))
+          return ReplaceInstUsesWith(I, Op0);
     }
 
     // Try to fold constant mul into select arguments.
@@ -288,6 +324,33 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
     if (Value *Op1v = dyn_castFNegVal(Op1))
       return BinaryOperator::CreateFMul(Op0v, Op1v);
 
+  // Under unsafe algebra do:
+  // X * log2(0.5*Y) = X*log2(Y) - X
+  if (I.hasUnsafeAlgebra()) {
+    Value *OpX = NULL;
+    Value *OpY = NULL;
+    IntrinsicInst *Log2;
+    detectLog2OfHalf(Op0, OpY, Log2);
+    if (OpY) {
+      OpX = Op1;
+    } else {
+      detectLog2OfHalf(Op1, OpY, Log2);
+      if (OpY) {
+        OpX = Op0;
+      }
+    }
+    // if pattern detected emit alternate sequence
+    if (OpX && OpY) {
+      Log2->setArgOperand(0, OpY);
+      Value *FMulVal = Builder->CreateFMul(OpX, Log2);
+      Instruction *FMul = cast<Instruction>(FMulVal);
+      FMul->copyFastMathFlags(Log2);
+      Instruction *FSub = BinaryOperator::CreateFSub(FMulVal, OpX);
+      FSub->copyFastMathFlags(Log2);
+      return FSub;
+    }
+  }
+
   return Changed ? &I : 0;
 }
 
@@ -421,7 +484,7 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) {
 
 /// dyn_castZExtVal - Checks if V is a zext or constant that can
 /// be truncated to Ty without losing bits.
-static Value *dyn_castZExtVal(Value *V, const Type *Ty) {
+static Value *dyn_castZExtVal(Value *V, Type *Ty) {
   if (ZExtInst *Z = dyn_cast<ZExtInst>(V)) {
     if (Z->getSrcTy() == Ty)
       return Z->getOperand(0);
@@ -441,19 +504,23 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
   // Handle the integer div common cases
   if (Instruction *Common = commonIDivTransforms(I))
     return Common;
-
-  if (ConstantInt *C = dyn_cast<ConstantInt>(Op1)) {
+  
+  { 
     // X udiv 2^C -> X >> C
     // Check to see if this is an unsigned division with an exact power of 2,
     // if so, convert to a right shift.
-    if (C->getValue().isPowerOf2()) { // 0 not included in isPowerOf2
+    const APInt *C;
+    if (match(Op1, m_Power2(C))) {
       BinaryOperator *LShr =
-        BinaryOperator::CreateLShr(Op0, 
-            ConstantInt::get(Op0->getType(), C->getValue().logBase2()));
+      BinaryOperator::CreateLShr(Op0, 
+                                 ConstantInt::get(Op0->getType(), 
+                                                  C->logBase2()));
       if (I.isExact()) LShr->setIsExact();
       return LShr;
     }
+  }
 
+  if (ConstantInt *C = dyn_cast<ConstantInt>(Op1)) {
     // X udiv C, where C >= signbit
     if (C->getValue().isNegative()) {
       Value *IC = Builder->CreateICmpULT(Op0, C);
@@ -462,12 +529,25 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
     }
   }
 
+  // (x lshr C1) udiv C2 --> x udiv (C2 << C1)
+  if (ConstantInt *C2 = dyn_cast<ConstantInt>(Op1)) {
+    Value *X;
+    ConstantInt *C1;
+    if (match(Op0, m_LShr(m_Value(X), m_ConstantInt(C1)))) {
+      APInt NC = C2->getValue().shl(C1->getLimitedValue(C1->getBitWidth()-1));
+      return BinaryOperator::CreateUDiv(X, Builder->getInt(NC));
+    }
+  }
+
   // X udiv (C1 << N), where C1 is "1<<C2"  -->  X >> (N+C2)
   { const APInt *CI; Value *N;
-    if (match(Op1, m_Shl(m_Power2(CI), m_Value(N)))) {
+    if (match(Op1, m_Shl(m_Power2(CI), m_Value(N))) ||
+        match(Op1, m_ZExt(m_Shl(m_Power2(CI), m_Value(N))))) {
       if (*CI != 1)
-        N = Builder->CreateAdd(N, ConstantInt::get(I.getType(), CI->logBase2()),
-                               "tmp");
+        N = Builder->CreateAdd(N,
+                               ConstantInt::get(N->getType(), CI->logBase2()));
+      if (ZExtInst *Z = dyn_cast<ZExtInst>(Op1))
+        N = Builder->CreateZExt(N, Z->getDestTy());
       if (I.isExact())
         return BinaryOperator::CreateExactLShr(Op0, N);
       return BinaryOperator::CreateLShr(Op0, N);
@@ -630,7 +710,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) {
   // Turn A % (C << N), where C is 2^k, into A & ((C << N)-1)  
   if (match(Op1, m_Shl(m_Power2(), m_Value()))) {
     Constant *N1 = Constant::getAllOnesValue(I.getType());
-    Value *Add = Builder->CreateAdd(Op1, N1, "tmp");
+    Value *Add = Builder->CreateAdd(Op1, N1);
     return BinaryOperator::CreateAnd(Op0, Add);
   }
 
@@ -685,28 +765,36 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
   }
 
   // If it's a constant vector, flip any negative values positive.
-  if (ConstantVector *RHSV = dyn_cast<ConstantVector>(Op1)) {
-    unsigned VWidth = RHSV->getNumOperands();
+  if (isa<ConstantVector>(Op1) || isa<ConstantDataVector>(Op1)) {
+    Constant *C = cast<Constant>(Op1);
+    unsigned VWidth = C->getType()->getVectorNumElements();
 
     bool hasNegative = false;
-    for (unsigned i = 0; !hasNegative && i != VWidth; ++i)
-      if (ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV->getOperand(i)))
-        if (RHS->getValue().isNegative())
+    bool hasMissing = false;
+    for (unsigned i = 0; i != VWidth; ++i) {
+      Constant *Elt = C->getAggregateElement(i);
+      if (Elt == 0) {
+        hasMissing = true;
+        break;
+      }
+
+      if (ConstantInt *RHS = dyn_cast<ConstantInt>(Elt))
+        if (RHS->isNegative())
           hasNegative = true;
+    }
 
-    if (hasNegative) {
-      std::vector<Constant *> Elts(VWidth);
+    if (hasNegative && !hasMissing) {
+      SmallVector<Constant *, 16> Elts(VWidth);
       for (unsigned i = 0; i != VWidth; ++i) {
-        if (ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV->getOperand(i))) {
-          if (RHS->getValue().isNegative())
+        Elts[i] = C->getAggregateElement(i);  // Handle undef, etc.
+        if (ConstantInt *RHS = dyn_cast<ConstantInt>(Elts[i])) {
+          if (RHS->isNegative())
             Elts[i] = cast<ConstantInt>(ConstantExpr::getNeg(RHS));
-          else
-            Elts[i] = RHS;
         }
       }
 
       Constant *NewRHSV = ConstantVector::get(Elts);
-      if (NewRHSV != RHSV) {
+      if (NewRHSV != C) {  // Don't loop on -MININT
         Worklist.AddValue(I.getOperand(1));
         I.setOperand(1, NewRHSV);
         return &I;