X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FTransforms%2FInstCombine%2FInstCombineCompares.cpp;h=c2b9e71c6f285bf6e28f0192a02a8d77ad2c1ba1;hb=107ffd58f70651cc928416f6aaf267accc992c23;hp=861cf92d281e99e575e61764244147d710af8780;hpb=9ee17208115482441953127615231c59a2f4d052;p=oota-llvm.git diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 861cf92d281..c2b9e71c6f2 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -22,13 +22,17 @@ using namespace llvm; using namespace PatternMatch; +static ConstantInt *getOne(Constant *C) { + return ConstantInt::get(cast(C->getType()), 1); +} + /// AddOne - Add one to a ConstantInt static Constant *AddOne(Constant *C) { return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1)); } /// SubOne - Subtract one from a ConstantInt -static Constant *SubOne(ConstantInt *C) { - return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); +static Constant *SubOne(Constant *C) { + return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); } static ConstantInt *ExtractElement(Constant *V, Constant *Idx) { @@ -160,8 +164,8 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const APInt& KnownZero, Max = KnownOne|UnknownBits; if (UnknownBits.isNegative()) { // Sign bit is unknown - Min.set(Min.getBitWidth()-1); - Max.clear(Max.getBitWidth()-1); + Min.setBit(Min.getBitWidth()-1); + Max.clearBit(Max.getBitWidth()-1); } } @@ -465,8 +469,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, /// /// If we can't emit an optimized form for this expression, this returns null. /// -static Value *EvaluateGEPOffsetExpression(User *GEP, Instruction &I, - InstCombiner &IC) { +static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { TargetData &TD = *IC.getTargetData(); gep_type_iterator GTI = gep_type_begin(GEP); @@ -529,10 +532,10 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, Instruction &I, // Cast to intptrty in case a truncation occurs. If an extension is needed, // we don't need to bother extending: the extension won't affect where the // computation crosses zero. - if (VariableIdx->getType()->getPrimitiveSizeInBits() > IntPtrWidth) - VariableIdx = new TruncInst(VariableIdx, - TD.getIntPtrType(VariableIdx->getContext()), - VariableIdx->getName(), &I); + if (VariableIdx->getType()->getPrimitiveSizeInBits() > IntPtrWidth) { + const Type *IntPtrTy = TD.getIntPtrType(VariableIdx->getContext()); + VariableIdx = IC.Builder->CreateTrunc(VariableIdx, IntPtrTy); + } return VariableIdx; } @@ -554,11 +557,10 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, Instruction &I, // Okay, we can do this evaluation. Start by converting the index to intptr. const Type *IntPtrTy = TD.getIntPtrType(VariableIdx->getContext()); if (VariableIdx->getType() != IntPtrTy) - VariableIdx = CastInst::CreateIntegerCast(VariableIdx, IntPtrTy, - true /*SExt*/, - VariableIdx->getName(), &I); + VariableIdx = IC.Builder->CreateIntCast(VariableIdx, IntPtrTy, + true /*Signed*/); Constant *OffsetVal = ConstantInt::get(IntPtrTy, NewOffs); - return BinaryOperator::CreateAdd(VariableIdx, OffsetVal, "offset", &I); + return IC.Builder->CreateAdd(VariableIdx, OffsetVal, "offset"); } /// FoldGEPICmp - Fold comparisons between a GEP instruction and something @@ -576,7 +578,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // This transformation (ignoring the base and scales) is valid because we // know pointers can't overflow since the gep is inbounds. See if we can // output an optimized form. - Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, I, *this); + Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, *this); // If not, synthesize the offset the hard way. if (Offset == 0) @@ -630,6 +632,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, if (AllZeros) return FoldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); + bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands()) { // If the GEPs only differ by one index, compare it. unsigned NumDifferences = 0; // Keep track of # differences. @@ -652,7 +655,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, ConstantInt::get(Type::getInt1Ty(I.getContext()), ICmpInst::isTrueWhenEqual(Cond))); - else if (NumDifferences == 1) { + else if (NumDifferences == 1 && GEPsInBounds) { Value *LHSV = GEPLHS->getOperand(DiffOperand); Value *RHSV = GEPRHS->getOperand(DiffOperand); // Make sure we do a signed comparison here. @@ -663,6 +666,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Only lower this if the icmp is the only user of the GEP or if we expect // the result to fold to a constant! if (TD && + GEPsInBounds && (isa(GEPLHS) || GEPLHS->hasOneUse()) && (isa(GEPRHS) || GEPRHS->hasOneUse())) { // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2) @@ -694,25 +698,14 @@ Instruction *InstCombiner::FoldICmpAddOpCst(ICmpInst &ICI, if (Pred == ICmpInst::ICMP_NE) return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(X->getContext())); - // If this is an instruction (as opposed to constantexpr) get NUW/NSW info. - bool isNUW = false, isNSW = false; - if (BinaryOperator *Add = dyn_cast(TheAdd)) { - isNUW = Add->hasNoUnsignedWrap(); - isNSW = Add->hasNoSignedWrap(); - } - // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, - // so the values can never be equal. Similiarly for all other "or equals" + // so the values can never be equal. Similarly for all other "or equals" // operators. // (X+1) X >u (MAXUINT-1) --> X == 255 // (X+2) X >u (MAXUINT-2) --> X > 253 // (X+MAXUINT) X >u (MAXUINT-MAXUINT) --> X != 0 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) { - // If this is an NUW add, then this is always false. - if (isNUW) - return ReplaceInstUsesWith(ICI, ConstantInt::getFalse(X->getContext())); - Value *R = ConstantExpr::getSub(ConstantInt::getAllOnesValue(CI->getType()), CI); return new ICmpInst(ICmpInst::ICMP_UGT, X, R); @@ -721,12 +714,8 @@ Instruction *InstCombiner::FoldICmpAddOpCst(ICmpInst &ICI, // (X+1) >u X --> X X != 255 // (X+2) >u X --> X X u X --> X X X == 0 - if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) { - // If this is an NUW add, then this is always true. - if (isNUW) - return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(X->getContext())); + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantExpr::getNeg(CI)); - } unsigned BitWidth = CI->getType()->getPrimitiveSizeInBits(); ConstantInt *SMax = ConstantInt::get(X->getContext(), @@ -738,16 +727,8 @@ Instruction *InstCombiner::FoldICmpAddOpCst(ICmpInst &ICI, // (X+MINSINT) X >s (MAXSINT-MINSINT) --> X >s -1 // (X+ -2) X >s (MAXSINT- -2) --> X >s 126 // (X+ -1) X >s (MAXSINT- -1) --> X != 127 - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) { - // If this is an NSW add, then we have two cases: if the constant is - // positive, then this is always false, if negative, this is always true. - if (isNSW) { - bool isTrue = CI->getValue().isNegative(); - return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue)); - } - + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantExpr::getSub(SMax, CI)); - } // (X+ 1) >s X --> X X != 127 // (X+ 2) >s X --> X X s X --> X X s X --> X X == -128 - // If this is an NSW add, then we have two cases: if the constant is - // positive, then this is always true, if negative, this is always false. - if (isNSW) { - bool isTrue = !CI->getValue().isNegative(); - return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue)); - } - assert(Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE); Constant *C = ConstantInt::get(X->getContext(), CI->getValue()-1); return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantExpr::getSub(SMax, C)); @@ -782,7 +756,7 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // results than (x /s C1) getOpcode() == Instruction::SDiv; if (!ICI.isEquality() && DivIsSigned != ICI.isSigned()) return 0; @@ -790,9 +764,11 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, return 0; // The ProdOV computation fails on divide by zero. if (DivIsSigned && DivRHS->isAllOnesValue()) return 0; // The overflow computation also screws up here - if (DivRHS->isOne()) - return 0; // Not worth bothering, and eliminates some funny cases - // with INT_MIN. + if (DivRHS->isOne()) { + // This eliminates some funny cases with INT_MIN. + ICI.setOperand(0, DivI->getOperand(0)); // X/1 == X. + return &ICI; + } // Compute Prod = CI * DivRHS. We are essentially solving an equation // of form X/C1=C2. We solve for X by multiplying C1 (DivRHS) and @@ -809,6 +785,10 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // Get the ICmp opcode ICmpInst::Predicate Pred = ICI.getPredicate(); + /// If the division is known to be exact, then there is no remainder from the + /// divide, so the covered range size is unit, otherwise it is the divisor. + ConstantInt *RangeSize = DivI->isExact() ? getOne(Prod) : DivRHS; + // Figure out the interval that is being checked. For example, a comparison // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). // Compute this interval based on the constants involved and the signedness of @@ -818,38 +798,43 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. int LoOverflow = 0, HiOverflow = 0; Constant *LoBound = 0, *HiBound = 0; - + if (!DivIsSigned) { // udiv // e.g. X/5 op 3 --> [15, 20) LoBound = Prod; HiOverflow = LoOverflow = ProdOV; - if (!HiOverflow) - HiOverflow = AddWithOverflow(HiBound, LoBound, DivRHS, false); + if (!HiOverflow) { + // If this is not an exact divide, then many values in the range collapse + // to the same result value. + HiOverflow = AddWithOverflow(HiBound, LoBound, RangeSize, false); + } + } else if (DivRHS->getValue().isStrictlyPositive()) { // Divisor is > 0. if (CmpRHSV == 0) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) - LoBound = cast(ConstantExpr::getNeg(SubOne(DivRHS))); - HiBound = DivRHS; + LoBound = ConstantExpr::getNeg(SubOne(RangeSize)); + HiBound = RangeSize; } else if (CmpRHSV.isStrictlyPositive()) { // (X / pos) op pos LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) - HiOverflow = AddWithOverflow(HiBound, Prod, DivRHS, true); + HiOverflow = AddWithOverflow(HiBound, Prod, RangeSize, true); } else { // (X / pos) op neg // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) HiBound = AddOne(Prod); LoOverflow = HiOverflow = ProdOV ? -1 : 0; if (!LoOverflow) { - ConstantInt* DivNeg = - cast(ConstantExpr::getNeg(DivRHS)); + ConstantInt *DivNeg =cast(ConstantExpr::getNeg(RangeSize)); LoOverflow = AddWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; - } + } } } else if (DivRHS->getValue().isNegative()) { // Divisor is < 0. + if (DivI->isExact()) + RangeSize = cast(ConstantExpr::getNeg(RangeSize)); if (CmpRHSV == 0) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) - LoBound = AddOne(DivRHS); - HiBound = cast(ConstantExpr::getNeg(DivRHS)); + LoBound = AddOne(RangeSize); + HiBound = cast(ConstantExpr::getNeg(RangeSize)); if (HiBound == DivRHS) { // -INTMIN = INTMIN HiOverflow = 1; // [INTMIN+1, overflow) HiBound = 0; // e.g. X/INTMIN = 0 --> X > INTMIN @@ -859,12 +844,12 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, HiBound = AddOne(Prod); HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) - LoOverflow = AddWithOverflow(LoBound, HiBound, DivRHS, true) ? -1 : 0; + LoOverflow = AddWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; } else { // (X / neg) op neg LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) LoOverflow = HiOverflow = ProdOV; if (!HiOverflow) - HiOverflow = SubWithOverflow(HiBound, Prod, DivRHS, true); + HiOverflow = SubWithOverflow(HiBound, Prod, RangeSize, true); } // Dividing by a negative swaps the condition. LT <-> GT @@ -883,9 +868,8 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, X, HiBound); - return ReplaceInstUsesWith(ICI, - InsertRangeTest(X, LoBound, HiBound, DivIsSigned, - true)); + return ReplaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, + DivIsSigned, true)); case ICmpInst::ICMP_NE: if (LoOverflow && HiOverflow) return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext())); @@ -908,15 +892,102 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, case ICmpInst::ICMP_SGT: if (HiOverflow == +1) // High bound greater than input range. return ReplaceInstUsesWith(ICI, ConstantInt::getFalse(ICI.getContext())); - else if (HiOverflow == -1) // High bound less than input range. + if (HiOverflow == -1) // High bound less than input range. return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext())); if (Pred == ICmpInst::ICMP_UGT) return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); - else - return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); + return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); } } +/// FoldICmpShrCst - Handle "icmp(([al]shr X, cst1), cst2)". +Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, + ConstantInt *ShAmt) { + const APInt &CmpRHSV = cast(ICI.getOperand(1))->getValue(); + + // Check that the shift amount is in range. If not, don't perform + // undefined shifts. When the shift is visited it will be + // simplified. + uint32_t TypeBits = CmpRHSV.getBitWidth(); + uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); + if (ShAmtVal >= TypeBits || ShAmtVal == 0) + return 0; + + if (!ICI.isEquality()) { + // If we have an unsigned comparison and an ashr, we can't simplify this. + // Similarly for signed comparisons with lshr. + if (ICI.isSigned() != (Shr->getOpcode() == Instruction::AShr)) + return 0; + + // Otherwise, all lshr and all exact ashr's are equivalent to a udiv/sdiv by + // a power of 2. Since we already have logic to simplify these, transform + // to div and then simplify the resultant comparison. + if (Shr->getOpcode() == Instruction::AShr && + !Shr->isExact()) + return 0; + + // Revisit the shift (to delete it). + Worklist.Add(Shr); + + Constant *DivCst = + ConstantInt::get(Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal)); + + Value *Tmp = + Shr->getOpcode() == Instruction::AShr ? + Builder->CreateSDiv(Shr->getOperand(0), DivCst, "", Shr->isExact()) : + Builder->CreateUDiv(Shr->getOperand(0), DivCst, "", Shr->isExact()); + + ICI.setOperand(0, Tmp); + + // If the builder folded the binop, just return it. + BinaryOperator *TheDiv = dyn_cast(Tmp); + if (TheDiv == 0) + return &ICI; + + // Otherwise, fold this div/compare. + assert(TheDiv->getOpcode() == Instruction::SDiv || + TheDiv->getOpcode() == Instruction::UDiv); + + Instruction *Res = FoldICmpDivCst(ICI, TheDiv, cast(DivCst)); + assert(Res && "This div/cst should have folded!"); + return Res; + } + + + // If we are comparing against bits always shifted out, the + // comparison cannot succeed. + APInt Comp = CmpRHSV << ShAmtVal; + ConstantInt *ShiftedCmpRHS = ConstantInt::get(ICI.getContext(), Comp); + if (Shr->getOpcode() == Instruction::LShr) + Comp = Comp.lshr(ShAmtVal); + else + Comp = Comp.ashr(ShAmtVal); + + if (Comp != CmpRHSV) { // Comparing against a bit that we know is zero. + bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; + Constant *Cst = ConstantInt::get(Type::getInt1Ty(ICI.getContext()), + IsICMP_NE); + return ReplaceInstUsesWith(ICI, Cst); + } + + // Otherwise, check to see if the bits shifted out are known to be zero. + // If so, we can compare against the unshifted value: + // (X & 4) >> 1 == 2 --> (X & 4) == 4. + if (Shr->hasOneUse() && Shr->isExact()) + return new ICmpInst(ICI.getPredicate(), Shr->getOperand(0), ShiftedCmpRHS); + + if (Shr->hasOneUse()) { + // Otherwise strength reduce the shift into an and. + APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); + Constant *Mask = ConstantInt::get(ICI.getContext(), Val); + + Value *And = Builder->CreateAnd(Shr->getOperand(0), + Mask, Shr->getName()+".mask"); + return new ICmpInst(ICI.getPredicate(), And, ShiftedCmpRHS); + } + return 0; +} + /// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)". /// @@ -939,8 +1010,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // If all the high bits are known, we can do this xform. if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { // Pull in the high bits from known-ones set. - APInt NewRHS(RHS->getValue()); - NewRHS.zext(SrcBits); + APInt NewRHS = RHS->getValue().zext(SrcBits); NewRHS |= KnownOne; return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), ConstantInt::get(ICI.getContext(), NewRHS)); @@ -1022,10 +1092,8 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, (AndCST->getValue().isNonNegative() && RHSV.isNonNegative()))) { uint32_t BitWidth = cast(Cast->getOperand(0)->getType())->getBitWidth(); - APInt NewCST = AndCST->getValue(); - NewCST.zext(BitWidth); - APInt NewCI = RHSV; - NewCI.zext(BitWidth); + APInt NewCST = AndCST->getValue().zext(BitWidth); + APInt NewCI = RHSV.zext(BitWidth); Value *NewAnd = Builder->CreateAnd(Cast->getOperand(0), ConstantInt::get(ICI.getContext(), NewCST), @@ -1145,7 +1213,6 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (match(LHSI, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) { // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 // -> and (icmp eq P, null), (icmp eq Q, null). - Value *ICIP = Builder->CreateICmp(ICI.getPredicate(), P, Constant::getNullValue(P->getType())); Value *ICIQ = Builder->CreateICmp(ICI.getPredicate(), Q, @@ -1185,6 +1252,12 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, return ReplaceInstUsesWith(ICI, Cst); } + // If the shift is NUW, then it is just shifting out zeros, no need for an + // AND. + if (cast(LHSI)->hasNoUnsignedWrap()) + return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), + ConstantExpr::getLShr(RHS, ShAmt)); + if (LHSI->hasOneUse()) { // Otherwise strength reduce the shift into an and. uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); @@ -1195,8 +1268,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, Value *And = Builder->CreateAnd(LHSI->getOperand(0),Mask, LHSI->getName()+".mask"); return new ICmpInst(ICI.getPredicate(), And, - ConstantInt::get(ICI.getContext(), - RHSV.lshr(ShAmtVal))); + ConstantExpr::getLShr(RHS, ShAmt)); } } @@ -1205,8 +1277,9 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (LHSI->hasOneUse() && isSignBitCheck(ICI.getPredicate(), RHS, TrueIfSigned)) { // (X << 31) (X&1) != 0 - Constant *Mask = ConstantInt::get(ICI.getContext(), APInt(TypeBits, 1) << - (TypeBits-ShAmt->getZExtValue()-1)); + Constant *Mask = ConstantInt::get(LHSI->getOperand(0)->getType(), + APInt::getOneBitSet(TypeBits, + TypeBits-ShAmt->getZExtValue()-1)); Value *And = Builder->CreateAnd(LHSI->getOperand(0), Mask, LHSI->getName()+".mask"); return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, @@ -1217,53 +1290,17 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, case Instruction::LShr: // (icmp pred (shr X, ShAmt), CI) case Instruction::AShr: { - // Only handle equality comparisons of shift-by-constant. - ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1)); - if (!ShAmt || !ICI.isEquality()) break; - - // Check that the shift amount is in range. If not, don't perform - // undefined shifts. When the shift is visited it will be - // simplified. - uint32_t TypeBits = RHSV.getBitWidth(); - if (ShAmt->uge(TypeBits)) - break; - - uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); - - // If we are comparing against bits always shifted out, the - // comparison cannot succeed. - APInt Comp = RHSV << ShAmtVal; - if (LHSI->getOpcode() == Instruction::LShr) - Comp = Comp.lshr(ShAmtVal); - else - Comp = Comp.ashr(ShAmtVal); - - if (Comp != RHSV) { // Comparing against a bit that we know is zero. - bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - Constant *Cst = ConstantInt::get(Type::getInt1Ty(ICI.getContext()), - IsICMP_NE); - return ReplaceInstUsesWith(ICI, Cst); - } - - // Otherwise, check to see if the bits shifted out are known to be zero. - // If so, we can compare against the unshifted value: - // (X & 4) >> 1 == 2 --> (X & 4) == 4. - if (LHSI->hasOneUse() && - MaskedValueIsZero(LHSI->getOperand(0), - APInt::getLowBitsSet(Comp.getBitWidth(), ShAmtVal))) { - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - ConstantExpr::getShl(RHS, ShAmt)); + // Handle equality comparisons of shift-by-constant. + BinaryOperator *BO = cast(LHSI); + if (ConstantInt *ShAmt = dyn_cast(LHSI->getOperand(1))) { + if (Instruction *Res = FoldICmpShrCst(ICI, BO, ShAmt)) + return Res; } - - if (LHSI->hasOneUse()) { - // Otherwise strength reduce the shift into an and. - APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); - Constant *Mask = ConstantInt::get(ICI.getContext(), Val); - - Value *And = Builder->CreateAnd(LHSI->getOperand(0), - Mask, LHSI->getName()+".mask"); - return new ICmpInst(ICI.getPredicate(), And, - ConstantExpr::getShl(RHS, ShAmt)); + + // Handle exact shr's. + if (ICI.isEquality() && BO->isExact() && BO->hasOneUse()) { + if (RHSV.isMinValue()) + return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), RHS); } break; } @@ -1347,9 +1384,9 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (Value *NegVal = dyn_castNegVal(BOp1)) return new ICmpInst(ICI.getPredicate(), BOp0, NegVal); - else if (Value *NegVal = dyn_castNegVal(BOp0)) + if (Value *NegVal = dyn_castNegVal(BOp0)) return new ICmpInst(ICI.getPredicate(), NegVal, BOp1); - else if (BO->hasOneUse()) { + if (BO->hasOneUse()) { Value *Neg = Builder->CreateNeg(BOp1); Neg->takeName(BO); return new ICmpInst(ICI.getPredicate(), BOp0, Neg); @@ -1374,7 +1411,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, case Instruction::Or: // If bits are being or'd in that are not present in the constant we // are comparing against, then the comparison could never succeed! - if (Constant *BOC = dyn_cast(BO->getOperand(1))) { + if (ConstantInt *BOC = dyn_cast(BO->getOperand(1))) { Constant *NotCI = ConstantExpr::getNot(RHS); if (!ConstantExpr::getAnd(BOC, NotCI)->isNullValue()) return ReplaceInstUsesWith(ICI, @@ -1423,7 +1460,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, switch (II->getIntrinsicID()) { case Intrinsic::bswap: Worklist.Add(II); - ICI.setOperand(0, II->getOperand(1)); + ICI.setOperand(0, II->getArgOperand(0)); ICI.setOperand(1, ConstantInt::get(II->getContext(), RHSV.byteSwap())); return &ICI; case Intrinsic::ctlz: @@ -1431,7 +1468,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // ctz(A) == bitwidth(a) -> A == 0 and likewise for != if (RHSV == RHS->getType()->getBitWidth()) { Worklist.Add(II); - ICI.setOperand(0, II->getOperand(1)); + ICI.setOperand(0, II->getArgOperand(0)); ICI.setOperand(1, ConstantInt::get(RHS->getType(), 0)); return &ICI; } @@ -1440,13 +1477,13 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // popcount(A) == 0 -> A == 0 and likewise for != if (RHS->isZero()) { Worklist.Add(II); - ICI.setOperand(0, II->getOperand(1)); + ICI.setOperand(0, II->getArgOperand(0)); ICI.setOperand(1, RHS); return &ICI; } break; default: - break; + break; } } } @@ -1543,50 +1580,174 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // The re-extended constant changed so the constant cannot be represented // in the shorter type. Consequently, we cannot emit a simple comparison. + // All the cases that fold to true or false will have already been handled + // by SimplifyICmpInst, so only deal with the tricky case. - // First, handle some easy cases. We know the result cannot be equal at this - // point so handle the ICI.isEquality() cases - if (ICI.getPredicate() == ICmpInst::ICMP_EQ) - return ReplaceInstUsesWith(ICI, ConstantInt::getFalse(ICI.getContext())); - if (ICI.getPredicate() == ICmpInst::ICMP_NE) - return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext())); + if (isSignedCmp || !isSignedExt) + return 0; // Evaluate the comparison for LT (we invert for GT below). LE and GE cases // should have been folded away previously and not enter in here. - Value *Result; - if (isSignedCmp) { - // We're performing a signed comparison. - if (cast(CI)->getValue().isNegative()) - Result = ConstantInt::getFalse(ICI.getContext()); // X < (small) --> false - else - Result = ConstantInt::getTrue(ICI.getContext()); // X < (large) --> true - } else { - // We're performing an unsigned comparison. - if (isSignedExt) { - // We're performing an unsigned comp with a sign extended value. - // This is true if the input is >= 0. [aka >s -1] - Constant *NegOne = Constant::getAllOnesValue(SrcTy); - Result = Builder->CreateICmpSGT(LHSCIOp, NegOne, ICI.getName()); - } else { - // Unsigned extend & unsigned compare -> always true. - Result = ConstantInt::getTrue(ICI.getContext()); - } - } + + // We're performing an unsigned comp with a sign extended value. + // This is true if the input is >= 0. [aka >s -1] + Constant *NegOne = Constant::getAllOnesValue(SrcTy); + Value *Result = Builder->CreateICmpSGT(LHSCIOp, NegOne, ICI.getName()); // Finally, return the value computed. - if (ICI.getPredicate() == ICmpInst::ICMP_ULT || - ICI.getPredicate() == ICmpInst::ICMP_SLT) + if (ICI.getPredicate() == ICmpInst::ICMP_ULT) return ReplaceInstUsesWith(ICI, Result); - assert((ICI.getPredicate()==ICmpInst::ICMP_UGT || - ICI.getPredicate()==ICmpInst::ICMP_SGT) && - "ICmp should be folded!"); - if (Constant *CI = dyn_cast(Result)) - return ReplaceInstUsesWith(ICI, ConstantExpr::getNot(CI)); + assert(ICI.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); return BinaryOperator::CreateNot(Result); } +/// ProcessUGT_ADDCST_ADD - The caller has matched a pattern of the form: +/// I = icmp ugt (add (add A, B), CI2), CI1 +/// If this is of the form: +/// sum = a + b +/// if (sum+128 >u 255) +/// Then replace it with llvm.sadd.with.overflow.i8. +/// +static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, + ConstantInt *CI2, ConstantInt *CI1, + InstCombiner &IC) { + // The transformation we're trying to do here is to transform this into an + // llvm.sadd.with.overflow. To do this, we have to replace the original add + // with a narrower add, and discard the add-with-constant that is part of the + // range check (if we can't eliminate it, this isn't profitable). + + // In order to eliminate the add-with-constant, the compare can be its only + // use. + Instruction *AddWithCst = cast(I.getOperand(0)); + if (!AddWithCst->hasOneUse()) return 0; + + // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. + if (!CI2->getValue().isPowerOf2()) return 0; + unsigned NewWidth = CI2->getValue().countTrailingZeros(); + if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return 0; + + // The width of the new add formed is 1 more than the bias. + ++NewWidth; + + // Check to see that CI1 is an all-ones value with NewWidth bits. + if (CI1->getBitWidth() == NewWidth || + CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) + return 0; + + // In order to replace the original add with a narrower + // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant + // and truncates that discard the high bits of the add. Verify that this is + // the case. + Instruction *OrigAdd = cast(AddWithCst->getOperand(0)); + for (Value::use_iterator UI = OrigAdd->use_begin(), E = OrigAdd->use_end(); + UI != E; ++UI) { + if (*UI == AddWithCst) continue; + + // Only accept truncates for now. We would really like a nice recursive + // predicate like SimplifyDemandedBits, but which goes downwards the use-def + // chain to see which bits of a value are actually demanded. If the + // original add had another add which was then immediately truncated, we + // could still do the transformation. + TruncInst *TI = dyn_cast(*UI); + if (TI == 0 || + TI->getType()->getPrimitiveSizeInBits() > NewWidth) return 0; + } + + // If the pattern matches, truncate the inputs to the narrower type and + // use the sadd_with_overflow intrinsic to efficiently compute both the + // result and the overflow bit. + Module *M = I.getParent()->getParent()->getParent(); + + const Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); + Value *F = Intrinsic::getDeclaration(M, Intrinsic::sadd_with_overflow, + &NewType, 1); + + InstCombiner::BuilderTy *Builder = IC.Builder; + + // Put the new code above the original add, in case there are any uses of the + // add between the add and the compare. + Builder->SetInsertPoint(OrigAdd); + + Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName()+".trunc"); + Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName()+".trunc"); + CallInst *Call = Builder->CreateCall2(F, TruncA, TruncB, "sadd"); + Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); + Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); + + // The inner add was the result of the narrow add, zero extended to the + // wider type. Replace it with the result computed by the intrinsic. + IC.ReplaceInstUsesWith(*OrigAdd, ZExt); + + // The original icmp gets replaced with the overflow value. + return ExtractValueInst::Create(Call, 1, "sadd.overflow"); +} +static Instruction *ProcessUAddIdiom(Instruction &I, Value *OrigAddV, + InstCombiner &IC) { + // Don't bother doing this transformation for pointers, don't do it for + // vectors. + if (!isa(OrigAddV->getType())) return 0; + + // If the add is a constant expr, then we don't bother transforming it. + Instruction *OrigAdd = dyn_cast(OrigAddV); + if (OrigAdd == 0) return 0; + + Value *LHS = OrigAdd->getOperand(0), *RHS = OrigAdd->getOperand(1); + + // Put the new code above the original add, in case there are any uses of the + // add between the add and the compare. + InstCombiner::BuilderTy *Builder = IC.Builder; + Builder->SetInsertPoint(OrigAdd); + + Module *M = I.getParent()->getParent()->getParent(); + const Type *Ty = LHS->getType(); + Value *F = Intrinsic::getDeclaration(M, Intrinsic::uadd_with_overflow, &Ty,1); + CallInst *Call = Builder->CreateCall2(F, LHS, RHS, "uadd"); + Value *Add = Builder->CreateExtractValue(Call, 0); + + IC.ReplaceInstUsesWith(*OrigAdd, Add); + + // The original icmp gets replaced with the overflow value. + return ExtractValueInst::Create(Call, 1, "uadd.overflow"); +} + +// DemandedBitsLHSMask - When performing a comparison against a constant, +// it is possible that not all the bits in the LHS are demanded. This helper +// method computes the mask that IS demanded. +static APInt DemandedBitsLHSMask(ICmpInst &I, + unsigned BitWidth, bool isSignCheck) { + if (isSignCheck) + return APInt::getSignBit(BitWidth); + + ConstantInt *CI = dyn_cast(I.getOperand(1)); + if (!CI) return APInt::getAllOnesValue(BitWidth); + const APInt &RHS = CI->getValue(); + + switch (I.getPredicate()) { + // For a UGT comparison, we don't care about any bits that + // correspond to the trailing ones of the comparand. The value of these + // bits doesn't impact the outcome of the comparison, because any value + // greater than the RHS must differ in a bit higher than these due to carry. + case ICmpInst::ICMP_UGT: { + unsigned trailingOnes = RHS.countTrailingOnes(); + APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingOnes); + return ~lowBitsSet; + } + + // Similarly, for a ULT comparison, we don't care about the trailing zeros. + // Any value less than the RHS must differ in a higher bit because of carries. + case ICmpInst::ICMP_ULT: { + unsigned trailingZeros = RHS.countTrailingZeros(); + APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingZeros); + return ~lowBitsSet; + } + + default: + return APInt::getAllOnesValue(BitWidth); + } + +} Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { bool Changed = false; @@ -1649,17 +1810,37 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } unsigned BitWidth = 0; - if (TD) - BitWidth = TD->getTypeSizeInBits(Ty->getScalarType()); - else if (Ty->isIntOrIntVectorTy()) + if (Ty->isIntOrIntVectorTy()) BitWidth = Ty->getScalarSizeInBits(); - + else if (TD) // Pointers require TD info to get their size. + BitWidth = TD->getTypeSizeInBits(Ty->getScalarType()); + bool isSignBit = false; // See if we are doing a comparison with a constant. if (ConstantInt *CI = dyn_cast(Op1)) { Value *A = 0, *B = 0; + // Match the following pattern, which is a common idiom when writing + // overflow-safe integer arithmetic function. The source performs an + // addition in wider type, and explicitly checks for overflow using + // comparisons against INT_MIN and INT_MAX. Simplify this by using the + // sadd_with_overflow intrinsic. + // + // TODO: This could probably be generalized to handle other overflow-safe + // operations if we worked out the formulas to compute the appropriate + // magic constants. + // + // sum = a + b + // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 + { + ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI + if (I.getPredicate() == ICmpInst::ICMP_UGT && + match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) + if (Instruction *Res = ProcessUGT_ADDCST_ADD(I, A, B, CI2, CI, *this)) + return Res; + } + // (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B) if (I.isEquality() && CI->isZero() && match(Op0, m_Sub(m_Value(A), m_Value(B)))) { @@ -1682,11 +1863,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return new ICmpInst(ICmpInst::ICMP_SLT, Op0, ConstantInt::get(CI->getContext(), CI->getValue()+1)); case ICmpInst::ICMP_UGE: - assert(!CI->isMinValue(false)); // A >=u MIN -> TRUE + assert(!CI->isMinValue(false)); // A >=u MIN -> TRUE return new ICmpInst(ICmpInst::ICMP_UGT, Op0, ConstantInt::get(CI->getContext(), CI->getValue()-1)); case ICmpInst::ICMP_SGE: - assert(!CI->isMinValue(true)); // A >=s MIN -> TRUE + assert(!CI->isMinValue(true)); // A >=s MIN -> TRUE return new ICmpInst(ICmpInst::ICMP_SGT, Op0, ConstantInt::get(CI->getContext(), CI->getValue()-1)); } @@ -1704,8 +1885,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); if (SimplifyDemandedBits(I.getOperandUse(0), - isSignBit ? APInt::getSignBit(BitWidth) - : APInt::getAllOnesValue(BitWidth), + DemandedBitsLHSMask(I, BitWidth, isSignBit), Op0KnownZero, Op0KnownOne, 0)) return &I; if (SimplifyDemandedBits(I.getOperandUse(1), @@ -1735,28 +1915,94 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // that code below can assume that Min != Max. if (!isa(Op0) && Op0Min == Op0Max) return new ICmpInst(I.getPredicate(), - ConstantInt::get(I.getContext(), Op0Min), Op1); + ConstantInt::get(Op0->getType(), Op0Min), Op1); if (!isa(Op1) && Op1Min == Op1Max) return new ICmpInst(I.getPredicate(), Op0, - ConstantInt::get(I.getContext(), Op1Min)); + ConstantInt::get(Op1->getType(), Op1Min)); // Based on the range information we know about the LHS, see if we can - // simplify this comparison. For example, (x&4) < 8 is always true. + // simplify this comparison. For example, (x&4) < 8 is always true. switch (I.getPredicate()) { default: llvm_unreachable("Unknown icmp opcode!"); - case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_EQ: { if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + + // If all bits are known zero except for one, then we know at most one + // bit is set. If the comparison is against zero, then this is a check + // to see if *that* bit is set. + APInt Op0KnownZeroInverted = ~Op0KnownZero; + if (~Op1KnownZero == 0 && Op0KnownZeroInverted.isPowerOf2()) { + // If the LHS is an AND with the same constant, look through it. + Value *LHS = 0; + ConstantInt *LHSC = 0; + if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || + LHSC->getValue() != Op0KnownZeroInverted) + LHS = Op0; + + // If the LHS is 1 << x, and we know the result is a power of 2 like 8, + // then turn "((1 << x)&8) == 0" into "x != 3". + Value *X = 0; + if (match(LHS, m_Shl(m_One(), m_Value(X)))) { + unsigned CmpVal = Op0KnownZeroInverted.countTrailingZeros(); + return new ICmpInst(ICmpInst::ICMP_NE, X, + ConstantInt::get(X->getType(), CmpVal)); + } + + // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, + // then turn "((8 >>u x)&1) == 0" into "x != 3". + const APInt *CI; + if (Op0KnownZeroInverted == 1 && + match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) + return new ICmpInst(ICmpInst::ICMP_NE, X, + ConstantInt::get(X->getType(), + CI->countTrailingZeros())); + } + break; - case ICmpInst::ICMP_NE: + } + case ICmpInst::ICMP_NE: { if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + + // If all bits are known zero except for one, then we know at most one + // bit is set. If the comparison is against zero, then this is a check + // to see if *that* bit is set. + APInt Op0KnownZeroInverted = ~Op0KnownZero; + if (~Op1KnownZero == 0 && Op0KnownZeroInverted.isPowerOf2()) { + // If the LHS is an AND with the same constant, look through it. + Value *LHS = 0; + ConstantInt *LHSC = 0; + if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || + LHSC->getValue() != Op0KnownZeroInverted) + LHS = Op0; + + // If the LHS is 1 << x, and we know the result is a power of 2 like 8, + // then turn "((1 << x)&8) != 0" into "x == 3". + Value *X = 0; + if (match(LHS, m_Shl(m_One(), m_Value(X)))) { + unsigned CmpVal = Op0KnownZeroInverted.countTrailingZeros(); + return new ICmpInst(ICmpInst::ICMP_EQ, X, + ConstantInt::get(X->getType(), CmpVal)); + } + + // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, + // then turn "((8 >>u x)&1) != 0" into "x == 3". + const APInt *CI; + if (Op0KnownZeroInverted == 1 && + match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) + return new ICmpInst(ICmpInst::ICMP_EQ, X, + ConstantInt::get(X->getType(), + CI->countTrailingZeros())); + } + break; + } case ICmpInst::ICMP_ULT: if (Op0Max.ult(Op1Min)) // A true if max(A) < min(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.uge(Op1Max)) // A false if min(A) >= max(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Min == Op0Max) // A A != B if max(A) == min(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); if (ConstantInt *CI = dyn_cast(Op1)) { @@ -1772,9 +2018,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { break; case ICmpInst::ICMP_UGT: if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); @@ -1791,9 +2037,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { break; case ICmpInst::ICMP_SLT: if (Op0Max.slt(Op1Min)) // A true if max(A) < min(C) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sge(Op1Max)) // A false if min(A) >= max(C) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Min == Op0Max) // A A != B if max(A) == min(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); if (ConstantInt *CI = dyn_cast(Op1)) { @@ -1804,9 +2050,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { break; case ICmpInst::ICMP_SGT: if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); @@ -1819,30 +2065,30 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { case ICmpInst::ICMP_SGE: assert(!isa(Op1) && "ICMP_SGE with ConstantInt not folded!"); if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); break; case ICmpInst::ICMP_SLE: assert(!isa(Op1) && "ICMP_SLE with ConstantInt not folded!"); if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); break; case ICmpInst::ICMP_UGE: assert(!isa(Op1) && "ICMP_UGE with ConstantInt not folded!"); if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); break; case ICmpInst::ICMP_ULE: assert(!isa(Op1) && "ICMP_ULE with ConstantInt not folded!"); if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) - return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) - return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext())); + return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); break; } @@ -1894,7 +2140,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // block. If in the same block, we're encouraging jump threading. If // not, we are just pessimizing the code by making an i1 phi. if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = FoldOpIntoPhi(I, true)) + if (Instruction *NV = FoldOpIntoPhi(I)) return NV; break; case Instruction::Select: { @@ -1924,35 +2170,6 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } break; } - case Instruction::Call: - // If we have (malloc != null), and if the malloc has a single use, we - // can assume it is successful and remove the malloc. - if (isMalloc(LHSI) && LHSI->hasOneUse() && - isa(RHSC)) { - // Need to explicitly erase malloc call here, instead of adding it to - // Worklist, because it won't get DCE'd from the Worklist since - // isInstructionTriviallyDead() returns false for function calls. - // It is OK to replace LHSI/MallocCall with Undef because the - // instruction that uses it will be erased via Worklist. - if (extractMallocCall(LHSI)) { - LHSI->replaceAllUsesWith(UndefValue::get(LHSI->getType())); - EraseInstFromFunction(*LHSI); - return ReplaceInstUsesWith(I, - ConstantInt::get(Type::getInt1Ty(I.getContext()), - !I.isTrueWhenEqual())); - } - if (CallInst* MallocCall = extractMallocCallFromBitCast(LHSI)) - if (MallocCall->hasOneUse()) { - MallocCall->replaceAllUsesWith( - UndefValue::get(MallocCall->getType())); - EraseInstFromFunction(*MallocCall); - Worklist.Add(LHSI); // The malloc's bitcast use. - return ReplaceInstUsesWith(I, - ConstantInt::get(Type::getInt1Ty(I.getContext()), - !I.isTrueWhenEqual())); - } - } - break; case Instruction::IntToPtr: // icmp pred inttoptr(X), null -> icmp pred X, 0 if (RHSC->isNullValue() && TD && @@ -2024,79 +2241,213 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *R = visitICmpInstWithCastAndCast(I)) return R; } - - // See if it's the same type of instruction on the left and right. - if (BinaryOperator *Op0I = dyn_cast(Op0)) { - if (BinaryOperator *Op1I = dyn_cast(Op1)) { - if (Op0I->getOpcode() == Op1I->getOpcode() && Op0I->hasOneUse() && - Op1I->hasOneUse() && Op0I->getOperand(1) == Op1I->getOperand(1)) { - switch (Op0I->getOpcode()) { + + // Special logic for binary operators. + BinaryOperator *BO0 = dyn_cast(Op0); + BinaryOperator *BO1 = dyn_cast(Op1); + if (BO0 || BO1) { + CmpInst::Predicate Pred = I.getPredicate(); + bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; + if (BO0 && isa(BO0)) + NoOp0WrapProblem = ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && BO0->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && BO0->hasNoSignedWrap()); + if (BO1 && isa(BO1)) + NoOp1WrapProblem = ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && BO1->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && BO1->hasNoSignedWrap()); + + // Analyze the case when either Op0 or Op1 is an add instruction. + // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). + Value *A = 0, *B = 0, *C = 0, *D = 0; + if (BO0 && BO0->getOpcode() == Instruction::Add) + A = BO0->getOperand(0), B = BO0->getOperand(1); + if (BO1 && BO1->getOpcode() == Instruction::Add) + C = BO1->getOperand(0), D = BO1->getOperand(1); + + // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. + if ((A == Op1 || B == Op1) && NoOp0WrapProblem) + return new ICmpInst(Pred, A == Op1 ? B : A, + Constant::getNullValue(Op1->getType())); + + // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. + if ((C == Op0 || D == Op0) && NoOp1WrapProblem) + return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()), + C == Op0 ? D : C); + + // icmp (X+Y), (X+Z) -> icmp Y, Z for equalities or if there is no overflow. + if (A && C && (A == C || A == D || B == C || B == D) && + NoOp0WrapProblem && NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) { + // Determine Y and Z in the form icmp (X+Y), (X+Z). + Value *Y = (A == C || A == D) ? B : A; + Value *Z = (C == A || C == B) ? D : C; + return new ICmpInst(Pred, Y, Z); + } + + // Analyze the case when either Op0 or Op1 is a sub instruction. + // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). + A = 0; B = 0; C = 0; D = 0; + if (BO0 && BO0->getOpcode() == Instruction::Sub) + A = BO0->getOperand(0), B = BO0->getOperand(1); + if (BO1 && BO1->getOpcode() == Instruction::Sub) + C = BO1->getOperand(0), D = BO1->getOperand(1); + + // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. + if (A == Op1 && NoOp0WrapProblem) + return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); + + // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow. + if (C == Op0 && NoOp1WrapProblem) + return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); + + // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow. + if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) + return new ICmpInst(Pred, A, C); + + // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow. + if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) + return new ICmpInst(Pred, D, B); + + BinaryOperator *SRem = NULL; + // icmp (srem X, Y), Y + if (BO0 && BO0->getOpcode() == Instruction::SRem && + Op1 == BO0->getOperand(1)) + SRem = BO0; + // icmp Y, (srem X, Y) + else if (BO1 && BO1->getOpcode() == Instruction::SRem && + Op0 == BO1->getOperand(1)) + SRem = BO1; + if (SRem) { + // We don't check hasOneUse to avoid increasing register pressure because + // the value we use is the same value this instruction was already using. + switch (SRem == BO0 ? ICmpInst::getSwappedPredicate(Pred) : Pred) { default: break; - case Instruction::Add: - case Instruction::Sub: - case Instruction::Xor: - if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b - return new ICmpInst(I.getPredicate(), Op0I->getOperand(0), - Op1I->getOperand(0)); - // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b - if (ConstantInt *CI = dyn_cast(Op0I->getOperand(1))) { - if (CI->getValue().isSignBit()) { - ICmpInst::Predicate Pred = I.isSigned() - ? I.getUnsignedPredicate() - : I.getSignedPredicate(); - return new ICmpInst(Pred, Op0I->getOperand(0), - Op1I->getOperand(0)); - } - - if (CI->getValue().isMaxSignedValue()) { - ICmpInst::Predicate Pred = I.isSigned() - ? I.getUnsignedPredicate() - : I.getSignedPredicate(); - Pred = I.getSwappedPredicate(Pred); - return new ICmpInst(Pred, Op0I->getOperand(0), - Op1I->getOperand(0)); - } + case ICmpInst::ICMP_EQ: + return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + case ICmpInst::ICMP_NE: + return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + return new ICmpInst(ICmpInst::ICMP_SGT, SRem->getOperand(1), + Constant::getAllOnesValue(SRem->getType())); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + return new ICmpInst(ICmpInst::ICMP_SLT, SRem->getOperand(1), + Constant::getNullValue(SRem->getType())); + } + } + + if (BO0 && BO1 && BO0->getOpcode() == BO1->getOpcode() && + BO0->hasOneUse() && BO1->hasOneUse() && + BO0->getOperand(1) == BO1->getOperand(1)) { + switch (BO0->getOpcode()) { + default: break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Xor: + if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b + return new ICmpInst(I.getPredicate(), BO0->getOperand(0), + BO1->getOperand(0)); + // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b + if (ConstantInt *CI = dyn_cast(BO0->getOperand(1))) { + if (CI->getValue().isSignBit()) { + ICmpInst::Predicate Pred = I.isSigned() + ? I.getUnsignedPredicate() + : I.getSignedPredicate(); + return new ICmpInst(Pred, BO0->getOperand(0), + BO1->getOperand(0)); + } + + if (CI->getValue().isMaxSignedValue()) { + ICmpInst::Predicate Pred = I.isSigned() + ? I.getUnsignedPredicate() + : I.getSignedPredicate(); + Pred = I.getSwappedPredicate(Pred); + return new ICmpInst(Pred, BO0->getOperand(0), + BO1->getOperand(0)); } + } + break; + case Instruction::Mul: + if (!I.isEquality()) break; - case Instruction::Mul: - if (!I.isEquality()) - break; - if (ConstantInt *CI = dyn_cast(Op0I->getOperand(1))) { - // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask - // Mask = -1 >> count-trailing-zeros(Cst). - if (!CI->isZero() && !CI->isOne()) { - const APInt &AP = CI->getValue(); - ConstantInt *Mask = ConstantInt::get(I.getContext(), - APInt::getLowBitsSet(AP.getBitWidth(), - AP.getBitWidth() - - AP.countTrailingZeros())); - Value *And1 = Builder->CreateAnd(Op0I->getOperand(0), Mask); - Value *And2 = Builder->CreateAnd(Op1I->getOperand(0), Mask); - return new ICmpInst(I.getPredicate(), And1, And2); - } + if (ConstantInt *CI = dyn_cast(BO0->getOperand(1))) { + // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask + // Mask = -1 >> count-trailing-zeros(Cst). + if (!CI->isZero() && !CI->isOne()) { + const APInt &AP = CI->getValue(); + ConstantInt *Mask = ConstantInt::get(I.getContext(), + APInt::getLowBitsSet(AP.getBitWidth(), + AP.getBitWidth() - + AP.countTrailingZeros())); + Value *And1 = Builder->CreateAnd(BO0->getOperand(0), Mask); + Value *And2 = Builder->CreateAnd(BO1->getOperand(0), Mask); + return new ICmpInst(I.getPredicate(), And1, And2); } - break; } + break; + case Instruction::UDiv: + case Instruction::LShr: + if (I.isSigned()) + break; + // fall-through + case Instruction::SDiv: + case Instruction::AShr: + if (!BO0->isExact() || !BO1->isExact()) + break; + return new ICmpInst(I.getPredicate(), BO0->getOperand(0), + BO1->getOperand(0)); + case Instruction::Shl: { + bool NUW = BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap(); + bool NSW = BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap(); + if (!NUW && !NSW) + break; + if (!NSW && I.isSigned()) + break; + return new ICmpInst(I.getPredicate(), BO0->getOperand(0), + BO1->getOperand(0)); + } } } } - // ~x < ~y --> y < x { Value *A, *B; - if (match(Op0, m_Not(m_Value(A))) && - match(Op1, m_Not(m_Value(B)))) - return new ICmpInst(I.getPredicate(), B, A); + // ~x < ~y --> y < x + // ~x < cst --> ~cst < x + if (match(Op0, m_Not(m_Value(A)))) { + if (match(Op1, m_Not(m_Value(B)))) + return new ICmpInst(I.getPredicate(), B, A); + if (ConstantInt *RHSC = dyn_cast(Op1)) + return new ICmpInst(I.getPredicate(), ConstantExpr::getNot(RHSC), A); + } + + // (a+b) llvm.uadd.with.overflow. + // (a+b) llvm.uadd.with.overflow. + if (I.getPredicate() == ICmpInst::ICMP_ULT && + match(Op0, m_Add(m_Value(A), m_Value(B))) && + (Op1 == A || Op1 == B)) + if (Instruction *R = ProcessUAddIdiom(I, Op0, *this)) + return R; + + // a >u (a+b) --> llvm.uadd.with.overflow. + // b >u (a+b) --> llvm.uadd.with.overflow. + if (I.getPredicate() == ICmpInst::ICMP_UGT && + match(Op1, m_Add(m_Value(A), m_Value(B))) && + (Op0 == A || Op0 == B)) + if (Instruction *R = ProcessUAddIdiom(I, Op1, *this)) + return R; } if (I.isEquality()) { Value *A, *B, *C, *D; - - // -x == -y --> x == y - if (match(Op0, m_Neg(m_Value(A))) && - match(Op1, m_Neg(m_Value(B)))) - return new ICmpInst(I.getPredicate(), A, B); - + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 Value *OtherVal = A == Op1 ? B : A; @@ -2131,20 +2482,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Constant::getNullValue(A->getType())); } - // (A-B) == A -> B == 0 - if (match(Op0, m_Sub(m_Specific(Op1), m_Value(B)))) - return new ICmpInst(I.getPredicate(), B, - Constant::getNullValue(B->getType())); - - // A == (A-B) -> B == 0 - if (match(Op1, m_Sub(m_Specific(Op0), m_Value(B)))) - return new ICmpInst(I.getPredicate(), B, - Constant::getNullValue(B->getType())); - // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 - if (Op0->hasOneUse() && Op1->hasOneUse() && - match(Op0, m_And(m_Value(A), m_Value(B))) && - match(Op1, m_And(m_Value(C), m_Value(D)))) { + if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) && + match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) { Value *X = 0, *Y = 0, *Z = 0; if (A == C) { @@ -2165,6 +2505,32 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return &I; } } + + // Transform "icmp eq (trunc (lshr(X, cst1)), cst" to + // "icmp (and X, mask), cst" + uint64_t ShAmt = 0; + ConstantInt *Cst1; + if (Op0->hasOneUse() && + match(Op0, m_Trunc(m_OneUse(m_LShr(m_Value(A), + m_ConstantInt(ShAmt))))) && + match(Op1, m_ConstantInt(Cst1)) && + // Only do this when A has multiple uses. This is most important to do + // when it exposes other optimizations. + !A->hasOneUse()) { + unsigned ASize =cast(A->getType())->getPrimitiveSizeInBits(); + + if (ShAmt < ASize) { + APInt MaskV = + APInt::getLowBitsSet(ASize, Op0->getType()->getPrimitiveSizeInBits()); + MaskV <<= ShAmt; + + APInt CmpV = Cst1->getValue().zext(ASize); + CmpV <<= ShAmt; + + Value *Mask = Builder->CreateAnd(A, Builder->getInt(MaskV)); + return new ICmpInst(I.getPredicate(), Mask, Builder->getInt(CmpV)); + } + } } { @@ -2421,12 +2787,48 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { if (Constant *RHSC = dyn_cast(Op1)) { if (Instruction *LHSI = dyn_cast(Op0)) switch (LHSI->getOpcode()) { + case Instruction::FPExt: { + // fcmp (fpext x), C -> fcmp x, (fptrunc C) if fptrunc is lossless + FPExtInst *LHSExt = cast(LHSI); + ConstantFP *RHSF = dyn_cast(RHSC); + if (!RHSF) + break; + + // We can't convert a PPC double double. + if (RHSF->getType()->isPPC_FP128Ty()) + break; + + const fltSemantics *Sem; + // FIXME: This shouldn't be here. + if (LHSExt->getSrcTy()->isFloatTy()) + Sem = &APFloat::IEEEsingle; + else if (LHSExt->getSrcTy()->isDoubleTy()) + Sem = &APFloat::IEEEdouble; + else if (LHSExt->getSrcTy()->isFP128Ty()) + Sem = &APFloat::IEEEquad; + else if (LHSExt->getSrcTy()->isX86_FP80Ty()) + Sem = &APFloat::x87DoubleExtended; + else + break; + + bool Lossy; + APFloat F = RHSF->getValueAPF(); + F.convert(*Sem, APFloat::rmNearestTiesToEven, &Lossy); + + // Avoid lossy conversions and denormals. + if (!Lossy && + F.compare(APFloat::getSmallestNormalized(*Sem)) != + APFloat::cmpLessThan) + return new FCmpInst(I.getPredicate(), LHSExt->getOperand(0), + ConstantFP::get(RHSC->getContext(), F)); + break; + } case Instruction::PHI: // Only fold fcmp into the PHI if the phi and fcmp are in the same // block. If in the same block, we're encouraging jump threading. If // not, we are just pessimizing the code by making an i1 phi. if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = FoldOpIntoPhi(I, true)) + if (Instruction *NV = FoldOpIntoPhi(I)) return NV; break; case Instruction::SIToFP: @@ -2459,6 +2861,14 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { return SelectInst::Create(LHSI->getOperand(0), Op1, Op2); break; } + case Instruction::FSub: { + // fcmp pred (fneg x), C -> fcmp swap(pred) x, -C + Value *Op; + if (match(LHSI, m_FNeg(m_Value(Op)))) + return new FCmpInst(I.getSwappedPredicate(), Op, + ConstantExpr::getFNeg(RHSC)); + break; + } case Instruction::Load: if (GetElementPtrInst *GEP = dyn_cast(LHSI->getOperand(0))) { @@ -2472,5 +2882,17 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { } } + // fcmp pred (fneg x), (fneg y) -> fcmp swap(pred) x, y + Value *X, *Y; + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) + return new FCmpInst(I.getSwappedPredicate(), X, Y); + + // fcmp (fpext x), (fpext y) -> fcmp x, y + if (FPExtInst *LHSExt = dyn_cast(Op0)) + if (FPExtInst *RHSExt = dyn_cast(Op1)) + if (LHSExt->getSrcTy() == RHSExt->getSrcTy()) + return new FCmpInst(I.getPredicate(), LHSExt->getOperand(0), + RHSExt->getOperand(0)); + return Changed ? &I : 0; }