DataLayout is mandatory, update the API to reflect it with references.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineAddSub.cpp
index 37ae797fbec06292ee7ec95b8e77d24d1fdd8000..c608f84bc7bb5ec015480e9dced46671daac7dd2 100644 (file)
@@ -11,7 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "InstCombine.h"
+#include "InstCombineInternal.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/IR/DataLayout.h"
@@ -751,8 +751,7 @@ Value *FAddCombine::createNaryFAdd
   return LastVal;
 }
 
-Value *FAddCombine::createFSub
-  (Value *Opnd0, Value *Opnd1) {
+Value *FAddCombine::createFSub(Value *Opnd0, Value *Opnd1) {
   Value *V = Builder->CreateFSub(Opnd0, Opnd1);
   if (Instruction *I = dyn_cast<Instruction>(V))
     createInstPostProc(I);
@@ -760,15 +759,14 @@ Value *FAddCombine::createFSub
 }
 
 Value *FAddCombine::createFNeg(Value *V) {
-  Value *Zero = cast<Value>(ConstantFP::get(V->getType(), 0.0));
+  Value *Zero = cast<Value>(ConstantFP::getZeroValueForNegation(V->getType()));
   Value *NewV = createFSub(Zero, V);
   if (Instruction *I = dyn_cast<Instruction>(NewV))
     createInstPostProc(I, true); // fneg's don't receive instruction numbers.
   return NewV;
 }
 
-Value *FAddCombine::createFAdd
-  (Value *Opnd0, Value *Opnd1) {
+Value *FAddCombine::createFAdd(Value *Opnd0, Value *Opnd1) {
   Value *V = Builder->CreateFAdd(Opnd0, Opnd1);
   if (Instruction *I = dyn_cast<Instruction>(V))
     createInstPostProc(I);
@@ -789,8 +787,7 @@ Value *FAddCombine::createFDiv(Value *Opnd0, Value *Opnd1) {
   return V;
 }
 
-void FAddCombine::createInstPostProc(Instruction *NewInstr,
-                                     bool NoNumber) {
+void FAddCombine::createInstPostProc(Instruction *NewInstr, bool NoNumber) {
   NewInstr->setDebugLoc(Instr->getDebugLoc());
 
   // Keep track of the number of instruction created.
@@ -840,8 +837,7 @@ unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) {
 // <C, V>             "fmul V, C"      false
 //
 // NOTE: Keep this function in sync with FAddCombine::calcInstrNumber.
-Value *FAddCombine::createAddendVal
-  (const FAddend &Opnd, bool &NeedNeg) {
+Value *FAddCombine::createAddendVal(const FAddend &Opnd, bool &NeedNeg) {
   const FAddendCoef &Coeff = Opnd.getCoef();
 
   if (Opnd.isConstant()) {
@@ -894,9 +890,8 @@ static bool checkRippleForAdd(const APInt &Op0KnownZero,
 ///    (sext (add LHS, RHS))  === (add (sext LHS), (sext RHS))
 /// This basically requires proving that the add in the original type would not
 /// overflow to change the sign bit or have a carry out.
-/// TODO: Handle this for Vectors.
 bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS,
-                                            Instruction *CxtI) {
+                                            Instruction &CxtI) {
   // There are different heuristics we can use for this.  Here are some simple
   // ones.
 
@@ -914,46 +909,29 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS,
   //
   // Since the carry into the most significant position is always equal to
   // the carry out of the addition, there is no signed overflow.
-  if (ComputeNumSignBits(LHS, 0, CxtI) > 1 &&
-      ComputeNumSignBits(RHS, 0, CxtI) > 1)
+  if (ComputeNumSignBits(LHS, 0, &CxtI) > 1 &&
+      ComputeNumSignBits(RHS, 0, &CxtI) > 1)
     return true;
 
-  if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) {
-    int BitWidth = IT->getBitWidth();
-    APInt LHSKnownZero(BitWidth, 0);
-    APInt LHSKnownOne(BitWidth, 0);
-    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI);
-
-    APInt RHSKnownZero(BitWidth, 0);
-    APInt RHSKnownOne(BitWidth, 0);
-    computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI);
-
-    // Addition of two 2's compliment numbers having opposite signs will never
-    // overflow.
-    if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) ||
-        (LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1]))
-      return true;
-
-    // Check if carry bit of addition will not cause overflow.
-    if (checkRippleForAdd(LHSKnownZero, RHSKnownZero))
-      return true;
-    if (checkRippleForAdd(RHSKnownZero, LHSKnownZero))
-      return true;
-  }
-  return false;
-}
+  unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
+  APInt LHSKnownZero(BitWidth, 0);
+  APInt LHSKnownOne(BitWidth, 0);
+  computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &CxtI);
 
-/// WillNotOverflowUnsignedAdd - Return true if we can prove that:
-///    (zext (add LHS, RHS))  === (add (zext LHS), (zext RHS))
-bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS,
-                                              Instruction *CxtI) {
-  // There are different heuristics we can use for this. Here is a simple one.
-  // If the sign bit of LHS and that of RHS are both zero, no unsigned wrap.
-  bool LHSKnownNonNegative, LHSKnownNegative;
-  bool RHSKnownNonNegative, RHSKnownNegative;
-  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT);
-  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT);
-  if (LHSKnownNonNegative && RHSKnownNonNegative)
+  APInt RHSKnownZero(BitWidth, 0);
+  APInt RHSKnownOne(BitWidth, 0);
+  computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI);
+
+  // Addition of two 2's compliment numbers having opposite signs will never
+  // overflow.
+  if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) ||
+      (LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1]))
+    return true;
+
+  // Check if carry bit of addition will not cause overflow.
+  if (checkRippleForAdd(LHSKnownZero, RHSKnownZero))
+    return true;
+  if (checkRippleForAdd(RHSKnownZero, LHSKnownZero))
     return true;
 
   return false;
@@ -965,94 +943,49 @@ bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS,
 /// overflow to change the sign bit or have a carry out.
 /// TODO: Handle this for Vectors.
 bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS,
-                                            Instruction *CxtI) {
+                                            Instruction &CxtI) {
   // If LHS and RHS each have at least two sign bits, the subtraction
   // cannot overflow.
-  if (ComputeNumSignBits(LHS, 0, CxtI) > 1 &&
-      ComputeNumSignBits(RHS, 0, CxtI) > 1)
+  if (ComputeNumSignBits(LHS, 0, &CxtI) > 1 &&
+      ComputeNumSignBits(RHS, 0, &CxtI) > 1)
     return true;
 
-  if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) {
-    unsigned BitWidth = IT->getBitWidth();
-    APInt LHSKnownZero(BitWidth, 0);
-    APInt LHSKnownOne(BitWidth, 0);
-    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI);
+  unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
+  APInt LHSKnownZero(BitWidth, 0);
+  APInt LHSKnownOne(BitWidth, 0);
+  computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &CxtI);
 
-    APInt RHSKnownZero(BitWidth, 0);
-    APInt RHSKnownOne(BitWidth, 0);
-    computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI);
+  APInt RHSKnownZero(BitWidth, 0);
+  APInt RHSKnownOne(BitWidth, 0);
+  computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI);
 
-    // Subtraction of two 2's compliment numbers having identical signs will
-    // never overflow.
-    if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) ||
-        (LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1]))
-      return true;
+  // Subtraction of two 2's compliment numbers having identical signs will
+  // never overflow.
+  if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) ||
+      (LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1]))
+    return true;
 
-    // TODO: implement logic similar to checkRippleForAdd
-  }
+  // TODO: implement logic similar to checkRippleForAdd
   return false;
 }
 
 /// \brief Return true if we can prove that:
 ///    (sub LHS, RHS)  === (sub nuw LHS, RHS)
 bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS,
-                                              Instruction *CxtI) {
+                                              Instruction &CxtI) {
   // If the LHS is negative and the RHS is non-negative, no unsigned wrap.
   bool LHSKnownNonNegative, LHSKnownNegative;
   bool RHSKnownNonNegative, RHSKnownNegative;
-  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT);
-  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT);
+  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, /*Depth=*/0,
+                 &CxtI);
+  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, /*Depth=*/0,
+                 &CxtI);
   if (LHSKnownNegative && RHSKnownNonNegative)
     return true;
 
   return false;
 }
 
-/// \brief Return true if we can prove that:
-///    (mul LHS, RHS)  === (mul nsw LHS, RHS)
-bool InstCombiner::WillNotOverflowSignedMul(Value *LHS, Value *RHS,
-                                            Instruction *CxtI) {
-  if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) {
-
-    // Multiplying n * m significant bits yields a result of n + m significant
-    // bits. If the total number of significant bits does not exceed the
-    // result bit width (minus 1), there is no overflow.
-    // This means if we have enough leading sign bits in the operands
-    // we can guarantee that the result does not overflow.
-    // Ref: "Hacker's Delight" by Henry Warren
-    unsigned BitWidth = IT->getBitWidth();
-
-    // Note that underestimating the number of sign bits gives a more
-    // conservative answer.
-    unsigned SignBits = ComputeNumSignBits(LHS, 0, CxtI) +
-                        ComputeNumSignBits(RHS, 0, CxtI);
-
-    // First handle the easy case: if we have enough sign bits there's
-    // definitely no overflow. 
-    if (SignBits > BitWidth + 1)
-      return true;
-    
-    // There are two ambiguous cases where there can be no overflow:
-    //   SignBits == BitWidth + 1    and
-    //   SignBits == BitWidth    
-    // The second case is difficult to check, therefore we only handle the
-    // first case.
-    if (SignBits == BitWidth + 1) {
-      // It overflows only when both arguments are negative and the true
-      // product is exactly the minimum negative number.
-      // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
-      // For simplicity we just check if at least one side is not negative.
-      bool LHSNonNegative, LHSNegative;
-      bool RHSNonNegative, RHSNegative;
-      ComputeSignBit(LHS, LHSNonNegative, LHSNegative, DL, 0, AT, CxtI, DT);
-      ComputeSignBit(RHS, RHSNonNegative, RHSNegative, DL, 0, AT, CxtI, DT);
-      if (LHSNonNegative || RHSNonNegative)
-        return true;
-    }
-  }
-  return false;
-}
-
 // Checks if any operand is negative and we can convert add to sub.
 // This function checks for following negative patterns
 //   ADD(XOR(OR(Z, NOT(C)), C)), 1) == NEG(AND(Z, C))
@@ -1115,15 +1048,15 @@ static Value *checkForNegativeOperand(BinaryOperator &I,
 }
 
 Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
-   bool Changed = SimplifyAssociativeOrCommutative(I);
-   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+  bool Changed = SimplifyAssociativeOrCommutative(I);
+  Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
 
-   if (Value *V = SimplifyVectorOp(I))
-     return ReplaceInstUsesWith(I, V);
+  if (Value *V = SimplifyVectorOp(I))
+    return ReplaceInstUsesWith(I, V);
 
-   if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(),
-                                  I.hasNoUnsignedWrap(), DL, TLI, DT, AT))
-     return ReplaceInstUsesWith(I, V);
+  if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(),
+                                 I.hasNoUnsignedWrap(), DL, TLI, DT, AC))
+    return ReplaceInstUsesWith(I, V);
 
    // (A*B)+(A*C) -> A*(B+C) etc
   if (Value *V = SimplifyUsingDistributiveLaws(I))
@@ -1312,7 +1245,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
         ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType());
       if (LHSConv->hasOneUse() &&
           ConstantExpr::getSExt(CI, I.getType()) == RHSC &&
-          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) {
+          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) {
         // Insert the new, smaller add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                               CI, "addconv");
@@ -1325,10 +1258,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
       // Only do this if x/y have the same type, if at last one of them has a
       // single use (so we don't increase the number of sexts), and if the
       // integer add will not overflow.
-      if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&&
+      if (LHSConv->getOperand(0)->getType() ==
+              RHSConv->getOperand(0)->getType() &&
           (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
           WillNotOverflowSignedAdd(LHSConv->getOperand(0),
-                                   RHSConv->getOperand(0), &I)) {
+                                   RHSConv->getOperand(0), I)) {
         // Insert the new integer add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                              RHSConv->getOperand(0), "addconv");
@@ -1376,11 +1310,13 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
   // TODO(jingyue): Consider WillNotOverflowSignedAdd and
   // WillNotOverflowUnsignedAdd to reduce the number of invocations of
   // computeKnownBits.
-  if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS, &I)) {
+  if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS, I)) {
     Changed = true;
     I.setHasNoSignedWrap(true);
   }
-  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedAdd(LHS, RHS, &I)) {
+  if (!I.hasNoUnsignedWrap() &&
+      computeOverflowForUnsignedAdd(LHS, RHS, &I) ==
+          OverflowResult::NeverOverflows) {
     Changed = true;
     I.setHasNoUnsignedWrap(true);
   }
@@ -1395,8 +1331,8 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
   if (Value *V = SimplifyVectorOp(I))
     return ReplaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL,
-                                  TLI, DT, AT))
+  if (Value *V =
+          SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL, TLI, DT, AC))
     return ReplaceInstUsesWith(I, V);
 
   if (isa<Constant>(RHS)) {
@@ -1438,7 +1374,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
       ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType());
       if (LHSConv->hasOneUse() &&
           ConstantExpr::getSIToFP(CI, I.getType()) == CFP &&
-          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) {
+          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) {
         // Insert the new integer add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                               CI, "addconv");
@@ -1451,10 +1387,11 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
       // Only do this if x/y have the same type, if at last one of them has a
       // single use (so we don't increase the number of int->fp conversions),
       // and if the integer add will not overflow.
-      if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&&
+      if (LHSConv->getOperand(0)->getType() ==
+              RHSConv->getOperand(0)->getType() &&
           (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
           WillNotOverflowSignedAdd(LHSConv->getOperand(0),
-                                   RHSConv->getOperand(0), &I)) {
+                                   RHSConv->getOperand(0), I)) {
         // Insert the new integer add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                               RHSConv->getOperand(0),"addconv");
@@ -1503,8 +1440,6 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
 ///
 Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS,
                                                Type *Ty) {
-  assert(DL && "Must have target data info for this");
-
   // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
   // this.
   bool Swapped = false;
@@ -1574,7 +1509,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
     return ReplaceInstUsesWith(I, V);
 
   if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(),
-                                 I.hasNoUnsignedWrap(), DL, TLI, DT, AT))
+                                 I.hasNoUnsignedWrap(), DL, TLI, DT, AC))
     return ReplaceInstUsesWith(I, V);
 
   // (A*B)-(A*C) -> A*(B-C) etc
@@ -1729,26 +1664,24 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
 
   // Optimize pointer differences into the same array into a size.  Consider:
   //  &A[10] - &A[0]: we should compile this to "10".
-  if (DL) {
-    Value *LHSOp, *RHSOp;
-    if (match(Op0, m_PtrToInt(m_Value(LHSOp))) &&
-        match(Op1, m_PtrToInt(m_Value(RHSOp))))
-      if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
-        return ReplaceInstUsesWith(I, Res);
-
-    // trunc(p)-trunc(q) -> trunc(p-q)
-    if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) &&
-        match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp)))))
-      if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
-        return ReplaceInstUsesWith(I, Res);
-      }
+  Value *LHSOp, *RHSOp;
+  if (match(Op0, m_PtrToInt(m_Value(LHSOp))) &&
+      match(Op1, m_PtrToInt(m_Value(RHSOp))))
+    if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
+      return ReplaceInstUsesWith(I, Res);
+
+  // trunc(p)-trunc(q) -> trunc(p-q)
+  if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) &&
+      match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp)))))
+    if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
+      return ReplaceInstUsesWith(I, Res);
 
   bool Changed = false;
-  if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1, &I)) {
+  if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1, I)) {
     Changed = true;
     I.setHasNoSignedWrap(true);
   }
-  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedSub(Op0, Op1, &I)) {
+  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedSub(Op0, Op1, I)) {
     Changed = true;
     I.setHasNoUnsignedWrap(true);
   }
@@ -1762,10 +1695,18 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
   if (Value *V = SimplifyVectorOp(I))
     return ReplaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL,
-                                  TLI, DT, AT))
+  if (Value *V =
+          SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC))
     return ReplaceInstUsesWith(I, V);
 
+  // fsub nsz 0, X ==> fsub nsz -0.0, X
+  if (I.getFastMathFlags().noSignedZeros() && match(Op0, m_Zero())) {
+    // Subtraction from -0.0 is the canonical form of fneg.
+    Instruction *NewI = BinaryOperator::CreateFNeg(Op1);
+    NewI->copyFastMathFlags(&I);
+    return NewI;
+  }
+
   if (isa<Constant>(Op0))
     if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
       if (Instruction *NV = FoldOpIntoSelect(I, SI))