InstCombine: Don't assume that m_ZExt matches an Instruction
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCompares.cpp
index 00623b1cbf6d2487882f125f1133c44390cfa9f2..f7eb16cbb96dd5f985f0103db2db646176e2d4b8 100644 (file)
@@ -1052,66 +1052,83 @@ Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A,
   APInt AP1 = CI1->getValue();
   APInt AP2 = CI2->getValue();
 
-  if (!AP1) {
-    if (!AP2) {
-      // Both Constants are 0.
-      return getConstant(true);
-    }
-
-    if (cast<BinaryOperator>(Op)->isExact())
-      return getConstant(false);
-
-    if (AP2.isNegative()) {
-      // MSB is set, so a lshr with a large enough 'A' would be undefined.
-      return getConstant(false);
-    }
+  // Don't bother doing any work for cases which InstSimplify handles.
+  if (AP2 == 0)
+    return nullptr;
+  bool IsAShr = isa<AShrOperator>(Op);
+  if (IsAShr) {
+    if (AP2.isAllOnesValue())
+      return nullptr;
+    if (AP2.isNegative() != AP1.isNegative())
+      return nullptr;
+    if (AP2.sgt(AP1))
+      return nullptr;
+  }
 
+  if (!AP1)
     // 'A' must be large enough to shift out the highest set bit.
     return getICmp(I.ICMP_UGT, A,
                    ConstantInt::get(A->getType(), AP2.logBase2()));
-  }
-
-  if (!AP2) {
-    // Shifting 0 by any value gives 0.
-    return getConstant(false);
-  }
 
-  bool IsAShr = isa<AShrOperator>(Op);
-  if (AP1 == AP2) {
-    if (AP1.isAllOnesValue() && IsAShr) {
-      // Arithmatic shift of -1 is always -1.
-      return getConstant(true);
-    }
+  if (AP1 == AP2)
     return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType()));
-  }
-
-  bool IsNegative = false;
-  if (IsAShr) {
-    if (AP1.isNegative() != AP2.isNegative()) {
-      // Arithmetic shift will never change the sign.
-      return getConstant(false);
-    }
-    // Both the constants are negative, take their positive to calculate log.
-    if (AP1.isNegative()) {
-      if (AP1.slt(AP2))
-        // Right-shifting won't increase the magnitude.
-        return getConstant(false);
-      IsNegative = true;
-    }
-  }
-
-  if (!IsNegative && AP1.ugt(AP2))
-    // Right-shifting will not increase the value.
-    return getConstant(false);
 
   // Get the distance between the highest bit that's set.
   int Shift;
-  if (IsNegative)
-    Shift = (-AP2).logBase2() - (-AP1).logBase2();
+  // Both the constants are negative, take their positive to calculate log.
+  if (IsAShr && AP1.isNegative())
+    // Get the ones' complement of AP2 and AP1 when computing the distance.
+    Shift = (~AP2).logBase2() - (~AP1).logBase2();
   else
     Shift = AP2.logBase2() - AP1.logBase2();
 
-  if (IsAShr ? AP1 == AP2.ashr(Shift) : AP1 == AP2.lshr(Shift))
+  if (Shift > 0) {
+    if (IsAShr ? AP1 == AP2.ashr(Shift) : AP1 == AP2.lshr(Shift))
+      return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
+  }
+  // Shifting const2 will never be equal to const1.
+  return getConstant(false);
+}
+
+/// FoldICmpCstShlCst - Handle "(icmp eq/ne (shl const2, A), const1)" ->
+/// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)).
+Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A,
+                                             ConstantInt *CI1,
+                                             ConstantInt *CI2) {
+  assert(I.isEquality() && "Cannot fold icmp gt/lt");
+
+  auto getConstant = [&I, this](bool IsTrue) {
+    if (I.getPredicate() == I.ICMP_NE)
+      IsTrue = !IsTrue;
+    return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue));
+  };
+
+  auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) {
+    if (I.getPredicate() == I.ICMP_NE)
+      Pred = CmpInst::getInversePredicate(Pred);
+    return new ICmpInst(Pred, LHS, RHS);
+  };
+
+  APInt AP1 = CI1->getValue();
+  APInt AP2 = CI2->getValue();
+
+  // Don't bother doing any work for cases which InstSimplify handles.
+  if (AP2 == 0)
+    return nullptr;
+
+  unsigned AP2TrailingZeros = AP2.countTrailingZeros();
+
+  if (!AP1 && AP2TrailingZeros != 0)
+    return getICmp(I.ICMP_UGE, A,
+                   ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros));
+
+  if (AP1 == AP2)
+    return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType()));
+
+  // Get the distance between the lowest bits that are set.
+  int Shift = AP1.countTrailingZeros() - AP2TrailingZeros;
+
+  if (Shift > 0 && AP2.shl(Shift) == AP1)
     return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
 
   // Shifting const2 will never be equal to const1.
@@ -2143,8 +2160,8 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal,
   Instruction *MulInstr = cast<Instruction>(MulVal);
   assert(MulInstr->getOpcode() == Instruction::Mul);
 
-  Instruction *LHS = cast<Instruction>(MulInstr->getOperand(0)),
-              *RHS = cast<Instruction>(MulInstr->getOperand(1));
+  auto *LHS = cast<ZExtOperator>(MulInstr->getOperand(0)),
+       *RHS = cast<ZExtOperator>(MulInstr->getOperand(1));
   assert(LHS->getOpcode() == Instruction::ZExt);
   assert(RHS->getOpcode() == Instruction::ZExt);
   Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
@@ -2574,12 +2591,18 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
                           Builder->getInt(CI->getValue()-1));
     }
 
-    // (icmp eq/ne (ashr/lshr const2, A), const1)
     if (I.isEquality()) {
       ConstantInt *CI2;
       if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) ||
           match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) {
-        return FoldICmpCstShrCst(I, Op0, A, CI, CI2);
+        // (icmp eq/ne (ashr/lshr const2, A), const1)
+        if (Instruction *Inst = FoldICmpCstShrCst(I, Op0, A, CI, CI2))
+          return Inst;
+      }
+      if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) {
+        // (icmp eq/ne (shl const2, A), const1)
+        if (Instruction *Inst = FoldICmpCstShlCst(I, Op0, A, CI, CI2))
+          return Inst;
       }
     }
 
@@ -2992,6 +3015,12 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     if (BO1 && BO1->getOpcode() == Instruction::Add)
       C = BO1->getOperand(0), D = BO1->getOperand(1);
 
+    // icmp (X+cst) < 0 --> X < -cst
+    if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero()))
+      if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B))
+        if (!RHSC->isMinValue(/*isSigned=*/true))
+          return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC));
+
     // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow.
     if ((A == Op1 || B == Op1) && NoOp0WrapProblem)
       return new ICmpInst(Pred, A == Op1 ? B : A,