Fixed GEP visitor in the InstCombine pass.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineAddSub.cpp
index fbec98dbec857dda098ff7975aa68ffb5cb45539..17bbc5d2870dfa0d14034daacd10d091d39b6414 100644 (file)
@@ -1,4 +1,4 @@
-//===- InstCombineAddSub.cpp ----------------------------------------------===//
+//===- InstCombineAddSub.cpp ------------------------------------*- C++ -*-===//
 //
 //                     The LLVM Compiler Infrastructure
 //
 //
 //===----------------------------------------------------------------------===//
 
-#include "InstCombine.h"
+#include "InstCombineInternal.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/PatternMatch.h"
+
 using namespace llvm;
 using namespace PatternMatch;
 
@@ -67,17 +68,17 @@ namespace {
 
   private:
     bool insaneIntVal(int V) { return V > 4 || V < -4; }
-    APFloat *getFpValPtr(void)
+    APFloat *getFpValPtr()
       { return reinterpret_cast<APFloat*>(&FpValBuf.buffer[0]); }
-    const APFloat *getFpValPtr(void) const
+    const APFloat *getFpValPtr() const
       { return reinterpret_cast<const APFloat*>(&FpValBuf.buffer[0]); }
 
-    const APFloat &getFpVal(void) const {
+    const APFloat &getFpVal() const {
       assert(IsFp && BufHasFpVal && "Incorret state");
       return *getFpValPtr();
     }
 
-    APFloat &getFpVal(void) {
+    APFloat &getFpVal() {
       assert(IsFp && BufHasFpVal && "Incorret state");
       return *getFpValPtr();
     }
@@ -92,8 +93,8 @@ namespace {
     // TODO: We should get rid of this function when APFloat can be constructed
     //       from an *SIGNED* integer.
     APFloat createAPFloatFromInt(const fltSemantics &Sem, int Val);
-  private:
 
+  private:
     bool IsFp;
 
     // True iff FpValBuf contains an instance of APFloat.
@@ -114,10 +115,10 @@ namespace {
   ///
   class FAddend {
   public:
-    FAddend() { Val = nullptr; }
+    FAddend() : Val(nullptr) {}
 
-    Value *getSymVal (void) const { return Val; }
-    const FAddendCoef &getCoef(void) const { return Coeff; }
+    Value *getSymVal() const { return Val; }
+    const FAddendCoef &getCoef() const { return Coeff; }
 
     bool isConstant() const { return Val == nullptr; }
     bool isZero() const { return Coeff.isZero(); }
@@ -182,7 +183,6 @@ namespace {
     InstCombiner::BuilderTy *Builder;
     Instruction *Instr;
 
-  private:
      // Debugging stuff are clustered here.
     #ifndef NDEBUG
       unsigned CreateInstrNum;
@@ -193,7 +193,8 @@ namespace {
       void incCreateInstNum() {}
     #endif
   };
-}
+
+} // anonymous namespace
 
 //===----------------------------------------------------------------------===//
 //
@@ -602,7 +603,6 @@ Value *FAddCombine::simplify(Instruction *I) {
 }
 
 Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) {
-
   unsigned AddendNum = Addends.size();
   assert(AddendNum <= 4 && "Too many addends");
 
@@ -886,12 +886,12 @@ static bool checkRippleForAdd(const APInt &Op0KnownZero,
   return Op0ZeroPosition >= Op1OnePosition;
 }
 
-/// WillNotOverflowSignedAdd - Return true if we can prove that:
+/// Return true if we can prove that:
 ///    (sext (add LHS, RHS))  === (add (sext LHS), (sext RHS))
 /// This basically requires proving that the add in the original type would not
 /// overflow to change the sign bit or have a carry out.
 bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS,
-                                            Instruction *CxtI) {
+                                            Instruction &CxtI) {
   // There are different heuristics we can use for this.  Here are some simple
   // ones.
 
@@ -909,18 +909,18 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS,
   //
   // Since the carry into the most significant position is always equal to
   // the carry out of the addition, there is no signed overflow.
-  if (ComputeNumSignBits(LHS, 0, CxtI) > 1 &&
-      ComputeNumSignBits(RHS, 0, CxtI) > 1)
+  if (ComputeNumSignBits(LHS, 0, &CxtI) > 1 &&
+      ComputeNumSignBits(RHS, 0, &CxtI) > 1)
     return true;
 
   unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
   APInt LHSKnownZero(BitWidth, 0);
   APInt LHSKnownOne(BitWidth, 0);
-  computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI);
+  computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &CxtI);
 
   APInt RHSKnownZero(BitWidth, 0);
   APInt RHSKnownOne(BitWidth, 0);
-  computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI);
+  computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI);
 
   // Addition of two 2's compliment numbers having opposite signs will never
   // overflow.
@@ -937,43 +937,27 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS,
   return false;
 }
 
-/// WillNotOverflowUnsignedAdd - Return true if we can prove that:
-///    (zext (add LHS, RHS))  === (add (zext LHS), (zext RHS))
-bool InstCombiner::WillNotOverflowUnsignedAdd(Value *LHS, Value *RHS,
-                                              Instruction *CxtI) {
-  // There are different heuristics we can use for this. Here is a simple one.
-  // If the sign bit of LHS and that of RHS are both zero, no unsigned wrap.
-  bool LHSKnownNonNegative, LHSKnownNegative;
-  bool RHSKnownNonNegative, RHSKnownNegative;
-  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, /*Depth=*/0, CxtI);
-  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, /*Depth=*/0, CxtI);
-  if (LHSKnownNonNegative && RHSKnownNonNegative)
-    return true;
-
-  return false;
-}
-
 /// \brief Return true if we can prove that:
 ///    (sub LHS, RHS)  === (sub nsw LHS, RHS)
 /// This basically requires proving that the add in the original type would not
 /// overflow to change the sign bit or have a carry out.
 /// TODO: Handle this for Vectors.
 bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS,
-                                            Instruction *CxtI) {
+                                            Instruction &CxtI) {
   // If LHS and RHS each have at least two sign bits, the subtraction
   // cannot overflow.
-  if (ComputeNumSignBits(LHS, 0, CxtI) > 1 &&
-      ComputeNumSignBits(RHS, 0, CxtI) > 1)
+  if (ComputeNumSignBits(LHS, 0, &CxtI) > 1 &&
+      ComputeNumSignBits(RHS, 0, &CxtI) > 1)
     return true;
 
   unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
   APInt LHSKnownZero(BitWidth, 0);
   APInt LHSKnownOne(BitWidth, 0);
-  computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, CxtI);
+  computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &CxtI);
 
   APInt RHSKnownZero(BitWidth, 0);
   APInt RHSKnownOne(BitWidth, 0);
-  computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, CxtI);
+  computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI);
 
   // Subtraction of two 2's compliment numbers having identical signs will
   // never overflow.
@@ -988,12 +972,14 @@ bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS,
 /// \brief Return true if we can prove that:
 ///    (sub LHS, RHS)  === (sub nuw LHS, RHS)
 bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS,
-                                              Instruction *CxtI) {
+                                              Instruction &CxtI) {
   // If the LHS is negative and the RHS is non-negative, no unsigned wrap.
   bool LHSKnownNonNegative, LHSKnownNegative;
   bool RHSKnownNonNegative, RHSKnownNegative;
-  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, /*Depth=*/0, CxtI);
-  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, /*Depth=*/0, CxtI);
+  ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, /*Depth=*/0,
+                 &CxtI);
+  ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, /*Depth=*/0,
+                 &CxtI);
   if (LHSKnownNegative && RHSKnownNonNegative)
     return true;
 
@@ -1062,15 +1048,15 @@ static Value *checkForNegativeOperand(BinaryOperator &I,
 }
 
 Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
-   bool Changed = SimplifyAssociativeOrCommutative(I);
-   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+  bool Changed = SimplifyAssociativeOrCommutative(I);
+  Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
 
-   if (Value *V = SimplifyVectorOp(I))
-     return ReplaceInstUsesWith(I, V);
+  if (Value *V = SimplifyVectorOp(I))
+    return ReplaceInstUsesWith(I, V);
 
-   if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(),
-                                  I.hasNoUnsignedWrap(), DL, TLI, DT, AC))
-     return ReplaceInstUsesWith(I, V);
+  if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(),
+                                 I.hasNoUnsignedWrap(), DL, TLI, DT, AC))
+    return ReplaceInstUsesWith(I, V);
 
    // (A*B)+(A*C) -> A*(B+C) etc
   if (Value *V = SimplifyUsingDistributiveLaws(I))
@@ -1174,20 +1160,8 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
     return ReplaceInstUsesWith(I, V);
 
   // A+B --> A|B iff A and B have no bits set in common.
-  if (IntegerType *IT = dyn_cast<IntegerType>(I.getType())) {
-    APInt LHSKnownOne(IT->getBitWidth(), 0);
-    APInt LHSKnownZero(IT->getBitWidth(), 0);
-    computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &I);
-    if (LHSKnownZero != 0) {
-      APInt RHSKnownOne(IT->getBitWidth(), 0);
-      APInt RHSKnownZero(IT->getBitWidth(), 0);
-      computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &I);
-
-      // No bits in common -> bitwise or.
-      if ((LHSKnownZero|RHSKnownZero).isAllOnesValue())
-        return BinaryOperator::CreateOr(LHS, RHS);
-    }
-  }
+  if (haveNoCommonBitsSet(LHS, RHS, DL, AC, &I, DT))
+    return BinaryOperator::CreateOr(LHS, RHS);
 
   if (Constant *CRHS = dyn_cast<Constant>(RHS)) {
     Value *X;
@@ -1259,7 +1233,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
         ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType());
       if (LHSConv->hasOneUse() &&
           ConstantExpr::getSExt(CI, I.getType()) == RHSC &&
-          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) {
+          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) {
         // Insert the new, smaller add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                               CI, "addconv");
@@ -1272,10 +1246,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
       // Only do this if x/y have the same type, if at last one of them has a
       // single use (so we don't increase the number of sexts), and if the
       // integer add will not overflow.
-      if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&&
+      if (LHSConv->getOperand(0)->getType() ==
+              RHSConv->getOperand(0)->getType() &&
           (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
           WillNotOverflowSignedAdd(LHSConv->getOperand(0),
-                                   RHSConv->getOperand(0), &I)) {
+                                   RHSConv->getOperand(0), I)) {
         // Insert the new integer add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                              RHSConv->getOperand(0), "addconv");
@@ -1323,11 +1298,13 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
   // TODO(jingyue): Consider WillNotOverflowSignedAdd and
   // WillNotOverflowUnsignedAdd to reduce the number of invocations of
   // computeKnownBits.
-  if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS, &I)) {
+  if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS, I)) {
     Changed = true;
     I.setHasNoSignedWrap(true);
   }
-  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedAdd(LHS, RHS, &I)) {
+  if (!I.hasNoUnsignedWrap() &&
+      computeOverflowForUnsignedAdd(LHS, RHS, &I) ==
+          OverflowResult::NeverOverflows) {
     Changed = true;
     I.setHasNoUnsignedWrap(true);
   }
@@ -1385,7 +1362,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
       ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType());
       if (LHSConv->hasOneUse() &&
           ConstantExpr::getSIToFP(CI, I.getType()) == CFP &&
-          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, &I)) {
+          WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) {
         // Insert the new integer add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                               CI, "addconv");
@@ -1398,10 +1375,11 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
       // Only do this if x/y have the same type, if at last one of them has a
       // single use (so we don't increase the number of int->fp conversions),
       // and if the integer add will not overflow.
-      if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&&
+      if (LHSConv->getOperand(0)->getType() ==
+              RHSConv->getOperand(0)->getType() &&
           (LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
           WillNotOverflowSignedAdd(LHSConv->getOperand(0),
-                                   RHSConv->getOperand(0), &I)) {
+                                   RHSConv->getOperand(0), I)) {
         // Insert the new integer add.
         Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
                                               RHSConv->getOperand(0),"addconv");
@@ -1443,15 +1421,12 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
   return Changed ? &I : nullptr;
 }
 
-
 /// Optimize pointer differences into the same array into a size.  Consider:
 ///  &A[10] - &A[0]: we should compile this to "10".  LHS/RHS are the pointer
 /// operands to the ptrtoint instructions for the LHS/RHS of the subtract.
 ///
 Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS,
                                                Type *Ty) {
-  assert(DL && "Must have target data info for this");
-
   // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
   // this.
   bool Swapped = false;
@@ -1598,8 +1573,20 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
           CI->getValue() == I.getType()->getPrimitiveSizeInBits() - 1)
         return BinaryOperator::CreateLShr(X, CI);
     }
-  }
 
+    // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
+    // zero.
+    APInt IntVal = C->getValue();
+    if ((IntVal + 1).isPowerOf2()) {
+      unsigned BitWidth = I.getType()->getScalarSizeInBits();
+      APInt KnownZero(BitWidth, 0);
+      APInt KnownOne(BitWidth, 0);
+      computeKnownBits(&I, KnownZero, KnownOne, 0, &I);
+      if ((IntVal | KnownZero).isAllOnesValue()) {
+        return BinaryOperator::CreateXor(Op1, C);
+      }
+    }
+  }
 
   {
     Value *Y;
@@ -1676,26 +1663,24 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
 
   // Optimize pointer differences into the same array into a size.  Consider:
   //  &A[10] - &A[0]: we should compile this to "10".
-  if (DL) {
-    Value *LHSOp, *RHSOp;
-    if (match(Op0, m_PtrToInt(m_Value(LHSOp))) &&
-        match(Op1, m_PtrToInt(m_Value(RHSOp))))
-      if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
-        return ReplaceInstUsesWith(I, Res);
-
-    // trunc(p)-trunc(q) -> trunc(p-q)
-    if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) &&
-        match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp)))))
-      if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
-        return ReplaceInstUsesWith(I, Res);
-      }
+  Value *LHSOp, *RHSOp;
+  if (match(Op0, m_PtrToInt(m_Value(LHSOp))) &&
+      match(Op1, m_PtrToInt(m_Value(RHSOp))))
+    if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
+      return ReplaceInstUsesWith(I, Res);
+
+  // trunc(p)-trunc(q) -> trunc(p-q)
+  if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) &&
+      match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp)))))
+    if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
+      return ReplaceInstUsesWith(I, Res);
 
   bool Changed = false;
-  if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1, &I)) {
+  if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1, I)) {
     Changed = true;
     I.setHasNoSignedWrap(true);
   }
-  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedSub(Op0, Op1, &I)) {
+  if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedSub(Op0, Op1, I)) {
     Changed = true;
     I.setHasNoUnsignedWrap(true);
   }