Revert r219175 - [InstCombine] re-commit r218721 icmp-select-icmp optimization
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCompares.cpp
index 3a501132ebdba084e94fe951330ebb415aeb2e86..00623b1cbf6d2487882f125f1133c44390cfa9f2 100644 (file)
@@ -740,21 +740,6 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
 Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI,
                                             Value *X, ConstantInt *CI,
                                             ICmpInst::Predicate Pred) {
-  // If we have X+0, exit early (simplifying logic below) and let it get folded
-  // elsewhere.   icmp X+0, X  -> icmp X, X
-  if (CI->isZero()) {
-    bool isTrue = ICmpInst::isTrueWhenEqual(Pred);
-    return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue));
-  }
-
-  // (X+4) == X -> false.
-  if (Pred == ICmpInst::ICMP_EQ)
-    return ReplaceInstUsesWith(ICI, Builder->getFalse());
-
-  // (X+4) != X -> true.
-  if (Pred == ICmpInst::ICMP_NE)
-    return ReplaceInstUsesWith(ICI, Builder->getTrue());
-
   // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0,
   // so the values can never be equal.  Similarly for all other "or equals"
   // operators.
@@ -1100,29 +1085,33 @@ Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A,
     return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType()));
   }
 
+  bool IsNegative = false;
   if (IsAShr) {
     if (AP1.isNegative() != AP2.isNegative()) {
       // Arithmetic shift will never change the sign.
       return getConstant(false);
     }
-    // Both the constants are negative, take their positive to calculate
-    // log.
+    // Both the constants are negative, take their positive to calculate log.
     if (AP1.isNegative()) {
-      AP1 = -AP1;
-      AP2 = -AP2;
+      if (AP1.slt(AP2))
+        // Right-shifting won't increase the magnitude.
+        return getConstant(false);
+      IsNegative = true;
     }
   }
 
-  if (AP1.ugt(AP2)) {
+  if (!IsNegative && AP1.ugt(AP2))
     // Right-shifting will not increase the value.
     return getConstant(false);
-  }
 
   // Get the distance between the highest bit that's set.
-  int Shift = AP2.logBase2() - AP1.logBase2();
+  int Shift;
+  if (IsNegative)
+    Shift = (-AP2).logBase2() - (-AP1).logBase2();
+  else
+    Shift = AP2.logBase2() - AP1.logBase2();
 
-  // Use lshr here, since we've canonicalized to +ve numbers.
-  if (AP1 == AP2.lshr(Shift))
+  if (IsAShr ? AP1 == AP2.ashr(Shift) : AP1 == AP2.lshr(Shift))
     return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
 
   // Shifting const2 will never be equal to const1.
@@ -1144,7 +1133,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
       unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(),
              SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits();
       APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0);
-      computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne);
+      computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne, 0, &ICI);
 
       // If all the high bits are known, we can do this xform.
       if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) {
@@ -1366,6 +1355,48 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
         return &ICI;
       }
 
+      // (icmp pred (and (or (lshr X, Y), X), 1), 0) -->
+      //    (icmp pred (and X, (or (shl 1, Y), 1), 0))
+      //
+      // iff pred isn't signed
+      {
+        Value *X, *Y, *LShr;
+        if (!ICI.isSigned() && RHSV == 0) {
+          if (match(LHSI->getOperand(1), m_One())) {
+            Constant *One = cast<Constant>(LHSI->getOperand(1));
+            Value *Or = LHSI->getOperand(0);
+            if (match(Or, m_Or(m_Value(LShr), m_Value(X))) &&
+                match(LShr, m_LShr(m_Specific(X), m_Value(Y)))) {
+              unsigned UsesRemoved = 0;
+              if (LHSI->hasOneUse())
+                ++UsesRemoved;
+              if (Or->hasOneUse())
+                ++UsesRemoved;
+              if (LShr->hasOneUse())
+                ++UsesRemoved;
+              Value *NewOr = nullptr;
+              // Compute X & ((1 << Y) | 1)
+              if (auto *C = dyn_cast<Constant>(Y)) {
+                if (UsesRemoved >= 1)
+                  NewOr =
+                      ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One);
+              } else {
+                if (UsesRemoved >= 3)
+                  NewOr = Builder->CreateOr(Builder->CreateShl(One, Y,
+                                                               LShr->getName(),
+                                                               /*HasNUW=*/true),
+                                            One, Or->getName());
+              }
+              if (NewOr) {
+                Value *NewAnd = Builder->CreateAnd(X, NewOr, LHSI->getName());
+                ICI.setOperand(0, NewAnd);
+                return &ICI;
+              }
+            }
+          }
+        }
+      }
+
       // Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any
       // bit set in (X & AndCst) will produce a result greater than RHSV.
       if (ICI.getPredicate() == ICmpInst::ICMP_UGT) {
@@ -1461,16 +1492,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
           unsigned RHSLog2 = RHSV.logBase2();
 
           // (1 << X) >= 2147483648 -> X >= 31 -> X == 31
-          // (1 << X) >  2147483648 -> X >  31 -> false
-          // (1 << X) <= 2147483648 -> X <= 31 -> true
           // (1 << X) <  2147483648 -> X <  31 -> X != 31
           if (RHSLog2 == TypeBits-1) {
             if (Pred == ICmpInst::ICMP_UGE)
               Pred = ICmpInst::ICMP_EQ;
-            else if (Pred == ICmpInst::ICMP_UGT)
-              return ReplaceInstUsesWith(ICI, Builder->getFalse());
-            else if (Pred == ICmpInst::ICMP_ULE)
-              return ReplaceInstUsesWith(ICI, Builder->getTrue());
             else if (Pred == ICmpInst::ICMP_ULT)
               Pred = ICmpInst::ICMP_NE;
           }
@@ -1505,10 +1530,6 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
           if (RHSVIsPowerOf2)
             return new ICmpInst(
                 Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2()));
-
-          return ReplaceInstUsesWith(
-              ICI, Pred == ICmpInst::ICMP_EQ ? Builder->getFalse()
-                                             : Builder->getTrue());
         }
       }
       break;
@@ -2016,8 +2037,8 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
   // sign-extended; check for that condition. For example, if CI2 is 2^31 and
   // the operands of the add are 64 bits wide, we need at least 33 sign bits.
   unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1;
-  if (IC.ComputeNumSignBits(A) < NeededSignBits ||
-      IC.ComputeNumSignBits(B) < NeededSignBits)
+  if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits ||
+      IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits)
     return nullptr;
 
   // In order to replace the original add with a narrower
@@ -2425,7 +2446,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     Changed = true;
   }
 
-  if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL))
+  if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AT))
     return ReplaceInstUsesWith(I, V);
 
   // comparing -val or val with non-zero is the same as just comparing val
@@ -3205,7 +3226,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     // and       (A & ~B) != 0 --> (A & B) == 0
     // if A is a power of 2.
     if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) &&
-        match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(A) && I.isEquality())
+        match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(A, false,
+                                                       0, AT, &I, DT) &&
+                                I.isEquality())
       return new ICmpInst(I.getInversePredicate(),
                           Builder->CreateAnd(A, B),
                           Op1);
@@ -3595,7 +3618,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
 
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
-  if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL))
+  if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AT))
     return ReplaceInstUsesWith(I, V);
 
   // Simplify 'fcmp pred X, X'