Make use of @llvm.assume in ValueTracking (computeKnownBits, etc.)
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineAddSub.cpp
index 600e09433cc4f303b66ad10c3054cbd75aabaf61..6287536b3e9ce81a1b97c480963e586e60c80f9d 100644 (file)
@@ -895,7 +895,8 @@ static bool checkRippleForAdd(const APInt &Op0KnownZero,
 /// This basically requires proving that the add in the original type would not
 /// overflow to change the sign bit or have a carry out.
 /// TODO: Handle this for Vectors.
-bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) {
+bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS,
+                                            Instruction *CxtI) {
   // There are different heuristics we can use for this.  Here are some simple
   // ones.
 
@@ -913,18 +914,19 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) {
   //
   // Since the carry into the most significant position is always equal to
   // the carry out of the addition, there is no signed overflow.
-  if (ComputeNumSignBits(LHS) > 1 && ComputeNumSignBits(RHS) > 1)
+  if (ComputeNumSignBits(LHS, 0, CxtI) > 1 &&
+      ComputeNumSignBits(RHS, 0, CxtI) > 1)
     return true;
 
   if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) {
     int BitWidth = IT->getBitWidth();
     APInt LHSKnownZero(BitWidth, 0);
     APInt LHSKnownOne(BitWidth, 0);
-    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne);
+    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI);
 
     APInt RHSKnownZero(BitWidth, 0);
     APInt RHSKnownOne(BitWidth, 0);
-    computeKnownBits(RHS, RHSKnownZero, RHSKnownOne);
+    computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI);
 
     // Addition of two 2's compliment numbers having opposite signs will never
     // overflow.
@@ -943,13 +945,14 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) {
 
 /// WillNotOverflowUnsignedAdd - Return true if we can prove that:
 ///    (zext (add LHS, RHS))  === (add (zext LHS), (zext RHS))
-bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) {
+bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS,
+                                              Instruction *CxtI) {
   // There are different heuristics we can use for this. Here is a simple one.
   // If the sign bit of LHS and that of RHS are both zero, no unsigned wrap.
   bool LHSKnownNonNegative, LHSKnownNegative;
   bool RHSKnownNonNegative, RHSKnownNegative;
-  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0);
-  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0);
+  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT);
+  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT);
   if (LHSKnownNonNegative && RHSKnownNonNegative)
     return true;
 
@@ -961,21 +964,23 @@ bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) {
 /// This basically requires proving that the add in the original type would not
 /// overflow to change the sign bit or have a carry out.
 /// TODO: Handle this for Vectors.
-bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS) {
+bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS,
+                                            Instruction *CxtI) {
   // If LHS and RHS each have at least two sign bits, the subtraction
   // cannot overflow.
-  if (ComputeNumSignBits(LHS) > 1 && ComputeNumSignBits(RHS) > 1)
+  if (ComputeNumSignBits(LHS, 0, CxtI) > 1 &&
+      ComputeNumSignBits(RHS, 0, CxtI) > 1)
     return true;
 
   if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) {
     unsigned BitWidth = IT->getBitWidth();
     APInt LHSKnownZero(BitWidth, 0);
     APInt LHSKnownOne(BitWidth, 0);
-    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne);
+    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI);
 
     APInt RHSKnownZero(BitWidth, 0);
     APInt RHSKnownOne(BitWidth, 0);
-    computeKnownBits(RHS, RHSKnownZero, RHSKnownOne);
+    computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI);
 
     // Subtraction of two 2's compliment numbers having identical signs will
     // never overflow.
@@ -990,12 +995,13 @@ bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS) {
 
 /// \brief Return true if we can prove that:
 ///    (sub LHS, RHS)  === (sub nuw LHS, RHS)
-bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS) {
+bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS,
+                                              Instruction *CxtI) {
   // If the LHS is negative and the RHS is non-negative, no unsigned wrap.
   bool LHSKnownNonNegative, LHSKnownNegative;
   bool RHSKnownNonNegative, RHSKnownNegative;
-  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0);
-  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0);
+  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT);
+  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT);
   if (LHSKnownNegative && RHSKnownNonNegative)
     return true;
 
@@ -1071,7 +1077,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
      return ReplaceInstUsesWith(I, V);
 
    if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(),
-                                  I.hasNoUnsignedWrap(), DL))
+                                  I.hasNoUnsignedWrap(), DL, TLI, DT, AT))
      return ReplaceInstUsesWith(I, V);
 
    // (A*B)+(A*C) -> A*(B+C) etc
@@ -1110,7 +1116,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
 
       if (ExtendAmt) {
         APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt);
-        if (!MaskedValueIsZero(XorLHS, Mask))
+        if (!MaskedValueIsZero(XorLHS, Mask, 0, &I))
           ExtendAmt = 0;
       }
 
@@ -1126,7 +1132,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
         IntegerType *IT = cast<IntegerType>(I.getType());
         APInt LHSKnownOne(IT->getBitWidth(), 0);
         APInt LHSKnownZero(IT->getBitWidth(), 0);
-        computeKnownBits(XorLHS, LHSKnownZero, LHSKnownOne);
+        computeKnownBits(XorLHS, LHSKnownZero, LHSKnownOne, 0, &I);
         if ((XorRHS->getValue() | LHSKnownZero).isAllOnesValue())
           return BinaryOperator::CreateSub(ConstantExpr::getAdd(XorRHS, CI),
                                            XorLHS);
@@ -1179,11 +1185,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
   if (IntegerType *IT = dyn_cast<IntegerType>(I.getType())) {
     APInt LHSKnownOne(IT->getBitWidth(), 0);
     APInt LHSKnownZero(IT->getBitWidth(), 0);
-    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne);
+    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &I);
     if (LHSKnownZero != 0) {
       APInt RHSKnownOne(IT->getBitWidth(), 0);
       APInt RHSKnownZero(IT->getBitWidth(), 0);
-      computeKnownBits(RHS, RHSKnownZero, RHSKnownOne);
+      computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &I);
 
       // No bits in common -> bitwise or.
       if ((LHSKnownZero|RHSKnownZero).isAllOnesValue())
@@ -1261,7 +1267,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
         ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType());
       if (LHSConv->hasOneUse() &&
           ConstantExpr::getSExt(CI, I.getType()) == RHSC &&
-          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) {
+          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) {
         // Insert the new, smaller add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                               CI, "addconv");
@@ -1277,7 +1283,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
       if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&&
           (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
           WillNotOverflowSignedAdd(LHSConv->getOperand(0),
-                                   RHSConv->getOperand(0))) {
+                                   RHSConv->getOperand(0), &I)) {
         // Insert the new integer add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                              RHSConv->getOperand(0), "addconv");
@@ -1325,11 +1331,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
   // TODO(jingyue): Consider WillNotOverflowSignedAdd and
   // WillNotOverflowUnsignedAdd to reduce the number of invocations of
   // computeKnownBits.
-  if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS)) {
+  if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS, &I)) {
     Changed = true;
     I.setHasNoSignedWrap(true);
   }
-  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedAdd(LHS, RHS)) {
+  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedAdd(LHS, RHS, &I)) {
     Changed = true;
     I.setHasNoUnsignedWrap(true);
   }
@@ -1344,7 +1350,8 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
   if (Value *V = SimplifyVectorOp(I))
     return ReplaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL))
+  if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL,
+                                  TLI, DT, AT))
     return ReplaceInstUsesWith(I, V);
 
   if (isa<Constant>(RHS)) {
@@ -1386,7 +1393,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
       ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType());
       if (LHSConv->hasOneUse() &&
           ConstantExpr::getSIToFP(CI, I.getType()) == CFP &&
-          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) {
+          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) {
         // Insert the new integer add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                               CI, "addconv");
@@ -1402,7 +1409,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
       if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&&
           (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
           WillNotOverflowSignedAdd(LHSConv->getOperand(0),
-                                   RHSConv->getOperand(0))) {
+                                   RHSConv->getOperand(0), &I)) {
         // Insert the new integer add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                               RHSConv->getOperand(0),"addconv");
@@ -1523,7 +1530,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
     return ReplaceInstUsesWith(I, V);
 
   if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(),
-                                 I.hasNoUnsignedWrap(), DL))
+                                 I.hasNoUnsignedWrap(), DL, TLI, DT, AT))
     return ReplaceInstUsesWith(I, V);
 
   // (A*B)-(A*C) -> A*(B-C) etc
@@ -1673,11 +1680,11 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
       }
 
   bool Changed = false;
-  if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1)) {
+  if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1, &I)) {
     Changed = true;
     I.setHasNoSignedWrap(true);
   }
-  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedSub(Op0, Op1)) {
+  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedSub(Op0, Op1, &I)) {
     Changed = true;
     I.setHasNoUnsignedWrap(true);
   }
@@ -1691,7 +1698,8 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
   if (Value *V = SimplifyVectorOp(I))
     return ReplaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL))
+  if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL,
+                                  TLI, DT, AT))
     return ReplaceInstUsesWith(I, V);
 
   if (isa<Constant>(Op0))