if (ICmpInst::isSigned(Cond))
return nullptr;
- // Look through bitcasts.
- if (BitCastInst *BCI = dyn_cast<BitCastInst>(RHS))
- RHS = BCI->getOperand(0);
+ // Look through bitcasts and addrspacecasts. We do not however want to remove
+ // 0 GEPs.
+ if (!isa<GetElementPtrInst>(RHS))
+ RHS = RHS->stripPointerCasts();
Value *PtrBase = GEPLHS->getOperand(0);
if (DL && PtrBase == RHS && GEPLHS->isInBounds()) {
(GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse()) &&
PtrBase->stripPointerCasts() ==
GEPRHS->getOperand(0)->stripPointerCasts()) {
+ Value *LOffset = EmitGEPOffset(GEPLHS);
+ Value *ROffset = EmitGEPOffset(GEPRHS);
+
+ // If we looked through an addrspacecast between different sized address
+ // spaces, the LHS and RHS pointers are different sized
+ // integers. Truncate to the smaller one.
+ Type *LHSIndexTy = LOffset->getType();
+ Type *RHSIndexTy = ROffset->getType();
+ if (LHSIndexTy != RHSIndexTy) {
+ if (LHSIndexTy->getPrimitiveSizeInBits() <
+ RHSIndexTy->getPrimitiveSizeInBits()) {
+ ROffset = Builder->CreateTrunc(ROffset, LHSIndexTy);
+ } else
+ LOffset = Builder->CreateTrunc(LOffset, RHSIndexTy);
+ }
+
Value *Cmp = Builder->CreateICmp(ICmpInst::getSignedPredicate(Cond),
- EmitGEPOffset(GEPLHS),
- EmitGEPOffset(GEPRHS));
+ LOffset, ROffset);
return ReplaceInstUsesWith(I, Cmp);
}
}
// If one of the GEPs has all zero indices, recurse.
- bool AllZeros = true;
- for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i)
- if (!isa<Constant>(GEPLHS->getOperand(i)) ||
- !cast<Constant>(GEPLHS->getOperand(i))->isNullValue()) {
- AllZeros = false;
- break;
- }
- if (AllZeros)
+ if (GEPLHS->hasAllZeroIndices())
return FoldGEPICmp(GEPRHS, GEPLHS->getOperand(0),
ICmpInst::getSwappedPredicate(Cond), I);
// If the other GEP has all zero indices, recurse.
- AllZeros = true;
- for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i)
- if (!isa<Constant>(GEPRHS->getOperand(i)) ||
- !cast<Constant>(GEPRHS->getOperand(i))->isNullValue()) {
- AllZeros = false;
- break;
- }
- if (AllZeros)
+ if (GEPRHS->hasAllZeroIndices())
return FoldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I);
bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds();
Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI,
Value *X, ConstantInt *CI,
ICmpInst::Predicate Pred) {
- // If we have X+0, exit early (simplifying logic below) and let it get folded
- // elsewhere. icmp X+0, X -> icmp X, X
- if (CI->isZero()) {
- bool isTrue = ICmpInst::isTrueWhenEqual(Pred);
- return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue));
- }
-
- // (X+4) == X -> false.
- if (Pred == ICmpInst::ICMP_EQ)
- return ReplaceInstUsesWith(ICI, Builder->getFalse());
-
- // (X+4) != X -> true.
- if (Pred == ICmpInst::ICMP_NE)
- return ReplaceInstUsesWith(ICI, Builder->getTrue());
-
// From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0,
// so the values can never be equal. Similarly for all other "or equals"
// operators.
return nullptr;
}
+/// FoldICmpCstShrCst - Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" ->
+/// (icmp eq/ne A, Log2(const2/const1)) ->
+/// (icmp eq/ne A, Log2(const2) - Log2(const1)).
+Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A,
+ ConstantInt *CI1,
+ ConstantInt *CI2) {
+ assert(I.isEquality() && "Cannot fold icmp gt/lt");
+
+ auto getConstant = [&I, this](bool IsTrue) {
+ if (I.getPredicate() == I.ICMP_NE)
+ IsTrue = !IsTrue;
+ return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue));
+ };
+
+ auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) {
+ if (I.getPredicate() == I.ICMP_NE)
+ Pred = CmpInst::getInversePredicate(Pred);
+ return new ICmpInst(Pred, LHS, RHS);
+ };
+
+ APInt AP1 = CI1->getValue();
+ APInt AP2 = CI2->getValue();
+
+ // Don't bother doing any work for cases which InstSimplify handles.
+ if (AP2 == 0)
+ return nullptr;
+ bool IsAShr = isa<AShrOperator>(Op);
+ if (IsAShr) {
+ if (AP2.isAllOnesValue())
+ return nullptr;
+ if (AP2.isNegative() != AP1.isNegative())
+ return nullptr;
+ if (AP2.sgt(AP1))
+ return nullptr;
+ }
+
+ if (!AP1)
+ // 'A' must be large enough to shift out the highest set bit.
+ return getICmp(I.ICMP_UGT, A,
+ ConstantInt::get(A->getType(), AP2.logBase2()));
+
+ if (AP1 == AP2)
+ return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType()));
+
+ // Get the distance between the highest bit that's set.
+ int Shift;
+ // Both the constants are negative, take their positive to calculate log.
+ if (IsAShr && AP1.isNegative())
+ // Get the ones' complement of AP2 and AP1 when computing the distance.
+ Shift = (~AP2).logBase2() - (~AP1).logBase2();
+ else
+ Shift = AP2.logBase2() - AP1.logBase2();
+
+ if (Shift > 0) {
+ if (IsAShr ? AP1 == AP2.ashr(Shift) : AP1 == AP2.lshr(Shift))
+ return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
+ }
+ // Shifting const2 will never be equal to const1.
+ return getConstant(false);
+}
+
+/// FoldICmpCstShlCst - Handle "(icmp eq/ne (shl const2, A), const1)" ->
+/// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)).
+Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A,
+ ConstantInt *CI1,
+ ConstantInt *CI2) {
+ assert(I.isEquality() && "Cannot fold icmp gt/lt");
+
+ auto getConstant = [&I, this](bool IsTrue) {
+ if (I.getPredicate() == I.ICMP_NE)
+ IsTrue = !IsTrue;
+ return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue));
+ };
+
+ auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) {
+ if (I.getPredicate() == I.ICMP_NE)
+ Pred = CmpInst::getInversePredicate(Pred);
+ return new ICmpInst(Pred, LHS, RHS);
+ };
+
+ APInt AP1 = CI1->getValue();
+ APInt AP2 = CI2->getValue();
+
+ // Don't bother doing any work for cases which InstSimplify handles.
+ if (AP2 == 0)
+ return nullptr;
+
+ unsigned AP2TrailingZeros = AP2.countTrailingZeros();
+
+ if (!AP1 && AP2TrailingZeros != 0)
+ return getICmp(I.ICMP_UGE, A,
+ ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros));
+
+ if (AP1 == AP2)
+ return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType()));
+
+ // Get the distance between the lowest bits that are set.
+ int Shift = AP1.countTrailingZeros() - AP2TrailingZeros;
+
+ if (Shift > 0 && AP2.shl(Shift) == AP1)
+ return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
+
+ // Shifting const2 will never be equal to const1.
+ return getConstant(false);
+}
/// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)".
///
unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(),
SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits();
APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0);
- computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne);
+ computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne, 0, &ICI);
// If all the high bits are known, we can do this xform.
if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) {
return &ICI;
}
+ // (icmp pred (and (or (lshr X, Y), X), 1), 0) -->
+ // (icmp pred (and X, (or (shl 1, Y), 1), 0))
+ //
+ // iff pred isn't signed
+ {
+ Value *X, *Y, *LShr;
+ if (!ICI.isSigned() && RHSV == 0) {
+ if (match(LHSI->getOperand(1), m_One())) {
+ Constant *One = cast<Constant>(LHSI->getOperand(1));
+ Value *Or = LHSI->getOperand(0);
+ if (match(Or, m_Or(m_Value(LShr), m_Value(X))) &&
+ match(LShr, m_LShr(m_Specific(X), m_Value(Y)))) {
+ unsigned UsesRemoved = 0;
+ if (LHSI->hasOneUse())
+ ++UsesRemoved;
+ if (Or->hasOneUse())
+ ++UsesRemoved;
+ if (LShr->hasOneUse())
+ ++UsesRemoved;
+ Value *NewOr = nullptr;
+ // Compute X & ((1 << Y) | 1)
+ if (auto *C = dyn_cast<Constant>(Y)) {
+ if (UsesRemoved >= 1)
+ NewOr =
+ ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One);
+ } else {
+ if (UsesRemoved >= 3)
+ NewOr = Builder->CreateOr(Builder->CreateShl(One, Y,
+ LShr->getName(),
+ /*HasNUW=*/true),
+ One, Or->getName());
+ }
+ if (NewOr) {
+ Value *NewAnd = Builder->CreateAnd(X, NewOr, LHSI->getName());
+ ICI.setOperand(0, NewAnd);
+ return &ICI;
+ }
+ }
+ }
+ }
+ }
+
// Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any
// bit set in (X & AndCst) will produce a result greater than RHSV.
if (ICI.getPredicate() == ICmpInst::ICMP_UGT) {
unsigned RHSLog2 = RHSV.logBase2();
// (1 << X) >= 2147483648 -> X >= 31 -> X == 31
- // (1 << X) > 2147483648 -> X > 31 -> false
- // (1 << X) <= 2147483648 -> X <= 31 -> true
// (1 << X) < 2147483648 -> X < 31 -> X != 31
if (RHSLog2 == TypeBits-1) {
if (Pred == ICmpInst::ICMP_UGE)
Pred = ICmpInst::ICMP_EQ;
- else if (Pred == ICmpInst::ICMP_UGT)
- return ReplaceInstUsesWith(ICI, Builder->getFalse());
- else if (Pred == ICmpInst::ICMP_ULE)
- return ReplaceInstUsesWith(ICI, Builder->getTrue());
else if (Pred == ICmpInst::ICMP_ULT)
Pred = ICmpInst::ICMP_NE;
}
if (RHSVIsPowerOf2)
return new ICmpInst(
Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2()));
-
- return ReplaceInstUsesWith(
- ICI, Pred == ICmpInst::ICMP_EQ ? Builder->getFalse()
- : Builder->getTrue());
}
}
break;
// sign-extended; check for that condition. For example, if CI2 is 2^31 and
// the operands of the add are 64 bits wide, we need at least 33 sign bits.
unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1;
- if (IC.ComputeNumSignBits(A) < NeededSignBits ||
- IC.ComputeNumSignBits(B) < NeededSignBits)
+ if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits ||
+ IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits)
return nullptr;
// In order to replace the original add with a narrower
/// replacement required.
static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal,
Value *OtherVal, InstCombiner &IC) {
+ // Don't bother doing this transformation for pointers, don't do it for
+ // vectors.
+ if (!isa<IntegerType>(MulVal->getType()))
+ return nullptr;
+
assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal);
assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal);
- assert(isa<IntegerType>(MulVal->getType()));
Instruction *MulInstr = cast<Instruction>(MulVal);
assert(MulInstr->getOpcode() == Instruction::Mul);
- Instruction *LHS = cast<Instruction>(MulInstr->getOperand(0)),
- *RHS = cast<Instruction>(MulInstr->getOperand(1));
+ auto *LHS = cast<ZExtOperator>(MulInstr->getOperand(0)),
+ *RHS = cast<ZExtOperator>(MulInstr->getOperand(1));
assert(LHS->getOpcode() == Instruction::ZExt);
assert(RHS->getOpcode() == Instruction::ZExt);
Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
Changed = true;
}
- if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL))
+ if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AT))
return ReplaceInstUsesWith(I, V);
// comparing -val or val with non-zero is the same as just comparing val
Builder->getInt(CI->getValue()-1));
}
+ if (I.isEquality()) {
+ ConstantInt *CI2;
+ if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) ||
+ match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) {
+ // (icmp eq/ne (ashr/lshr const2, A), const1)
+ if (Instruction *Inst = FoldICmpCstShrCst(I, Op0, A, CI, CI2))
+ return Inst;
+ }
+ if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) {
+ // (icmp eq/ne (shl const2, A), const1)
+ if (Instruction *Inst = FoldICmpCstShlCst(I, Op0, A, CI, CI2))
+ return Inst;
+ }
+ }
+
// If this comparison is a normal comparison, it demands all
// bits, if it is a sign bit comparison, it only demands the sign bit.
bool UnusedBit;
// bit is set. If the comparison is against zero, then this is a check
// to see if *that* bit is set.
APInt Op0KnownZeroInverted = ~Op0KnownZero;
- if (~Op1KnownZero == 0 && Op0KnownZeroInverted.isPowerOf2()) {
+ if (~Op1KnownZero == 0) {
// If the LHS is an AND with the same constant, look through it.
Value *LHS = nullptr;
ConstantInt *LHSC = nullptr;
// If the LHS is 1 << x, and we know the result is a power of 2 like 8,
// then turn "((1 << x)&8) == 0" into "x != 3".
+ // or turn "((1 << x)&7) == 0" into "x > 2".
Value *X = nullptr;
if (match(LHS, m_Shl(m_One(), m_Value(X)))) {
- unsigned CmpVal = Op0KnownZeroInverted.countTrailingZeros();
- return new ICmpInst(ICmpInst::ICMP_NE, X,
- ConstantInt::get(X->getType(), CmpVal));
+ APInt ValToCheck = Op0KnownZeroInverted;
+ if (ValToCheck.isPowerOf2()) {
+ unsigned CmpVal = ValToCheck.countTrailingZeros();
+ return new ICmpInst(ICmpInst::ICMP_NE, X,
+ ConstantInt::get(X->getType(), CmpVal));
+ } else if ((++ValToCheck).isPowerOf2()) {
+ unsigned CmpVal = ValToCheck.countTrailingZeros() - 1;
+ return new ICmpInst(ICmpInst::ICMP_UGT, X,
+ ConstantInt::get(X->getType(), CmpVal));
+ }
}
// If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
// bit is set. If the comparison is against zero, then this is a check
// to see if *that* bit is set.
APInt Op0KnownZeroInverted = ~Op0KnownZero;
- if (~Op1KnownZero == 0 && Op0KnownZeroInverted.isPowerOf2()) {
+ if (~Op1KnownZero == 0) {
// If the LHS is an AND with the same constant, look through it.
Value *LHS = nullptr;
ConstantInt *LHSC = nullptr;
// If the LHS is 1 << x, and we know the result is a power of 2 like 8,
// then turn "((1 << x)&8) != 0" into "x == 3".
+ // or turn "((1 << x)&7) != 0" into "x < 3".
Value *X = nullptr;
if (match(LHS, m_Shl(m_One(), m_Value(X)))) {
- unsigned CmpVal = Op0KnownZeroInverted.countTrailingZeros();
- return new ICmpInst(ICmpInst::ICMP_EQ, X,
- ConstantInt::get(X->getType(), CmpVal));
+ APInt ValToCheck = Op0KnownZeroInverted;
+ if (ValToCheck.isPowerOf2()) {
+ unsigned CmpVal = ValToCheck.countTrailingZeros();
+ return new ICmpInst(ICmpInst::ICMP_EQ, X,
+ ConstantInt::get(X->getType(), CmpVal));
+ } else if ((++ValToCheck).isPowerOf2()) {
+ unsigned CmpVal = ValToCheck.countTrailingZeros();
+ return new ICmpInst(ICmpInst::ICMP_ULT, X,
+ ConstantInt::get(X->getType(), CmpVal));
+ }
}
// If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
if (BO1 && BO1->getOpcode() == Instruction::Add)
C = BO1->getOperand(0), D = BO1->getOperand(1);
+ // icmp (X+cst) < 0 --> X < -cst
+ if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero()))
+ if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B))
+ if (!RHSC->isMinValue(/*isSigned=*/true))
+ return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC));
+
// icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow.
if ((A == Op1 || B == Op1) && NoOp0WrapProblem)
return new ICmpInst(Pred, A == Op1 ? B : A,
// and (A & ~B) != 0 --> (A & B) == 0
// if A is a power of 2.
if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) &&
- match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(A) && I.isEquality())
+ match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(A, false,
+ 0, AT, &I, DT) &&
+ I.isEquality())
return new ICmpInst(I.getInversePredicate(),
Builder->CreateAnd(A, B),
Op1);
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
- if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL))
+ if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AT))
return ReplaceInstUsesWith(I, V);
// Simplify 'fcmp pred X, X'