fix formatting; NFC
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineAddSub.cpp
index 6aede84a352459f6e347a1b0af7f1816662b4e2b..39b71f37f8b2795a0b4be863532b967f634d63dc 100644 (file)
@@ -32,7 +32,7 @@ namespace {
   ///
   class FAddendCoef {
   public:
-    // The constructor has to initialize a APFloat, which is uncessary for
+    // The constructor has to initialize a APFloat, which is unnecessary for
     // most addends which have coefficient either 1 or -1. So, the constructor
     // is expensive. In order to avoid the cost of the constructor, we should
     // reuse some instances whenever possible. The pre-created instances
@@ -751,8 +751,7 @@ Value *FAddCombine::createNaryFAdd
   return LastVal;
 }
 
-Value *FAddCombine::createFSub
-  (Value *Opnd0, Value *Opnd1) {
+Value *FAddCombine::createFSub(Value *Opnd0, Value *Opnd1) {
   Value *V = Builder->CreateFSub(Opnd0, Opnd1);
   if (Instruction *I = dyn_cast<Instruction>(V))
     createInstPostProc(I);
@@ -767,8 +766,7 @@ Value *FAddCombine::createFNeg(Value *V) {
   return NewV;
 }
 
-Value *FAddCombine::createFAdd
-  (Value *Opnd0, Value *Opnd1) {
+Value *FAddCombine::createFAdd(Value *Opnd0, Value *Opnd1) {
   Value *V = Builder->CreateFAdd(Opnd0, Opnd1);
   if (Instruction *I = dyn_cast<Instruction>(V))
     createInstPostProc(I);
@@ -789,8 +787,7 @@ Value *FAddCombine::createFDiv(Value *Opnd0, Value *Opnd1) {
   return V;
 }
 
-void FAddCombine::createInstPostProc(Instruction *NewInstr,
-                                     bool NoNumber) {
+void FAddCombine::createInstPostProc(Instruction *NewInstr, bool NoNumber) {
   NewInstr->setDebugLoc(Instr->getDebugLoc());
 
   // Keep track of the number of instruction created.
@@ -840,8 +837,7 @@ unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) {
 // <C, V>             "fmul V, C"      false
 //
 // NOTE: Keep this function in sync with FAddCombine::calcInstrNumber.
-Value *FAddCombine::createAddendVal
-  (const FAddend &Opnd, bool &NeedNeg) {
+Value *FAddCombine::createAddendVal(const FAddend &Opnd, bool &NeedNeg) {
   const FAddendCoef &Coeff = Opnd.getCoef();
 
   if (Opnd.isConstant()) {
@@ -895,7 +891,8 @@ static bool checkRippleForAdd(const APInt &Op0KnownZero,
 /// This basically requires proving that the add in the original type would not
 /// overflow to change the sign bit or have a carry out.
 /// TODO: Handle this for Vectors.
-bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) {
+bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS,
+                                            Instruction *CxtI) {
   // There are different heuristics we can use for this.  Here are some simple
   // ones.
 
@@ -913,18 +910,19 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) {
   //
   // Since the carry into the most significant position is always equal to
   // the carry out of the addition, there is no signed overflow.
-  if (ComputeNumSignBits(LHS) > 1 && ComputeNumSignBits(RHS) > 1)
+  if (ComputeNumSignBits(LHS, 0, CxtI) > 1 &&
+      ComputeNumSignBits(RHS, 0, CxtI) > 1)
     return true;
 
   if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) {
     int BitWidth = IT->getBitWidth();
     APInt LHSKnownZero(BitWidth, 0);
     APInt LHSKnownOne(BitWidth, 0);
-    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne);
+    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI);
 
     APInt RHSKnownZero(BitWidth, 0);
     APInt RHSKnownOne(BitWidth, 0);
-    computeKnownBits(RHS, RHSKnownZero, RHSKnownOne);
+    computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI);
 
     // Addition of two 2's compliment numbers having opposite signs will never
     // overflow.
@@ -943,67 +941,176 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) {
 
 /// WillNotOverflowUnsignedAdd - Return true if we can prove that:
 ///    (zext (add LHS, RHS))  === (add (zext LHS), (zext RHS))
-bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) {
+bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS,
+                                              Instruction *CxtI) {
   // There are different heuristics we can use for this. Here is a simple one.
   // If the sign bit of LHS and that of RHS are both zero, no unsigned wrap.
   bool LHSKnownNonNegative, LHSKnownNegative;
   bool RHSKnownNonNegative, RHSKnownNegative;
-  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0);
-  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0);
+  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT);
+  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT);
   if (LHSKnownNonNegative && RHSKnownNonNegative)
     return true;
 
   return false;
- }
+}
+
+/// \brief Return true if we can prove that:
+///    (sub LHS, RHS)  === (sub nsw LHS, RHS)
+/// This basically requires proving that the add in the original type would not
+/// overflow to change the sign bit or have a carry out.
+/// TODO: Handle this for Vectors.
+bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS,
+                                            Instruction *CxtI) {
+  // If LHS and RHS each have at least two sign bits, the subtraction
+  // cannot overflow.
+  if (ComputeNumSignBits(LHS, 0, CxtI) > 1 &&
+      ComputeNumSignBits(RHS, 0, CxtI) > 1)
+    return true;
+
+  if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) {
+    unsigned BitWidth = IT->getBitWidth();
+    APInt LHSKnownZero(BitWidth, 0);
+    APInt LHSKnownOne(BitWidth, 0);
+    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI);
+
+    APInt RHSKnownZero(BitWidth, 0);
+    APInt RHSKnownOne(BitWidth, 0);
+    computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI);
+
+    // Subtraction of two 2's compliment numbers having identical signs will
+    // never overflow.
+    if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) ||
+        (LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1]))
+      return true;
+
+    // TODO: implement logic similar to checkRippleForAdd
+  }
+  return false;
+}
+
+/// \brief Return true if we can prove that:
+///    (sub LHS, RHS)  === (sub nuw LHS, RHS)
+bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS,
+                                              Instruction *CxtI) {
+  // If the LHS is negative and the RHS is non-negative, no unsigned wrap.
+  bool LHSKnownNonNegative, LHSKnownNegative;
+  bool RHSKnownNonNegative, RHSKnownNegative;
+  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0, AT, CxtI, DT);
+  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0, AT, CxtI, DT);
+  if (LHSKnownNegative && RHSKnownNonNegative)
+    return true;
+
+  return false;
+}
+
+/// \brief Return true if we can prove that:
+///    (mul LHS, RHS)  === (mul nsw LHS, RHS)
+bool InstCombiner::WillNotOverflowSignedMul(Value *LHS, Value *RHS,
+                                            Instruction *CxtI) {
+  if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) {
+
+    // Multiplying n * m significant bits yields a result of n + m significant
+    // bits. If the total number of significant bits does not exceed the
+    // result bit width (minus 1), there is no overflow.
+    // This means if we have enough leading sign bits in the operands
+    // we can guarantee that the result does not overflow.
+    // Ref: "Hacker's Delight" by Henry Warren
+    unsigned BitWidth = IT->getBitWidth();
+
+    // Note that underestimating the number of sign bits gives a more
+    // conservative answer.
+    unsigned SignBits = ComputeNumSignBits(LHS, 0, CxtI) +
+                        ComputeNumSignBits(RHS, 0, CxtI);
+
+    // First handle the easy case: if we have enough sign bits there's
+    // definitely no overflow. 
+    if (SignBits > BitWidth + 1)
+      return true;
+    
+    // There are two ambiguous cases where there can be no overflow:
+    //   SignBits == BitWidth + 1    and
+    //   SignBits == BitWidth    
+    // The second case is difficult to check, therefore we only handle the
+    // first case.
+    if (SignBits == BitWidth + 1) {
+      // It overflows only when both arguments are negative and the true
+      // product is exactly the minimum negative number.
+      // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
+      // For simplicity we just check if at least one side is not negative.
+      bool LHSNonNegative, LHSNegative;
+      bool RHSNonNegative, RHSNegative;
+      ComputeSignBit(LHS, LHSNonNegative, LHSNegative, DL, 0, AT, CxtI, DT);
+      ComputeSignBit(RHS, RHSNonNegative, RHSNegative, DL, 0, AT, CxtI, DT);
+      if (LHSNonNegative || RHSNonNegative)
+        return true;
+    }
+  }
+  return false;
+}
 
 // Checks if any operand is negative and we can convert add to sub.
- // This function checks for following negative patterns
- //   ADD(XOR(OR(Z, NOT(C)), C)), 1) == NEG(AND(Z, C))
- //   ADD(XOR(AND(Z, C), C), 1) == NEG(OR(Z, ~C)) if C is odd
- //   TODO: XOR(AND(Z, C), (C + 1)) == NEG(OR(Z, ~C)) if C is even
- Value *checkForNegativeOperand(BinaryOperator &I,
-                                InstCombiner::BuilderTy *Builder) {
-   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+// This function checks for following negative patterns
+//   ADD(XOR(OR(Z, NOT(C)), C)), 1) == NEG(AND(Z, C))
+//   ADD(XOR(AND(Z, C), C), 1) == NEG(OR(Z, ~C))
+//   XOR(AND(Z, C), (C + 1)) == NEG(OR(Z, ~C)) if C is even
+static Value *checkForNegativeOperand(BinaryOperator &I,
+                                      InstCombiner::BuilderTy *Builder) {
+  Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+
+  // This function creates 2 instructions to replace ADD, we need at least one
+  // of LHS or RHS to have one use to ensure benefit in transform.
+  if (!LHS->hasOneUse() && !RHS->hasOneUse())
+    return nullptr;
 
-   // This function creates 2 instructions to replace ADD, we need at least one
-   // of LHS or RHS to have one use to ensure benefit in transform.
-   if (!LHS->hasOneUse() && !RHS->hasOneUse())
-     return nullptr;
-
-   Value *X = nullptr, *Y = nullptr, *Z = nullptr;
-   const APInt *C1 = nullptr, *C2 = nullptr;
-
-   // if ONE is on other side, swap
-   if (match(RHS, m_Add(m_Value(X), m_One())))
-     std::swap(LHS, RHS);
-
-   if (match(LHS, m_Add(m_Value(X), m_One()))) {
-     // if XOR on other side, swap
-     if (match(RHS, m_Xor(m_Value(Y), m_APInt(C1))))
-       std::swap(X, RHS);
-
-     if (match(X, m_Xor(m_Value(Y), m_APInt(C1)))) {
-       // X = XOR(Y, C1), Y = OR(Z, C2), C2 = NOT(C1) ==> X == NOT(AND(Z, C1))
-       // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, AND(Z, C1))
-       if (match(Y, m_Or(m_Value(Z), m_APInt(C2))) && (*C2 == ~(*C1))) {
-         Value *NewAnd = Builder->CreateAnd(Z, *C1);
-         return Builder->CreateSub(RHS, NewAnd, "sub");
-       } else if (C1->countTrailingZeros() == 0) {
-         // if C1 is ODD and
-         // X = XOR(Y, C1), Y = AND(Z, C2), C2 == C1 ==> X == NOT(OR(Z, ~C1))
-         // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, OR(Z, ~C1))
-         if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && (*C1 == *C2)) {
-           Value *NewOr = Builder->CreateOr(Z, ~(*C1));
-           return Builder->CreateSub(RHS, NewOr, "sub");
-         }
-       }
-     }
-   }
-
-   return nullptr;
- }
-
- Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
+  Value *X = nullptr, *Y = nullptr, *Z = nullptr;
+  const APInt *C1 = nullptr, *C2 = nullptr;
+
+  // if ONE is on other side, swap
+  if (match(RHS, m_Add(m_Value(X), m_One())))
+    std::swap(LHS, RHS);
+
+  if (match(LHS, m_Add(m_Value(X), m_One()))) {
+    // if XOR on other side, swap
+    if (match(RHS, m_Xor(m_Value(Y), m_APInt(C1))))
+      std::swap(X, RHS);
+
+    if (match(X, m_Xor(m_Value(Y), m_APInt(C1)))) {
+      // X = XOR(Y, C1), Y = OR(Z, C2), C2 = NOT(C1) ==> X == NOT(AND(Z, C1))
+      // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, AND(Z, C1))
+      if (match(Y, m_Or(m_Value(Z), m_APInt(C2))) && (*C2 == ~(*C1))) {
+        Value *NewAnd = Builder->CreateAnd(Z, *C1);
+        return Builder->CreateSub(RHS, NewAnd, "sub");
+      } else if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && (*C1 == *C2)) {
+        // X = XOR(Y, C1), Y = AND(Z, C2), C2 == C1 ==> X == NOT(OR(Z, ~C1))
+        // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, OR(Z, ~C1))
+        Value *NewOr = Builder->CreateOr(Z, ~(*C1));
+        return Builder->CreateSub(RHS, NewOr, "sub");
+      }
+    }
+  }
+
+  // Restore LHS and RHS
+  LHS = I.getOperand(0);
+  RHS = I.getOperand(1);
+
+  // if XOR is on other side, swap
+  if (match(RHS, m_Xor(m_Value(Y), m_APInt(C1))))
+    std::swap(LHS, RHS);
+
+  // C2 is ODD
+  // LHS = XOR(Y, C1), Y = AND(Z, C2), C1 == (C2 + 1) => LHS == NEG(OR(Z, ~C2))
+  // ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2))
+  if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1))))
+    if (C1->countTrailingZeros() == 0)
+      if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) {
+        Value *NewOr = Builder->CreateOr(Z, ~(*C2));
+        return Builder->CreateSub(RHS, NewOr, "sub");
+      }
+  return nullptr;
+}
+
+Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
    bool Changed = SimplifyAssociativeOrCommutative(I);
    Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
 
@@ -1011,7 +1118,7 @@ bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) {
      return ReplaceInstUsesWith(I, V);
 
    if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(),
-                                  I.hasNoUnsignedWrap(), DL))
+                                  I.hasNoUnsignedWrap(), DL, TLI, DT, AT))
      return ReplaceInstUsesWith(I, V);
 
    // (A*B)+(A*C) -> A*(B+C) etc
@@ -1050,7 +1157,7 @@ bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) {
 
       if (ExtendAmt) {
         APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt);
-        if (!MaskedValueIsZero(XorLHS, Mask))
+        if (!MaskedValueIsZero(XorLHS, Mask, 0, &I))
           ExtendAmt = 0;
       }
 
@@ -1066,7 +1173,7 @@ bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) {
         IntegerType *IT = cast<IntegerType>(I.getType());
         APInt LHSKnownOne(IT->getBitWidth(), 0);
         APInt LHSKnownZero(IT->getBitWidth(), 0);
-        computeKnownBits(XorLHS, LHSKnownZero, LHSKnownOne);
+        computeKnownBits(XorLHS, LHSKnownZero, LHSKnownOne, 0, &I);
         if ((XorRHS->getValue() | LHSKnownZero).isAllOnesValue())
           return BinaryOperator::CreateSub(ConstantExpr::getAdd(XorRHS, CI),
                                            XorLHS);
@@ -1119,11 +1226,11 @@ bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) {
   if (IntegerType *IT = dyn_cast<IntegerType>(I.getType())) {
     APInt LHSKnownOne(IT->getBitWidth(), 0);
     APInt LHSKnownZero(IT->getBitWidth(), 0);
-    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne);
+    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &I);
     if (LHSKnownZero != 0) {
       APInt RHSKnownOne(IT->getBitWidth(), 0);
       APInt RHSKnownZero(IT->getBitWidth(), 0);
-      computeKnownBits(RHS, RHSKnownZero, RHSKnownOne);
+      computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &I);
 
       // No bits in common -> bitwise or.
       if ((LHSKnownZero|RHSKnownZero).isAllOnesValue())
@@ -1201,7 +1308,7 @@ bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) {
         ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType());
       if (LHSConv->hasOneUse() &&
           ConstantExpr::getSExt(CI, I.getType()) == RHSC &&
-          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) {
+          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) {
         // Insert the new, smaller add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                               CI, "addconv");
@@ -1217,7 +1324,7 @@ bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) {
       if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&&
           (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
           WillNotOverflowSignedAdd(LHSConv->getOperand(0),
-                                   RHSConv->getOperand(0))) {
+                                   RHSConv->getOperand(0), &I)) {
         // Insert the new integer add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                              RHSConv->getOperand(0), "addconv");
@@ -1226,7 +1333,7 @@ bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) {
     }
   }
 
-  // Check for (x & y) + (x ^ y)
+  // (add (xor A, B) (and A, B)) --> (or A, B)
   {
     Value *A = nullptr, *B = nullptr;
     if (match(RHS, m_Xor(m_Value(A), m_Value(B))) &&
@@ -1240,14 +1347,36 @@ bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) {
       return BinaryOperator::CreateOr(A, B);
   }
 
+  // (add (or A, B) (and A, B)) --> (add A, B)
+  {
+    Value *A = nullptr, *B = nullptr;
+    if (match(RHS, m_Or(m_Value(A), m_Value(B))) &&
+        (match(LHS, m_And(m_Specific(A), m_Specific(B))) ||
+         match(LHS, m_And(m_Specific(B), m_Specific(A))))) {
+      auto *New = BinaryOperator::CreateAdd(A, B);
+      New->setHasNoSignedWrap(I.hasNoSignedWrap());
+      New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+      return New;
+    }
+
+    if (match(LHS, m_Or(m_Value(A), m_Value(B))) &&
+        (match(RHS, m_And(m_Specific(A), m_Specific(B))) ||
+         match(RHS, m_And(m_Specific(B), m_Specific(A))))) {
+      auto *New = BinaryOperator::CreateAdd(A, B);
+      New->setHasNoSignedWrap(I.hasNoSignedWrap());
+      New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+      return New;
+    }
+  }
+
   // TODO(jingyue): Consider WillNotOverflowSignedAdd and
   // WillNotOverflowUnsignedAdd to reduce the number of invocations of
   // computeKnownBits.
-  if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS)) {
+  if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS, &I)) {
     Changed = true;
     I.setHasNoSignedWrap(true);
   }
-  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedAdd(LHS, RHS)) {
+  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedAdd(LHS, RHS, &I)) {
     Changed = true;
     I.setHasNoUnsignedWrap(true);
   }
@@ -1262,7 +1391,8 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
   if (Value *V = SimplifyVectorOp(I))
     return ReplaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL))
+  if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL,
+                                  TLI, DT, AT))
     return ReplaceInstUsesWith(I, V);
 
   if (isa<Constant>(RHS)) {
@@ -1304,7 +1434,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
       ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType());
       if (LHSConv->hasOneUse() &&
           ConstantExpr::getSIToFP(CI, I.getType()) == CFP &&
-          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) {
+          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) {
         // Insert the new integer add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                               CI, "addconv");
@@ -1320,7 +1450,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
       if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&&
           (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
           WillNotOverflowSignedAdd(LHSConv->getOperand(0),
-                                   RHSConv->getOperand(0))) {
+                                   RHSConv->getOperand(0), &I)) {
         // Insert the new integer add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                               RHSConv->getOperand(0),"addconv");
@@ -1342,11 +1472,11 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
             Z2 = dyn_cast<Constant>(B2); B = B1;
         } else if (match(B1, m_AnyZero()) && match(A2, m_AnyZero())) {
             Z1 = dyn_cast<Constant>(B1); B = B2;
-            Z2 = dyn_cast<Constant>(A2); A = A1; 
+            Z2 = dyn_cast<Constant>(A2); A = A1;
         }
-        
-        if (Z1 && Z2 && 
-            (I.hasNoSignedZeros() || 
+
+        if (Z1 && Z2 &&
+            (I.hasNoSignedZeros() ||
              (Z1->isNegativeZeroValue() && Z2->isNegativeZeroValue()))) {
           return SelectInst::Create(C, A, B);
         }
@@ -1433,7 +1563,6 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS,
   return Builder->CreateIntCast(Result, Ty, true);
 }
 
-
 Instruction *InstCombiner::visitSub(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
@@ -1441,18 +1570,27 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
     return ReplaceInstUsesWith(I, V);
 
   if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(),
-                                 I.hasNoUnsignedWrap(), DL))
+                                 I.hasNoUnsignedWrap(), DL, TLI, DT, AT))
     return ReplaceInstUsesWith(I, V);
 
   // (A*B)-(A*C) -> A*(B-C) etc
   if (Value *V = SimplifyUsingDistributiveLaws(I))
     return ReplaceInstUsesWith(I, V);
 
-  // If this is a 'B = x-(-A)', change to B = x+A.  This preserves NSW/NUW.
+  // If this is a 'B = x-(-A)', change to B = x+A.
   if (Value *V = dyn_castNegVal(Op1)) {
     BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V);
-    Res->setHasNoSignedWrap(I.hasNoSignedWrap());
-    Res->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+
+    if (const auto *BO = dyn_cast<BinaryOperator>(Op1)) {
+      assert(BO->getOpcode() == Instruction::Sub &&
+             "Expected a subtraction operator!");
+      if (BO->hasNoSignedWrap() && I.hasNoSignedWrap())
+        Res->setHasNoSignedWrap(true);
+    } else {
+      if (cast<Constant>(Op1)->isNotMinSignedValue() && I.hasNoSignedWrap())
+        Res->setHasNoSignedWrap(true);
+    }
+
     return Res;
   }
 
@@ -1497,21 +1635,23 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
     // -(X >>u 31) -> (X >>s 31)
     // -(X >>s 31) -> (X >>u 31)
     if (C->isZero()) {
-      Value *X; ConstantInt *CI;
+      Value *X;
+      ConstantInt *CI;
       if (match(Op1, m_LShr(m_Value(X), m_ConstantInt(CI))) &&
           // Verify we are shifting out everything but the sign bit.
-          CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1)
+          CI->getValue() == I.getType()->getPrimitiveSizeInBits() - 1)
         return BinaryOperator::CreateAShr(X, CI);
 
       if (match(Op1, m_AShr(m_Value(X), m_ConstantInt(CI))) &&
           // Verify we are shifting out everything but the sign bit.
-          CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1)
+          CI->getValue() == I.getType()->getPrimitiveSizeInBits() - 1)
         return BinaryOperator::CreateLShr(X, CI);
     }
   }
 
 
-  { Value *Y;
+  {
+    Value *Y;
     // X-(X+Y) == -Y    X-(Y+X) == -Y
     if (match(Op1, m_Add(m_Specific(Op0), m_Value(Y))) ||
         match(Op1, m_Add(m_Value(Y), m_Specific(Op0))))
@@ -1522,6 +1662,24 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
       return BinaryOperator::CreateNeg(Y);
   }
 
+  // (sub (or A, B) (xor A, B)) --> (and A, B)
+  {
+    Value *A = nullptr, *B = nullptr;
+    if (match(Op1, m_Xor(m_Value(A), m_Value(B))) &&
+        (match(Op0, m_Or(m_Specific(A), m_Specific(B))) ||
+         match(Op0, m_Or(m_Specific(B), m_Specific(A)))))
+      return BinaryOperator::CreateAnd(A, B);
+  }
+
+  if (Op0->hasOneUse()) {
+    Value *Y = nullptr;
+    // ((X | Y) - X) --> (~X & Y)
+    if (match(Op0, m_Or(m_Value(Y), m_Specific(Op1))) ||
+        match(Op0, m_Or(m_Specific(Op1), m_Value(Y))))
+      return BinaryOperator::CreateAnd(
+          Y, Builder->CreateNot(Op1, Op1->getName() + ".not"));
+  }
+
   if (Op1->hasOneUse()) {
     Value *X = nullptr, *Y = nullptr, *Z = nullptr;
     Constant *C = nullptr;
@@ -1539,9 +1697,9 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
       return BinaryOperator::CreateAnd(Op0,
                                   Builder->CreateNot(Y, Y->getName() + ".not"));
 
-    // 0 - (X sdiv C)  -> (X sdiv -C)
-    if (match(Op1, m_SDiv(m_Value(X), m_Constant(C))) &&
-        match(Op0, m_Zero()))
+    // 0 - (X sdiv C)  -> (X sdiv -C)  provided the negation doesn't overflow.
+    if (match(Op1, m_SDiv(m_Value(X), m_Constant(C))) && match(Op0, m_Zero()) &&
+        C->isNotMinSignedValue() && !C->isOneValue())
       return BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(C));
 
     // 0 - (X << Y)  -> (-X << Y)   when X is freely negatable.
@@ -1579,9 +1737,19 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
         match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp)))))
       if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
         return ReplaceInstUsesWith(I, Res);
+      }
+
+  bool Changed = false;
+  if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1, &I)) {
+    Changed = true;
+    I.setHasNoSignedWrap(true);
+  }
+  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedSub(Op0, Op1, &I)) {
+    Changed = true;
+    I.setHasNoUnsignedWrap(true);
   }
 
-  return nullptr;
+  return Changed ? &I : nullptr;
 }
 
 Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
@@ -1590,7 +1758,8 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
   if (Value *V = SimplifyVectorOp(I))
     return ReplaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL))
+  if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL,
+                                  TLI, DT, AT))
     return ReplaceInstUsesWith(I, V);
 
   if (isa<Constant>(Op0))