Use the new script to sort the includes of every file under lib.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineMulDivRem.cpp
index 668c34fc06c6158a1245d30cd16581f66f69f5eb..5cd611c4200adb819f93b491a18fd4fe28a47d2a 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "InstCombine.h"
+#include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/IntrinsicInst.h"
 #include "llvm/Support/PatternMatch.h"
 using namespace llvm;
 using namespace PatternMatch;
 
-/// SubOne - Subtract one from a ConstantInt.
-static Constant *SubOne(ConstantInt *C) {
-  return ConstantInt::get(C->getContext(), C->getValue()-1);
+
+/// simplifyValueKnownNonZero - The specific integer value is used in a context
+/// where it is known to be non-zero.  If this allows us to simplify the
+/// computation, do so and return the new operand, otherwise return null.
+static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) {
+  // If V has multiple uses, then we would have to do more analysis to determine
+  // if this is safe.  For example, the use could be in dynamically unreached
+  // code.
+  if (!V->hasOneUse()) return 0;
+  
+  bool MadeChange = false;
+
+  // ((1 << A) >>u B) --> (1 << (A-B))
+  // Because V cannot be zero, we know that B is less than A.
+  Value *A = 0, *B = 0, *PowerOf2 = 0;
+  if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(PowerOf2), m_Value(A))),
+                      m_Value(B))) &&
+      // The "1" can be any value known to be a power of 2.
+      isPowerOfTwo(PowerOf2, IC.getDataLayout())) {
+    A = IC.Builder->CreateSub(A, B);
+    return IC.Builder->CreateShl(PowerOf2, A);
+  }
+  
+  // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it
+  // inexact.  Similarly for <<.
+  if (BinaryOperator *I = dyn_cast<BinaryOperator>(V))
+    if (I->isLogicalShift() &&
+        isPowerOfTwo(I->getOperand(0), IC.getDataLayout())) {
+      // We know that this is an exact/nuw shift and that the input is a
+      // non-zero context as well.
+      if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC)) {
+        I->setOperand(0, V2);
+        MadeChange = true;
+      }
+      
+      if (I->getOpcode() == Instruction::LShr && !I->isExact()) {
+        I->setIsExact();
+        MadeChange = true;
+      }
+      
+      if (I->getOpcode() == Instruction::Shl && !I->hasNoUnsignedWrap()) {
+        I->setHasNoUnsignedWrap();
+        MadeChange = true;
+      }
+    }
+
+  // TODO: Lots more we could do here:
+  //    If V is a phi node, we can call this on each of its operands.
+  //    "select cond, X, 0" can simplify to "X".
+  
+  return MadeChange ? V : 0;
 }
 
+
 /// MultiplyOverflows - True if the multiply can not be expressed in an int
 /// this size.
 static bool MultiplyOverflows(ConstantInt *C1, ConstantInt *C2, bool sign) {
   uint32_t W = C1->getBitWidth();
   APInt LHSExt = C1->getValue(), RHSExt = C2->getValue();
   if (sign) {
-    LHSExt.sext(W * 2);
-    RHSExt.sext(W * 2);
+    LHSExt = LHSExt.sext(W * 2);
+    RHSExt = RHSExt.sext(W * 2);
   } else {
-    LHSExt.zext(W * 2);
-    RHSExt.zext(W * 2);
+    LHSExt = LHSExt.zext(W * 2);
+    RHSExt = RHSExt.zext(W * 2);
   }
   
   APInt MulExt = LHSExt * RHSExt;
@@ -47,62 +97,71 @@ static bool MultiplyOverflows(ConstantInt *C1, ConstantInt *C2, bool sign) {
 }
 
 Instruction *InstCombiner::visitMul(BinaryOperator &I) {
-  bool Changed = SimplifyCommutative(I);
+  bool Changed = SimplifyAssociativeOrCommutative(I);
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
-  if (isa<UndefValue>(Op1))              // undef * X -> 0
-    return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
+  if (Value *V = SimplifyMulInst(Op0, Op1, TD))
+    return ReplaceInstUsesWith(I, V);
 
-  // Simplify mul instructions with a constant RHS.
-  if (Constant *Op1C = dyn_cast<Constant>(Op1)) {
-    if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1C)) {
-
-      // ((X << C1)*C2) == (X * (C2 << C1))
-      if (BinaryOperator *SI = dyn_cast<BinaryOperator>(Op0))
-        if (SI->getOpcode() == Instruction::Shl)
-          if (Constant *ShOp = dyn_cast<Constant>(SI->getOperand(1)))
-            return BinaryOperator::CreateMul(SI->getOperand(0),
-                                        ConstantExpr::getShl(CI, ShOp));
-
-      if (CI->isZero())
-        return ReplaceInstUsesWith(I, Op1C);  // X * 0  == 0
-      if (CI->equalsInt(1))                  // X * 1  == X
-        return ReplaceInstUsesWith(I, Op0);
-      if (CI->isAllOnesValue())              // X * -1 == 0 - X
-        return BinaryOperator::CreateNeg(Op0, I.getName());
-
-      const APInt& Val = cast<ConstantInt>(CI)->getValue();
-      if (Val.isPowerOf2()) {          // Replace X*(2^C) with X << C
-        return BinaryOperator::CreateShl(Op0,
-                 ConstantInt::get(Op0->getType(), Val.logBase2()));
-      }
-    } else if (isa<VectorType>(Op1C->getType())) {
-      if (Op1C->isNullValue())
-        return ReplaceInstUsesWith(I, Op1C);
-
-      if (ConstantVector *Op1V = dyn_cast<ConstantVector>(Op1C)) {
-        if (Op1V->isAllOnesValue())              // X * -1 == 0 - X
-          return BinaryOperator::CreateNeg(Op0, I.getName());
-
-        // As above, vector X*splat(1.0) -> X in all defined cases.
-        if (Constant *Splat = Op1V->getSplatValue()) {
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(Splat))
-            if (CI->equalsInt(1))
-              return ReplaceInstUsesWith(I, Op0);
-        }
-      }
+  if (Value *V = SimplifyUsingDistributiveLaws(I))
+    return ReplaceInstUsesWith(I, V);
+
+  if (match(Op1, m_AllOnes()))  // X * -1 == 0 - X
+    return BinaryOperator::CreateNeg(Op0, I.getName());
+  
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+    
+    // ((X << C1)*C2) == (X * (C2 << C1))
+    if (BinaryOperator *SI = dyn_cast<BinaryOperator>(Op0))
+      if (SI->getOpcode() == Instruction::Shl)
+        if (Constant *ShOp = dyn_cast<Constant>(SI->getOperand(1)))
+          return BinaryOperator::CreateMul(SI->getOperand(0),
+                                           ConstantExpr::getShl(CI, ShOp));
+    
+    const APInt &Val = CI->getValue();
+    if (Val.isPowerOf2()) {          // Replace X*(2^C) with X << C
+      Constant *NewCst = ConstantInt::get(Op0->getType(), Val.logBase2());
+      BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, NewCst);
+      if (I.hasNoSignedWrap()) Shl->setHasNoSignedWrap();
+      if (I.hasNoUnsignedWrap()) Shl->setHasNoUnsignedWrap();
+      return Shl;
     }
     
-    if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0))
-      if (Op0I->getOpcode() == Instruction::Add && Op0I->hasOneUse() &&
-          isa<ConstantInt>(Op0I->getOperand(1)) && isa<ConstantInt>(Op1C)) {
-        // Canonicalize (X+C1)*C2 -> X*C2+C1*C2.
-        Value *Add = Builder->CreateMul(Op0I->getOperand(0), Op1C, "tmp");
-        Value *C1C2 = Builder->CreateMul(Op1C, Op0I->getOperand(1));
-        return BinaryOperator::CreateAdd(Add, C1C2);
-        
+    // Canonicalize (X+C1)*CI -> X*CI+C1*CI.
+    { Value *X; ConstantInt *C1;
+      if (Op0->hasOneUse() &&
+          match(Op0, m_Add(m_Value(X), m_ConstantInt(C1)))) {
+        Value *Add = Builder->CreateMul(X, CI);
+        return BinaryOperator::CreateAdd(Add, Builder->CreateMul(C1, CI));
       }
+    }
 
+    // (Y - X) * (-(2**n)) -> (X - Y) * (2**n), for positive nonzero n
+    // (Y + const) * (-(2**n)) -> (-constY) * (2**n), for positive nonzero n
+    // The "* (2**n)" thus becomes a potential shifting opportunity.
+    {
+      const APInt &   Val = CI->getValue();
+      const APInt &PosVal = Val.abs();
+      if (Val.isNegative() && PosVal.isPowerOf2()) {
+        Value *X = 0, *Y = 0;
+        if (Op0->hasOneUse()) {
+          ConstantInt *C1;
+          Value *Sub = 0;
+          if (match(Op0, m_Sub(m_Value(Y), m_Value(X))))
+            Sub = Builder->CreateSub(X, Y, "suba");
+          else if (match(Op0, m_Add(m_Value(Y), m_ConstantInt(C1))))
+            Sub = Builder->CreateSub(Builder->CreateNeg(C1), Y, "subc");
+          if (Sub)
+            return
+              BinaryOperator::CreateMul(Sub,
+                                        ConstantInt::get(Y->getType(), PosVal));
+        }
+      }
+    }
+  }
+  
+  // Simplify mul instructions with a constant RHS.
+  if (isa<Constant>(Op1)) {    
     // Try to fold constant mul into select arguments.
     if (SelectInst *SI = dyn_cast<SelectInst>(Op0))
       if (Instruction *R = FoldOpIntoSelect(I, SI))
@@ -135,8 +194,8 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
          BO->getOpcode() == Instruction::SDiv)) {
       Value *Op0BO = BO->getOperand(0), *Op1BO = BO->getOperand(1);
 
-      // If the division is exact, X % Y is zero.
-      if (SDivOperator *SDiv = dyn_cast<SDivOperator>(BO))
+      // If the division is exact, X % Y is zero, so we end up with X or -X.
+      if (PossiblyExactOperator *SDiv = dyn_cast<PossiblyExactOperator>(BO))
         if (SDiv->isExact()) {
           if (Op1BO == Op1C)
             return ReplaceInstUsesWith(I, Op0BO);
@@ -173,7 +232,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
   // If one of the operands of the multiply is a cast from a boolean value, then
   // we know the bool is either zero or one, so this is a 'masking' multiply.
   //   X * Y (where Y is 0 or 1) -> X & (0-Y)
-  if (!isa<VectorType>(I.getType())) {
+  if (!I.getType()->isVectorTy()) {
     // -2 is "-1 << 1" so it is all bits set except the low one.
     APInt Negative2(I.getType()->getPrimitiveSizeInBits(), (uint64_t)-2, true);
     
@@ -185,7 +244,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
 
     if (BoolCast) {
       Value *V = Builder->CreateSub(Constant::getNullValue(I.getType()),
-                                    BoolCast, "tmp");
+                                    BoolCast);
       return BinaryOperator::CreateAnd(V, OtherOp);
     }
   }
@@ -193,26 +252,62 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
   return Changed ? &I : 0;
 }
 
+//
+// Detect pattern:
+//
+// log2(Y*0.5)
+//
+// And check for corresponding fast math flags
+//
+
+static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) {
+
+   if (!Op->hasOneUse())
+     return;
+
+   IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op);
+   if (!II)
+     return;
+   if (II->getIntrinsicID() != Intrinsic::log2 || !II->hasUnsafeAlgebra())
+     return;
+   Log2 = II;
+
+   Value *OpLog2Of = II->getArgOperand(0);
+   if (!OpLog2Of->hasOneUse())
+     return;
+
+   Instruction *I = dyn_cast<Instruction>(OpLog2Of);
+   if (!I)
+     return;
+   if (I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra())
+     return;
+              
+   ConstantFP *CFP = dyn_cast<ConstantFP>(I->getOperand(0));
+   if (CFP && CFP->isExactlyValue(0.5)) {
+     Y = I->getOperand(1);
+     return;
+   }
+   CFP = dyn_cast<ConstantFP>(I->getOperand(1));
+   if (CFP && CFP->isExactlyValue(0.5))
+     Y = I->getOperand(0);
+} 
+
 Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
-  bool Changed = SimplifyCommutative(I);
+  bool Changed = SimplifyAssociativeOrCommutative(I);
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
-  // Simplify mul instructions with a constant RHS...
+  // Simplify mul instructions with a constant RHS.
   if (Constant *Op1C = dyn_cast<Constant>(Op1)) {
     if (ConstantFP *Op1F = dyn_cast<ConstantFP>(Op1C)) {
       // "In IEEE floating point, x*1 is not equivalent to x for nans.  However,
       // ANSI says we can drop signals, so we can do this anyway." (from GCC)
       if (Op1F->isExactlyValue(1.0))
-        return ReplaceInstUsesWith(I, Op0);  // Eliminate 'mul double %X, 1.0'
-    } else if (isa<VectorType>(Op1C->getType())) {
-      if (ConstantVector *Op1V = dyn_cast<ConstantVector>(Op1C)) {
-        // As above, vector X*splat(1.0) -> X in all defined cases.
-        if (Constant *Splat = Op1V->getSplatValue()) {
-          if (ConstantFP *F = dyn_cast<ConstantFP>(Splat))
-            if (F->isExactlyValue(1.0))
-              return ReplaceInstUsesWith(I, Op0);
-        }
-      }
+        return ReplaceInstUsesWith(I, Op0);  // Eliminate 'fmul double %X, 1.0'
+    } else if (ConstantDataVector *Op1V = dyn_cast<ConstantDataVector>(Op1C)) {
+      // As above, vector X*splat(1.0) -> X in all defined cases.
+      if (ConstantFP *F = dyn_cast_or_null<ConstantFP>(Op1V->getSplatValue()))
+        if (F->isExactlyValue(1.0))
+          return ReplaceInstUsesWith(I, Op0);
     }
 
     // Try to fold constant mul into select arguments.
@@ -229,6 +324,33 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
     if (Value *Op1v = dyn_castFNegVal(Op1))
       return BinaryOperator::CreateFMul(Op0v, Op1v);
 
+  // Under unsafe algebra do:
+  // X * log2(0.5*Y) = X*log2(Y) - X
+  if (I.hasUnsafeAlgebra()) {
+    Value *OpX = NULL;
+    Value *OpY = NULL;
+    IntrinsicInst *Log2;
+    detectLog2OfHalf(Op0, OpY, Log2);
+    if (OpY) {
+      OpX = Op1;
+    } else {
+      detectLog2OfHalf(Op1, OpY, Log2);
+      if (OpY) {
+        OpX = Op0;
+      }
+    }
+    // if pattern detected emit alternate sequence
+    if (OpX && OpY) {
+      Log2->setArgOperand(0, OpY);
+      Value *FMulVal = Builder->CreateFMul(OpX, Log2);
+      Instruction *FMul = cast<Instruction>(FMulVal);
+      FMul->copyFastMathFlags(Log2);
+      Instruction *FSub = BinaryOperator::CreateFSub(FMulVal, OpX);
+      FSub->copyFastMathFlags(Log2);
+      return FSub;
+    }
+  }
+
   return Changed ? &I : 0;
 }
 
@@ -304,28 +426,6 @@ bool InstCombiner::SimplifyDivRemOfSelect(BinaryOperator &I) {
 }
 
 
-/// This function implements the transforms on div instructions that work
-/// regardless of the kind of div instruction it is (udiv, sdiv, or fdiv). It is
-/// used by the visitors to those instructions.
-/// @brief Transforms common to all three div instructions
-Instruction *InstCombiner::commonDivTransforms(BinaryOperator &I) {
-  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-
-  // undef / X -> 0        for integer.
-  // undef / X -> undef    for FP (the undef could be a snan).
-  if (isa<UndefValue>(Op0)) {
-    if (Op0->getType()->isFPOrFPVectorTy())
-      return ReplaceInstUsesWith(I, Op0);
-    return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
-  }
-
-  // X / undef -> undef
-  if (isa<UndefValue>(Op1))
-    return ReplaceInstUsesWith(I, Op1);
-
-  return 0;
-}
-
 /// This function implements the transforms common to both integer division
 /// instructions (udiv and sdiv). It is called by the visitors to those integer
 /// division instructions.
@@ -333,31 +433,18 @@ Instruction *InstCombiner::commonDivTransforms(BinaryOperator &I) {
 Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
-  // (sdiv X, X) --> 1     (udiv X, X) --> 1
-  if (Op0 == Op1) {
-    if (const VectorType *Ty = dyn_cast<VectorType>(I.getType())) {
-      Constant *CI = ConstantInt::get(Ty->getElementType(), 1);
-      std::vector<Constant*> Elts(Ty->getNumElements(), CI);
-      return ReplaceInstUsesWith(I, ConstantVector::get(Elts));
-    }
-
-    Constant *CI = ConstantInt::get(I.getType(), 1);
-    return ReplaceInstUsesWith(I, CI);
+  // The RHS is known non-zero.
+  if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this)) {
+    I.setOperand(1, V);
+    return &I;
   }
   
-  if (Instruction *Common = commonDivTransforms(I))
-    return Common;
-  
   // Handle cases involving: [su]div X, (select Cond, Y, Z)
   // This does not apply for fdiv.
   if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I))
     return &I;
 
   if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) {
-    // div X, 1 == X
-    if (RHS->equalsInt(1))
-      return ReplaceInstUsesWith(I, Op0);
-
     // (X / C1) / C2  -> X / (C1*C2)
     if (Instruction *LHS = dyn_cast<Instruction>(Op0))
       if (Instruction::BinaryOps(LHS->getOpcode()) == I.getOpcode())
@@ -365,9 +452,8 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) {
           if (MultiplyOverflows(RHS, LHSRHS,
                                 I.getOpcode()==Instruction::SDiv))
             return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
-          else 
-            return BinaryOperator::Create(I.getOpcode(), LHS->getOperand(0),
-                                      ConstantExpr::getMul(RHS, LHSRHS));
+          return BinaryOperator::Create(I.getOpcode(), LHS->getOperand(0),
+                                        ConstantExpr::getMul(RHS, LHSRHS));
         }
 
     if (!RHS->isZero()) { // avoid X udiv 0
@@ -380,90 +466,127 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) {
     }
   }
 
-  // 0 / X == 0, we don't need to preserve faults!
-  if (ConstantInt *LHS = dyn_cast<ConstantInt>(Op0))
-    if (LHS->equalsInt(0))
-      return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
-
-  // It can't be division by zero, hence it must be division by one.
-  if (I.getType()->isIntegerTy(1))
-    return ReplaceInstUsesWith(I, Op0);
+  // See if we can fold away this div instruction.
+  if (SimplifyDemandedInstructionBits(I))
+    return &I;
 
-  if (ConstantVector *Op1V = dyn_cast<ConstantVector>(Op1)) {
-    if (ConstantInt *X = cast_or_null<ConstantInt>(Op1V->getSplatValue()))
-      // div X, 1 == X
-      if (X->isOne())
-        return ReplaceInstUsesWith(I, Op0);
+  // (X - (X rem Y)) / Y -> X / Y; usually originates as ((X / Y) * Y) / Y
+  Value *X = 0, *Z = 0;
+  if (match(Op0, m_Sub(m_Value(X), m_Value(Z)))) { // (X - Z) / Y; Y = Op1
+    bool isSigned = I.getOpcode() == Instruction::SDiv;
+    if ((isSigned && match(Z, m_SRem(m_Specific(X), m_Specific(Op1)))) ||
+        (!isSigned && match(Z, m_URem(m_Specific(X), m_Specific(Op1)))))
+      return BinaryOperator::Create(I.getOpcode(), X, Op1);
   }
 
   return 0;
 }
 
+/// dyn_castZExtVal - Checks if V is a zext or constant that can
+/// be truncated to Ty without losing bits.
+static Value *dyn_castZExtVal(Value *V, Type *Ty) {
+  if (ZExtInst *Z = dyn_cast<ZExtInst>(V)) {
+    if (Z->getSrcTy() == Ty)
+      return Z->getOperand(0);
+  } else if (ConstantInt *C = dyn_cast<ConstantInt>(V)) {
+    if (C->getValue().getActiveBits() <= cast<IntegerType>(Ty)->getBitWidth())
+      return ConstantExpr::getTrunc(C, Ty);
+  }
+  return 0;
+}
+
 Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
+  if (Value *V = SimplifyUDivInst(Op0, Op1, TD))
+    return ReplaceInstUsesWith(I, V);
+
   // Handle the integer div common cases
   if (Instruction *Common = commonIDivTransforms(I))
     return Common;
-
-  if (ConstantInt *C = dyn_cast<ConstantInt>(Op1)) {
+  
+  { 
     // X udiv 2^C -> X >> C
     // Check to see if this is an unsigned division with an exact power of 2,
     // if so, convert to a right shift.
-    if (C->getValue().isPowerOf2())  // 0 not included in isPowerOf2
-      return BinaryOperator::CreateLShr(Op0, 
-            ConstantInt::get(Op0->getType(), C->getValue().logBase2()));
+    const APInt *C;
+    if (match(Op1, m_Power2(C))) {
+      BinaryOperator *LShr =
+      BinaryOperator::CreateLShr(Op0, 
+                                 ConstantInt::get(Op0->getType(), 
+                                                  C->logBase2()));
+      if (I.isExact()) LShr->setIsExact();
+      return LShr;
+    }
+  }
 
+  if (ConstantInt *C = dyn_cast<ConstantInt>(Op1)) {
     // X udiv C, where C >= signbit
     if (C->getValue().isNegative()) {
-      Value *IC = Builder->CreateICmpULT( Op0, C);
+      Value *IC = Builder->CreateICmpULT(Op0, C);
       return SelectInst::Create(IC, Constant::getNullValue(I.getType()),
                                 ConstantInt::get(I.getType(), 1));
     }
   }
 
+  // (x lshr C1) udiv C2 --> x udiv (C2 << C1)
+  if (ConstantInt *C2 = dyn_cast<ConstantInt>(Op1)) {
+    Value *X;
+    ConstantInt *C1;
+    if (match(Op0, m_LShr(m_Value(X), m_ConstantInt(C1)))) {
+      APInt NC = C2->getValue().shl(C1->getLimitedValue(C1->getBitWidth()-1));
+      return BinaryOperator::CreateUDiv(X, Builder->getInt(NC));
+    }
+  }
+
   // X udiv (C1 << N), where C1 is "1<<C2"  -->  X >> (N+C2)
-  if (BinaryOperator *RHSI = dyn_cast<BinaryOperator>(I.getOperand(1))) {
-    if (RHSI->getOpcode() == Instruction::Shl &&
-        isa<ConstantInt>(RHSI->getOperand(0))) {
-      const APInt& C1 = cast<ConstantInt>(RHSI->getOperand(0))->getValue();
-      if (C1.isPowerOf2()) {
-        Value *N = RHSI->getOperand(1);
-        const Type *NTy = N->getType();
-        if (uint32_t C2 = C1.logBase2())
-          N = Builder->CreateAdd(N, ConstantInt::get(NTy, C2), "tmp");
-        return BinaryOperator::CreateLShr(Op0, N);
-      }
+  { const APInt *CI; Value *N;
+    if (match(Op1, m_Shl(m_Power2(CI), m_Value(N))) ||
+        match(Op1, m_ZExt(m_Shl(m_Power2(CI), m_Value(N))))) {
+      if (*CI != 1)
+        N = Builder->CreateAdd(N,
+                               ConstantInt::get(N->getType(), CI->logBase2()));
+      if (ZExtInst *Z = dyn_cast<ZExtInst>(Op1))
+        N = Builder->CreateZExt(N, Z->getDestTy());
+      if (I.isExact())
+        return BinaryOperator::CreateExactLShr(Op0, N);
+      return BinaryOperator::CreateLShr(Op0, N);
     }
   }
   
   // udiv X, (Select Cond, C1, C2) --> Select Cond, (shr X, C1), (shr X, C2)
   // where C1&C2 are powers of two.
-  if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) 
-    if (ConstantInt *STO = dyn_cast<ConstantInt>(SI->getOperand(1)))
-      if (ConstantInt *SFO = dyn_cast<ConstantInt>(SI->getOperand(2)))  {
-        const APInt &TVA = STO->getValue(), &FVA = SFO->getValue();
-        if (TVA.isPowerOf2() && FVA.isPowerOf2()) {
-          // Compute the shift amounts
-          uint32_t TSA = TVA.logBase2(), FSA = FVA.logBase2();
-          // Construct the "on true" case of the select
-          Constant *TC = ConstantInt::get(Op0->getType(), TSA);
-          Value *TSI = Builder->CreateLShr(Op0, TC, SI->getName()+".t");
+  { Value *Cond; const APInt *C1, *C2;
+    if (match(Op1, m_Select(m_Value(Cond), m_Power2(C1), m_Power2(C2)))) {
+      // Construct the "on true" case of the select
+      Value *TSI = Builder->CreateLShr(Op0, C1->logBase2(), Op1->getName()+".t",
+                                       I.isExact());
   
-          // Construct the "on false" case of the select
-          Constant *FC = ConstantInt::get(Op0->getType(), FSA); 
-          Value *FSI = Builder->CreateLShr(Op0, FC, SI->getName()+".f");
+      // Construct the "on false" case of the select
+      Value *FSI = Builder->CreateLShr(Op0, C2->logBase2(), Op1->getName()+".f",
+                                       I.isExact());
+      
+      // construct the select instruction and return it.
+      return SelectInst::Create(Cond, TSI, FSI);
+    }
+  }
+
+  // (zext A) udiv (zext B) --> zext (A udiv B)
+  if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0))
+    if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy()))
+      return new ZExtInst(Builder->CreateUDiv(ZOp0->getOperand(0), ZOp1, "div",
+                                              I.isExact()),
+                          I.getType());
 
-          // construct the select instruction and return it.
-          return SelectInst::Create(SI->getOperand(0), TSI, FSI, SI->getName());
-        }
-      }
   return 0;
 }
 
 Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
+  if (Value *V = SimplifySDivInst(Op0, Op1, TD))
+    return ReplaceInstUsesWith(I, V);
+
   // Handle the integer div common cases
   if (Instruction *Common = commonIDivTransforms(I))
     return Common;
@@ -473,20 +596,17 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
     if (RHS->isAllOnesValue())
       return BinaryOperator::CreateNeg(Op0);
 
-    // sdiv X, C  -->  ashr X, log2(C)
-    if (cast<SDivOperator>(&I)->isExact() &&
-        RHS->getValue().isNonNegative() &&
+    // sdiv X, C  -->  ashr exact X, log2(C)
+    if (I.isExact() && RHS->getValue().isNonNegative() &&
         RHS->getValue().isPowerOf2()) {
       Value *ShAmt = llvm::ConstantInt::get(RHS->getType(),
                                             RHS->getValue().exactLogBase2());
-      return BinaryOperator::CreateAShr(Op0, ShAmt, I.getName());
+      return BinaryOperator::CreateExactAShr(Op0, ShAmt, I.getName());
     }
 
     // -X/C  -->  X/-C  provided the negation doesn't overflow.
     if (SubOperator *Sub = dyn_cast<SubOperator>(Op0))
-      if (isa<Constant>(Sub->getOperand(0)) &&
-          cast<Constant>(Sub->getOperand(0))->isNullValue() &&
-          Sub->hasNoSignedWrap())
+      if (match(Sub->getOperand(0), m_Zero()) && Sub->hasNoSignedWrap())
         return BinaryOperator::CreateSDiv(Sub->getOperand(1),
                                           ConstantExpr::getNeg(RHS));
   }
@@ -500,9 +620,8 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
         // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set
         return BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
       }
-      ConstantInt *ShiftedInt;
-      if (match(Op1, m_Shl(m_ConstantInt(ShiftedInt), m_Value())) &&
-          ShiftedInt->getValue().isPowerOf2()) {
+      
+      if (match(Op1, m_Shl(m_Power2(), m_Value()))) {
         // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y)
         // Safe because the only negative value (1 << Y) can take on is
         // INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have
@@ -516,27 +635,22 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
 }
 
 Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
-  return commonDivTransforms(I);
-}
-
-/// This function implements the transforms on rem instructions that work
-/// regardless of the kind of rem instruction it is (urem, srem, or frem). It 
-/// is used by the visitors to those instructions.
-/// @brief Transforms common to all three rem instructions
-Instruction *InstCombiner::commonRemTransforms(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
-  if (isa<UndefValue>(Op0)) {             // undef % X -> 0
-    if (I.getType()->isFPOrFPVectorTy())
-      return ReplaceInstUsesWith(I, Op0);  // X % undef -> undef (could be SNaN)
-    return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
-  }
-  if (isa<UndefValue>(Op1))
-    return ReplaceInstUsesWith(I, Op1);  // X % undef -> undef
+  if (Value *V = SimplifyFDivInst(Op0, Op1, TD))
+    return ReplaceInstUsesWith(I, V);
 
-  // Handle cases involving: rem X, (select Cond, Y, Z)
-  if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I))
-    return &I;
+  if (ConstantFP *Op1C = dyn_cast<ConstantFP>(Op1)) {
+    const APFloat &Op1F = Op1C->getValueAPF();
+
+    // If the divisor has an exact multiplicative inverse we can turn the fdiv
+    // into a cheaper fmul.
+    APFloat Reciprocal(Op1F.getSemantics());
+    if (Op1F.getExactInverse(&Reciprocal)) {
+      ConstantFP *RFP = ConstantFP::get(Builder->getContext(), Reciprocal);
+      return BinaryOperator::CreateFMul(Op0, RFP);
+    }
+  }
 
   return 0;
 }
@@ -548,22 +662,17 @@ Instruction *InstCombiner::commonRemTransforms(BinaryOperator &I) {
 Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
-  if (Instruction *common = commonRemTransforms(I))
-    return common;
-
-  // 0 % X == 0 for integer, we don't need to preserve faults!
-  if (Constant *LHS = dyn_cast<Constant>(Op0))
-    if (LHS->isNullValue())
-      return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
+  // The RHS is known non-zero.
+  if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this)) {
+    I.setOperand(1, V);
+    return &I;
+  }
 
-  if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) {
-    // X % 0 == undef, we don't need to preserve faults!
-    if (RHS->equalsInt(0))
-      return ReplaceInstUsesWith(I, UndefValue::get(I.getType()));
-    
-    if (RHS->equalsInt(1))  // X % 1 == 0
-      return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
+  // Handle cases involving: rem X, (select Cond, Y, Z)
+  if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I))
+    return &I;
 
+  if (isa<ConstantInt>(Op1)) {
     if (Instruction *Op0I = dyn_cast<Instruction>(Op0)) {
       if (SelectInst *SI = dyn_cast<SelectInst>(Op0I)) {
         if (Instruction *R = FoldOpIntoSelect(I, SI))
@@ -585,53 +694,52 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) {
 Instruction *InstCombiner::visitURem(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
+  if (Value *V = SimplifyURemInst(Op0, Op1, TD))
+    return ReplaceInstUsesWith(I, V);
+
   if (Instruction *common = commonIRemTransforms(I))
     return common;
   
-  if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) {
-    // X urem C^2 -> X and C
-    // Check to see if this is an unsigned remainder with an exact power of 2,
-    // if so, convert to a bitwise and.
-    if (ConstantInt *C = dyn_cast<ConstantInt>(RHS))
-      if (C->getValue().isPowerOf2())
-        return BinaryOperator::CreateAnd(Op0, SubOne(C));
+  // X urem C^2 -> X and C-1
+  { const APInt *C;
+    if (match(Op1, m_Power2(C)))
+      return BinaryOperator::CreateAnd(Op0,
+                                       ConstantInt::get(I.getType(), *C-1));
   }
 
-  if (Instruction *RHSI = dyn_cast<Instruction>(I.getOperand(1))) {
-    // Turn A % (C << N), where C is 2^k, into A & ((C << N)-1)  
-    if (RHSI->getOpcode() == Instruction::Shl &&
-        isa<ConstantInt>(RHSI->getOperand(0))) {
-      if (cast<ConstantInt>(RHSI->getOperand(0))->getValue().isPowerOf2()) {
-        Constant *N1 = Constant::getAllOnesValue(I.getType());
-        Value *Add = Builder->CreateAdd(RHSI, N1, "tmp");
-        return BinaryOperator::CreateAnd(Op0, Add);
-      }
-    }
+  // Turn A % (C << N), where C is 2^k, into A & ((C << N)-1)  
+  if (match(Op1, m_Shl(m_Power2(), m_Value()))) {
+    Constant *N1 = Constant::getAllOnesValue(I.getType());
+    Value *Add = Builder->CreateAdd(Op1, N1);
+    return BinaryOperator::CreateAnd(Op0, Add);
   }
 
-  // urem X, (select Cond, 2^C1, 2^C2) --> select Cond, (and X, C1), (and X, C2)
-  // where C1&C2 are powers of two.
-  if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) {
-    if (ConstantInt *STO = dyn_cast<ConstantInt>(SI->getOperand(1)))
-      if (ConstantInt *SFO = dyn_cast<ConstantInt>(SI->getOperand(2))) {
-        // STO == 0 and SFO == 0 handled above.
-        if ((STO->getValue().isPowerOf2()) && 
-            (SFO->getValue().isPowerOf2())) {
-          Value *TrueAnd = Builder->CreateAnd(Op0, SubOne(STO),
-                                              SI->getName()+".t");
-          Value *FalseAnd = Builder->CreateAnd(Op0, SubOne(SFO),
-                                               SI->getName()+".f");
-          return SelectInst::Create(SI->getOperand(0), TrueAnd, FalseAnd);
-        }
-      }
+  // urem X, (select Cond, 2^C1, 2^C2) -->
+  //    select Cond, (and X, C1-1), (and X, C2-1)
+  // when C1&C2 are powers of two.
+  { Value *Cond; const APInt *C1, *C2;
+    if (match(Op1, m_Select(m_Value(Cond), m_Power2(C1), m_Power2(C2)))) {
+      Value *TrueAnd = Builder->CreateAnd(Op0, *C1-1, Op1->getName()+".t");
+      Value *FalseAnd = Builder->CreateAnd(Op0, *C2-1, Op1->getName()+".f");
+      return SelectInst::Create(Cond, TrueAnd, FalseAnd);
+    }
   }
-  
+
+  // (zext A) urem (zext B) --> zext (A urem B)
+  if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0))
+    if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy()))
+      return new ZExtInst(Builder->CreateURem(ZOp0->getOperand(0), ZOp1),
+                          I.getType());
+
   return 0;
 }
 
 Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
+  if (Value *V = SimplifySRemInst(Op0, Op1, TD))
+    return ReplaceInstUsesWith(I, V);
+
   // Handle the integer rem common cases
   if (Instruction *Common = commonIRemTransforms(I))
     return Common;
@@ -657,28 +765,36 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
   }
 
   // If it's a constant vector, flip any negative values positive.
-  if (ConstantVector *RHSV = dyn_cast<ConstantVector>(Op1)) {
-    unsigned VWidth = RHSV->getNumOperands();
+  if (isa<ConstantVector>(Op1) || isa<ConstantDataVector>(Op1)) {
+    Constant *C = cast<Constant>(Op1);
+    unsigned VWidth = C->getType()->getVectorNumElements();
 
     bool hasNegative = false;
-    for (unsigned i = 0; !hasNegative && i != VWidth; ++i)
-      if (ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV->getOperand(i)))
-        if (RHS->getValue().isNegative())
+    bool hasMissing = false;
+    for (unsigned i = 0; i != VWidth; ++i) {
+      Constant *Elt = C->getAggregateElement(i);
+      if (Elt == 0) {
+        hasMissing = true;
+        break;
+      }
+
+      if (ConstantInt *RHS = dyn_cast<ConstantInt>(Elt))
+        if (RHS->isNegative())
           hasNegative = true;
+    }
 
-    if (hasNegative) {
-      std::vector<Constant *> Elts(VWidth);
+    if (hasNegative && !hasMissing) {
+      SmallVector<Constant *, 16> Elts(VWidth);
       for (unsigned i = 0; i != VWidth; ++i) {
-        if (ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV->getOperand(i))) {
-          if (RHS->getValue().isNegative())
+        Elts[i] = C->getAggregateElement(i);  // Handle undef, etc.
+        if (ConstantInt *RHS = dyn_cast<ConstantInt>(Elts[i])) {
+          if (RHS->isNegative())
             Elts[i] = cast<ConstantInt>(ConstantExpr::getNeg(RHS));
-          else
-            Elts[i] = RHS;
         }
       }
 
       Constant *NewRHSV = ConstantVector::get(Elts);
-      if (NewRHSV != RHSV) {
+      if (NewRHSV != C) {  // Don't loop on -MININT
         Worklist.AddValue(I.getOperand(1));
         I.setOperand(1, NewRHSV);
         return &I;
@@ -690,6 +806,14 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
 }
 
 Instruction *InstCombiner::visitFRem(BinaryOperator &I) {
-  return commonRemTransforms(I);
-}
+  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
+  if (Value *V = SimplifyFRemInst(Op0, Op1, TD))
+    return ReplaceInstUsesWith(I, V);
+
+  // Handle cases involving: rem X, (select Cond, Y, Z)
+  if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I))
+    return &I;
+
+  return 0;
+}