[InstCombine] mark ADD with nuw if no unsigned overflow
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineAddSub.cpp
index 4d8a1efd64ea311c2a80925b75c78102a8c4ccb7..f6658964ae851f82995180b058bab8aea56d63cf 100644 (file)
@@ -889,29 +889,93 @@ static inline Value *dyn_castFoldableMul(Value *V, Constant *&CST) {
   return nullptr;
 }
 
+// If one of the operands only has one non-zero bit, and if the other
+// operand has a known-zero bit in a more significant place than it (not
+// including the sign bit) the ripple may go up to and fill the zero, but
+// won't change the sign. For example, (X & ~4) + 1.
+static bool checkRippleForAdd(const APInt &Op0KnownZero,
+                              const APInt &Op1KnownZero) {
+  APInt Op1MaybeOne = ~Op1KnownZero;
+  // Make sure that one of the operand has at most one bit set to 1.
+  if (Op1MaybeOne.countPopulation() != 1)
+    return false;
+
+  // Find the most significant known 0 other than the sign bit.
+  int BitWidth = Op0KnownZero.getBitWidth();
+  APInt Op0KnownZeroTemp(Op0KnownZero);
+  Op0KnownZeroTemp.clearBit(BitWidth - 1);
+  int Op0ZeroPosition = BitWidth - Op0KnownZeroTemp.countLeadingZeros() - 1;
+
+  int Op1OnePosition = BitWidth - Op1MaybeOne.countLeadingZeros() - 1;
+  assert(Op1OnePosition >= 0);
+
+  // This also covers the case of no known zero, since in that case
+  // Op0ZeroPosition is -1.
+  return Op0ZeroPosition >= Op1OnePosition;
+}
 
 /// WillNotOverflowSignedAdd - Return true if we can prove that:
 ///    (sext (add LHS, RHS))  === (add (sext LHS), (sext RHS))
 /// This basically requires proving that the add in the original type would not
 /// overflow to change the sign bit or have a carry out.
+/// TODO: Handle this for Vectors.
 bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) {
   // There are different heuristics we can use for this.  Here are some simple
   // ones.
 
-  // Add has the property that adding any two 2's complement numbers can only
-  // have one carry bit which can change a sign.  As such, if LHS and RHS each
-  // have at least two sign bits, we know that the addition of the two values
-  // will sign extend fine.
+  // If LHS and RHS each have at least two sign bits, the addition will look
+  // like
+  //
+  // XX..... +
+  // YY.....
+  //
+  // If the carry into the most significant position is 0, X and Y can't both
+  // be 1 and therefore the carry out of the addition is also 0.
+  //
+  // If the carry into the most significant position is 1, X and Y can't both
+  // be 0 and therefore the carry out of the addition is also 1.
+  //
+  // Since the carry into the most significant position is always equal to
+  // the carry out of the addition, there is no signed overflow.
   if (ComputeNumSignBits(LHS) > 1 && ComputeNumSignBits(RHS) > 1)
     return true;
 
+  if (IntegerType *IT = dyn_cast<IntegerType>(LHS->getType())) {
+    int BitWidth = IT->getBitWidth();
+    APInt LHSKnownZero(BitWidth, 0);
+    APInt LHSKnownOne(BitWidth, 0);
+    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne);
 
-  // If one of the operands only has one non-zero bit, and if the other operand
-  // has a known-zero bit in a more significant place than it (not including the
-  // sign bit) the ripple may go up to and fill the zero, but won't change the
-  // sign.  For example, (X & ~4) + 1.
+    APInt RHSKnownZero(BitWidth, 0);
+    APInt RHSKnownOne(BitWidth, 0);
+    computeKnownBits(RHS, RHSKnownZero, RHSKnownOne);
+
+    // Addition of two 2's compliment numbers having opposite signs will never
+    // overflow.
+    if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) ||
+        (LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1]))
+      return true;
+
+    // Check if carry bit of addition will not cause overflow.
+    if (checkRippleForAdd(LHSKnownZero, RHSKnownZero))
+      return true;
+    if (checkRippleForAdd(RHSKnownZero, LHSKnownZero))
+      return true;
+  }
+  return false;
+}
 
-  // TODO: Implement.
+/// WillNotOverflowUnsignedAdd - Return true if we can prove that:
+///    (zext (add LHS, RHS))  === (add (zext LHS), (zext RHS))
+bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS) {
+  // There are different heuristics we can use for this. Here is a simple one.
+  // If the sign bit of LHS and that of RHS are both zero, no unsigned wrap.
+  bool LHSKnownNonNegative, LHSKnownNegative;
+  bool RHSKnownNonNegative, RHSKnownNegative;
+  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, 0);
+  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, 0);
+  if (LHSKnownNonNegative && RHSKnownNonNegative)
+    return true;
 
   return false;
 }
@@ -1191,10 +1255,17 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
       return BinaryOperator::CreateOr(A, B);
   }
 
+  // TODO(jingyue): Consider WillNotOverflowSignedAdd and
+  // WillNotOverflowUnsignedAdd to reduce the number of invocations of
+  // computeKnownBits.
   if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS)) {
     Changed = true;
     I.setHasNoSignedWrap(true);
   }
+  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedAdd(LHS, RHS)) {
+    Changed = true;
+    I.setHasNoUnsignedWrap(true);
+  }
 
   return Changed ? &I : nullptr;
 }