X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FTransforms%2FInstCombine%2FInstCombineCompares.cpp;h=ab0d1b10c2455f5cf11dcfa30b6b4176dbb84189;hb=46216f7f993c9cd71b71d46e37054b4c27cbaf40;hp=fd2b68a9ecdb8aeb0dee5ea2edcb02ce9f594282;hpb=11c29bafd584da2e39ee5d885ca2d53035bc1372;p=oota-llvm.git diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index fd2b68a9ecd..ab0d1b10c24 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -11,32 +11,35 @@ // //===----------------------------------------------------------------------===// -#include "InstCombine.h" +#include "InstCombineInternal.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/Support/ConstantRange.h" -#include "llvm/Support/GetElementPtrTypeIterator.h" -#include "llvm/Support/PatternMatch.h" -#include "llvm/Target/TargetLibraryInfo.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Analysis/TargetLibraryInfo.h" + using namespace llvm; using namespace PatternMatch; +#define DEBUG_TYPE "instcombine" + +// How many times is a select replaced by one of its operands? +STATISTIC(NumSel, "Number of select opts"); + +// Initialization Routines + static ConstantInt *getOne(Constant *C) { return ConstantInt::get(cast(C->getType()), 1); } -/// AddOne - Add one to a ConstantInt -static Constant *AddOne(Constant *C) { - return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1)); -} -/// SubOne - Subtract one from a ConstantInt -static Constant *SubOne(Constant *C) { - return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); -} - static ConstantInt *ExtractElement(Constant *V, Constant *Idx) { return cast(ConstantExpr::getExtractElement(V, Idx)); } @@ -226,15 +229,12 @@ static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, Instruction *InstCombiner:: FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, ConstantInt *AndCst) { - // We need TD information to know the pointer size unless this is inbounds. - if (!GEP->isInBounds() && TD == 0) return 0; - Constant *Init = GV->getInitializer(); if (!isa(Init) && !isa(Init)) - return 0; + return nullptr; uint64_t ArrayElementCount = Init->getType()->getArrayNumElements(); - if (ArrayElementCount > 1024) return 0; // Don't blow up on huge arrays. + if (ArrayElementCount > 1024) return nullptr; // Don't blow up on huge arrays. // There are many forms of this optimization we can handle, for now, just do // the simple index into a single-dimensional array. @@ -244,7 +244,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, !isa(GEP->getOperand(1)) || !cast(GEP->getOperand(1))->isZero() || isa(GEP->getOperand(2))) - return 0; + return nullptr; // Check that indices after the variable are constants and in-range for the // type they index. Collect the indices. This is typically for arrays of @@ -254,18 +254,18 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, Type *EltTy = Init->getType()->getArrayElementType(); for (unsigned i = 3, e = GEP->getNumOperands(); i != e; ++i) { ConstantInt *Idx = dyn_cast(GEP->getOperand(i)); - if (Idx == 0) return 0; // Variable index. + if (!Idx) return nullptr; // Variable index. uint64_t IdxVal = Idx->getZExtValue(); - if ((unsigned)IdxVal != IdxVal) return 0; // Too large array index. + if ((unsigned)IdxVal != IdxVal) return nullptr; // Too large array index. if (StructType *STy = dyn_cast(EltTy)) EltTy = STy->getElementType(IdxVal); else if (ArrayType *ATy = dyn_cast(EltTy)) { - if (IdxVal >= ATy->getNumElements()) return 0; + if (IdxVal >= ATy->getNumElements()) return nullptr; EltTy = ATy->getElementType(); } else { - return 0; // Unknown type. + return nullptr; // Unknown type. } LaterIndices.push_back(IdxVal); @@ -299,12 +299,11 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // the array, this will fully represent all the comparison results. uint64_t MagicBitvector = 0; - // Scan the array and see if one of our patterns matches. Constant *CompareRHS = cast(ICI.getOperand(1)); for (unsigned i = 0, e = ArrayElementCount; i != e; ++i) { Constant *Elt = Init->getAggregateElement(i); - if (Elt == 0) return 0; + if (!Elt) return nullptr; // If this is indexing an array of structures, get the structure element. if (!LaterIndices.empty()) @@ -315,7 +314,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // Find out if the comparison would be true or false for the i'th element. Constant *C = ConstantFoldCompareInstOperands(ICI.getPredicate(), Elt, - CompareRHS, TD, TLI); + CompareRHS, DL, TLI); // If the result is undef for this element, ignore it. if (isa(C)) { // Extend range state machines to cover this element in case there is an @@ -329,7 +328,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // If we can't compute the result for any of the elements, we have to give // up evaluating the entire conditional. - if (!isa(C)) return 0; + if (!isa(C)) return nullptr; // Otherwise, we know if the comparison is true or false for this element, // update our state machines. @@ -383,7 +382,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, if ((i & 8) == 0 && i >= 64 && SecondTrueElement == Overdefined && SecondFalseElement == Overdefined && TrueRangeEnd == Overdefined && FalseRangeEnd == Overdefined) - return 0; + return nullptr; } // Now that we've scanned the entire array, emit our new comparison(s). We @@ -393,9 +392,12 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // If the index is larger than the pointer size of the target, truncate the // index down like the GEP would do implicitly. We don't have to do this for // an inbounds GEP because the index can't be out of range. - if (!GEP->isInBounds() && - Idx->getType()->getPrimitiveSizeInBits() > TD->getPointerSizeInBits()) - Idx = Builder->CreateTrunc(Idx, TD->getIntPtrType(Idx->getContext())); + if (!GEP->isInBounds()) { + Type *IntPtrTy = DL.getIntPtrType(GEP->getType()); + unsigned PtrSize = IntPtrTy->getIntegerBitWidth(); + if (Idx->getType()->getPrimitiveSizeInBits() > PtrSize) + Idx = Builder->CreateTrunc(Idx, IntPtrTy); + } // If the comparison is only true for one or two elements, emit direct // comparisons. @@ -472,7 +474,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // of this load, replace it with computation that does: // ((magic_cst >> i) & 1) != 0 { - Type *Ty = 0; + Type *Ty = nullptr; // Look for an appropriate type: // - The type of Idx if the magic fits @@ -480,12 +482,10 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // - Default to i32 if (ArrayElementCount <= Idx->getType()->getIntegerBitWidth()) Ty = Idx->getType(); - else if (TD) - Ty = TD->getSmallestLegalIntType(Init->getContext(), ArrayElementCount); - else if (ArrayElementCount <= 32) - Ty = Type::getInt32Ty(Init->getContext()); + else + Ty = DL.getSmallestLegalIntType(Init->getContext(), ArrayElementCount); - if (Ty != 0) { + if (Ty) { Value *V = Builder->CreateIntCast(Idx, Ty, false); V = Builder->CreateLShr(ConstantInt::get(Ty, MagicBitvector), V); V = Builder->CreateAnd(ConstantInt::get(Ty, 1), V); @@ -493,7 +493,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, } } - return 0; + return nullptr; } @@ -507,8 +507,8 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, /// /// If we can't emit an optimized form for this expression, this returns null. /// -static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { - DataLayout &TD = *IC.getDataLayout(); +static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, + const DataLayout &DL) { gep_type_iterator GTI = gep_type_begin(GEP); // Check to see if this gep only has a single variable index. If so, and if @@ -525,9 +525,9 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { // Handle a struct index, which adds its field offset to the pointer. if (StructType *STy = dyn_cast(*GTI)) { - Offset += TD.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); + Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); } else { - uint64_t Size = TD.getTypeAllocSize(GTI.getIndexedType()); + uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); Offset += Size*CI->getSExtValue(); } } else { @@ -538,40 +538,42 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { // If there are no variable indices, we must have a constant offset, just // evaluate it the general way. - if (i == e) return 0; + if (i == e) return nullptr; Value *VariableIdx = GEP->getOperand(i); // Determine the scale factor of the variable element. For example, this is // 4 if the variable index is into an array of i32. - uint64_t VariableScale = TD.getTypeAllocSize(GTI.getIndexedType()); + uint64_t VariableScale = DL.getTypeAllocSize(GTI.getIndexedType()); // Verify that there are no other variable indices. If so, emit the hard way. for (++i, ++GTI; i != e; ++i, ++GTI) { ConstantInt *CI = dyn_cast(GEP->getOperand(i)); - if (!CI) return 0; + if (!CI) return nullptr; // Compute the aggregate offset of constant indices. if (CI->isZero()) continue; // Handle a struct index, which adds its field offset to the pointer. if (StructType *STy = dyn_cast(*GTI)) { - Offset += TD.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); + Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); } else { - uint64_t Size = TD.getTypeAllocSize(GTI.getIndexedType()); + uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); Offset += Size*CI->getSExtValue(); } } + + // Okay, we know we have a single variable index, which must be a // pointer/array/vector index. If there is no offset, life is simple, return // the index. - unsigned IntPtrWidth = TD.getPointerSizeInBits(); + Type *IntPtrTy = DL.getIntPtrType(GEP->getOperand(0)->getType()); + unsigned IntPtrWidth = IntPtrTy->getIntegerBitWidth(); if (Offset == 0) { // Cast to intptrty in case a truncation occurs. If an extension is needed, // we don't need to bother extending: the extension won't affect where the // computation crosses zero. if (VariableIdx->getType()->getPrimitiveSizeInBits() > IntPtrWidth) { - Type *IntPtrTy = TD.getIntPtrType(VariableIdx->getContext()); VariableIdx = IC.Builder->CreateTrunc(VariableIdx, IntPtrTy); } return VariableIdx; @@ -590,10 +592,9 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { // multiple of the variable scale. int64_t NewOffs = Offset / (int64_t)VariableScale; if (Offset != NewOffs*(int64_t)VariableScale) - return 0; + return nullptr; // Okay, we can do this evaluation. Start by converting the index to intptr. - Type *IntPtrTy = TD.getIntPtrType(VariableIdx->getContext()); if (VariableIdx->getType() != IntPtrTy) VariableIdx = IC.Builder->CreateIntCast(VariableIdx, IntPtrTy, true /*Signed*/); @@ -612,22 +613,23 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // e.g. "&foo[0] (RHS)) - RHS = BCI->getOperand(0); + // Look through bitcasts and addrspacecasts. We do not however want to remove + // 0 GEPs. + if (!isa(RHS)) + RHS = RHS->stripPointerCasts(); Value *PtrBase = GEPLHS->getOperand(0); - if (TD && PtrBase == RHS && GEPLHS->isInBounds()) { + if (PtrBase == RHS && GEPLHS->isInBounds()) { // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). // This transformation (ignoring the base and scales) is valid because we // know pointers can't overflow since the gep is inbounds. See if we can // output an optimized form. - Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, *this); + Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, *this, DL); // If not, synthesize the offset the hard way. - if (Offset == 0) + if (!Offset) Offset = EmitGEPOffset(GEPLHS); return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, Constant::getNullValue(Offset->getType())); @@ -652,43 +654,44 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // If we're comparing GEPs with two base pointers that only differ in type // and both GEPs have only constant indices or just one use, then fold // the compare with the adjusted indices. - if (TD && GEPLHS->isInBounds() && GEPRHS->isInBounds() && + if (GEPLHS->isInBounds() && GEPRHS->isInBounds() && (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) && (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse()) && PtrBase->stripPointerCasts() == - GEPRHS->getOperand(0)->stripPointerCasts()) { + GEPRHS->getOperand(0)->stripPointerCasts()) { + Value *LOffset = EmitGEPOffset(GEPLHS); + Value *ROffset = EmitGEPOffset(GEPRHS); + + // If we looked through an addrspacecast between different sized address + // spaces, the LHS and RHS pointers are different sized + // integers. Truncate to the smaller one. + Type *LHSIndexTy = LOffset->getType(); + Type *RHSIndexTy = ROffset->getType(); + if (LHSIndexTy != RHSIndexTy) { + if (LHSIndexTy->getPrimitiveSizeInBits() < + RHSIndexTy->getPrimitiveSizeInBits()) { + ROffset = Builder->CreateTrunc(ROffset, LHSIndexTy); + } else + LOffset = Builder->CreateTrunc(LOffset, RHSIndexTy); + } + Value *Cmp = Builder->CreateICmp(ICmpInst::getSignedPredicate(Cond), - EmitGEPOffset(GEPLHS), - EmitGEPOffset(GEPRHS)); + LOffset, ROffset); return ReplaceInstUsesWith(I, Cmp); } // Otherwise, the base pointers are different and the indices are // different, bail out. - return 0; + return nullptr; } // If one of the GEPs has all zero indices, recurse. - bool AllZeros = true; - for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i) - if (!isa(GEPLHS->getOperand(i)) || - !cast(GEPLHS->getOperand(i))->isNullValue()) { - AllZeros = false; - break; - } - if (AllZeros) + if (GEPLHS->hasAllZeroIndices()) return FoldGEPICmp(GEPRHS, GEPLHS->getOperand(0), ICmpInst::getSwappedPredicate(Cond), I); // If the other GEP has all zero indices, recurse. - AllZeros = true; - for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i) - if (!isa(GEPRHS->getOperand(i)) || - !cast(GEPRHS->getOperand(i))->isNullValue()) { - AllZeros = false; - break; - } - if (AllZeros) + if (GEPRHS->hasAllZeroIndices()) return FoldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); @@ -723,9 +726,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Only lower this if the icmp is the only user of the GEP or if we expect // the result to fold to a constant! - if (TD && - GEPsInBounds && - (isa(GEPLHS) || GEPLHS->hasOneUse()) && + if (GEPsInBounds && (isa(GEPLHS) || GEPLHS->hasOneUse()) && (isa(GEPRHS) || GEPRHS->hasOneUse())) { // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2) Value *L = EmitGEPOffset(GEPLHS); @@ -733,29 +734,13 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, return new ICmpInst(ICmpInst::getSignedPredicate(Cond), L, R); } } - return 0; + return nullptr; } /// FoldICmpAddOpCst - Fold "icmp pred (X+CI), X". -Instruction *InstCombiner::FoldICmpAddOpCst(ICmpInst &ICI, +Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, Value *X, ConstantInt *CI, - ICmpInst::Predicate Pred, - Value *TheAdd) { - // If we have X+0, exit early (simplifying logic below) and let it get folded - // elsewhere. icmp X+0, X -> icmp X, X - if (CI->isZero()) { - bool isTrue = ICmpInst::isTrueWhenEqual(Pred); - return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue)); - } - - // (X+4) == X -> false. - if (Pred == ICmpInst::ICMP_EQ) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); - - // (X+4) != X -> true. - if (Pred == ICmpInst::ICMP_NE) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); - + ICmpInst::Predicate Pred) { // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similarly for all other "or equals" // operators. @@ -817,11 +802,11 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // if it finds it. bool DivIsSigned = DivI->getOpcode() == Instruction::SDiv; if (!ICI.isEquality() && DivIsSigned != ICI.isSigned()) - return 0; + return nullptr; if (DivRHS->isZero()) - return 0; // The ProdOV computation fails on divide by zero. + return nullptr; // The ProdOV computation fails on divide by zero. if (DivIsSigned && DivRHS->isAllOnesValue()) - return 0; // The overflow computation also screws up here + return nullptr; // The overflow computation also screws up here if (DivRHS->isOne()) { // This eliminates some funny cases with INT_MIN. ICI.setOperand(0, DivI->getOperand(0)); // X/1 == X. @@ -855,7 +840,7 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // overflow variable is set to 0 if it's corresponding bound variable is valid // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. int LoOverflow = 0, HiOverflow = 0; - Constant *LoBound = 0, *HiBound = 0; + Constant *LoBound = nullptr, *HiBound = nullptr; if (!DivIsSigned) { // udiv // e.g. X/5 op 3 --> [15, 20) @@ -895,7 +880,7 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, HiBound = cast(ConstantExpr::getNeg(RangeSize)); if (HiBound == DivRHS) { // -INTMIN = INTMIN HiOverflow = 1; // [INTMIN+1, overflow) - HiBound = 0; // e.g. X/INTMIN = 0 --> X > INTMIN + HiBound = nullptr; // e.g. X/INTMIN = 0 --> X > INTMIN } } else if (CmpRHSV.isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) @@ -969,20 +954,20 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, uint32_t TypeBits = CmpRHSV.getBitWidth(); uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); if (ShAmtVal >= TypeBits || ShAmtVal == 0) - return 0; + return nullptr; if (!ICI.isEquality()) { // If we have an unsigned comparison and an ashr, we can't simplify this. // Similarly for signed comparisons with lshr. if (ICI.isSigned() != (Shr->getOpcode() == Instruction::AShr)) - return 0; + return nullptr; // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv // by a power of 2. Since we already have logic to simplify these, // transform to div and then simplify the resultant comparison. if (Shr->getOpcode() == Instruction::AShr && (!Shr->isExact() || ShAmtVal == TypeBits - 1)) - return 0; + return nullptr; // Revisit the shift (to delete it). Worklist.Add(Shr); @@ -999,7 +984,7 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, // If the builder folded the binop, just return it. BinaryOperator *TheDiv = dyn_cast(Tmp); - if (TheDiv == 0) + if (!TheDiv) return &ICI; // Otherwise, fold this div/compare. @@ -1042,9 +1027,114 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, Mask, Shr->getName()+".mask"); return new ICmpInst(ICI.getPredicate(), And, ShiftedCmpRHS); } - return 0; + return nullptr; +} + +/// FoldICmpCstShrCst - Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" -> +/// (icmp eq/ne A, Log2(const2/const1)) -> +/// (icmp eq/ne A, Log2(const2) - Log2(const1)). +Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, + ConstantInt *CI2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getConstant = [&I, this](bool IsTrue) { + if (I.getPredicate() == I.ICMP_NE) + IsTrue = !IsTrue; + return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + }; + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + APInt AP1 = CI1->getValue(); + APInt AP2 = CI2->getValue(); + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + bool IsAShr = isa(Op); + if (IsAShr) { + if (AP2.isAllOnesValue()) + return nullptr; + if (AP2.isNegative() != AP1.isNegative()) + return nullptr; + if (AP2.sgt(AP1)) + return nullptr; + } + + if (!AP1) + // 'A' must be large enough to shift out the highest set bit. + return getICmp(I.ICMP_UGT, A, + ConstantInt::get(A->getType(), AP2.logBase2())); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + // Get the distance between the highest bit that's set. + int Shift; + // Both the constants are negative, take their positive to calculate log. + if (IsAShr && AP1.isNegative()) + // Get the ones' complement of AP2 and AP1 when computing the distance. + Shift = (~AP2).logBase2() - (~AP1).logBase2(); + else + Shift = AP2.logBase2() - AP1.logBase2(); + + if (Shift > 0) { + if (IsAShr ? AP1 == AP2.ashr(Shift) : AP1 == AP2.lshr(Shift)) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + } + // Shifting const2 will never be equal to const1. + return getConstant(false); } +/// FoldICmpCstShlCst - Handle "(icmp eq/ne (shl const2, A), const1)" -> +/// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)). +Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, + ConstantInt *CI2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getConstant = [&I, this](bool IsTrue) { + if (I.getPredicate() == I.ICMP_NE) + IsTrue = !IsTrue; + return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + }; + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + APInt AP1 = CI1->getValue(); + APInt AP2 = CI2->getValue(); + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + + unsigned AP2TrailingZeros = AP2.countTrailingZeros(); + + if (!AP1 && AP2TrailingZeros != 0) + return getICmp(I.ICMP_UGE, A, + ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros)); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + // Get the distance between the lowest bits that are set. + int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + + if (Shift > 0 && AP2.shl(Shift) == AP1) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + + // Shifting const2 will never be equal to const1. + return getConstant(false); +} /// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)". /// @@ -1061,7 +1151,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(), SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits(); APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); - ComputeMaskedBits(LHSI->getOperand(0), KnownZero, KnownOne); + computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne, 0, &ICI); // If all the high bits are known, we can do this xform. if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { @@ -1074,17 +1164,17 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, } break; - case Instruction::Xor: // (icmp pred (xor X, XorCST), CI) - if (ConstantInt *XorCST = dyn_cast(LHSI->getOperand(1))) { + case Instruction::Xor: // (icmp pred (xor X, XorCst), CI) + if (ConstantInt *XorCst = dyn_cast(LHSI->getOperand(1))) { // If this is a comparison that tests the signbit (X < 0) or (x > -1), // fold the xor. if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && RHSV == 0) || (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV.isAllOnesValue())) { Value *CompareVal = LHSI->getOperand(0); - // If the sign bit of the XorCST is not set, there is no change to + // If the sign bit of the XorCst is not set, there is no change to // the operation, just stop using the Xor. - if (!XorCST->isNegative()) { + if (!XorCst->isNegative()) { ICI.setOperand(0, CompareVal); Worklist.Add(LHSI); return &ICI; @@ -1106,8 +1196,8 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (LHSI->hasOneUse()) { // (icmp u/s (xor A SignBit), C) -> (icmp s/u A, (xor C SignBit)) - if (!ICI.isEquality() && XorCST->getValue().isSignBit()) { - const APInt &SignBit = XorCST->getValue(); + if (!ICI.isEquality() && XorCst->getValue().isSignBit()) { + const APInt &SignBit = XorCst->getValue(); ICmpInst::Predicate Pred = ICI.isSigned() ? ICI.getUnsignedPredicate() : ICI.getSignedPredicate(); @@ -1116,8 +1206,8 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, } // (icmp u/s (xor A ~SignBit), C) -> (icmp s/u (xor C ~SignBit), A) - if (!ICI.isEquality() && XorCST->isMaxValue(true)) { - const APInt &NotSignBit = XorCST->getValue(); + if (!ICI.isEquality() && XorCst->isMaxValue(true)) { + const APInt &NotSignBit = XorCst->getValue(); ICmpInst::Predicate Pred = ICI.isSigned() ? ICI.getUnsignedPredicate() : ICI.getSignedPredicate(); @@ -1126,12 +1216,24 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, Builder->getInt(RHSV ^ NotSignBit)); } } + + // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) + // iff -C is a power of 2 + if (ICI.getPredicate() == ICmpInst::ICMP_UGT && + XorCst->getValue() == ~RHSV && (RHSV + 1).isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), XorCst); + + // (icmp ult (xor X, C), -C) -> (icmp uge X, C) + // iff -C is a power of 2 + if (ICI.getPredicate() == ICmpInst::ICMP_ULT && + XorCst->getValue() == -RHSV && RHSV.isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), XorCst); } break; - case Instruction::And: // (icmp pred (and X, AndCST), RHS) + case Instruction::And: // (icmp pred (and X, AndCst), RHS) if (LHSI->hasOneUse() && isa(LHSI->getOperand(1)) && LHSI->getOperand(0)->hasOneUse()) { - ConstantInt *AndCST = cast(LHSI->getOperand(1)); + ConstantInt *AndCst = cast(LHSI->getOperand(1)); // If the LHS is an AND of a truncating cast, we can widen the // and/compare to be the input width without changing the value @@ -1142,10 +1244,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // Extending a relational comparison when we're checking the sign // bit would not work. if (ICI.isEquality() || - (!AndCST->isNegative() && RHSV.isNonNegative())) { + (!AndCst->isNegative() && RHSV.isNonNegative())) { Value *NewAnd = Builder->CreateAnd(Cast->getOperand(0), - ConstantExpr::getZExt(AndCST, Cast->getSrcTy())); + ConstantExpr::getZExt(AndCst, Cast->getSrcTy())); NewAnd->takeName(LHSI); return new ICmpInst(ICI.getPredicate(), NewAnd, ConstantExpr::getZExt(RHS, Cast->getSrcTy())); @@ -1161,7 +1263,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (ICI.isEquality() && RHSV.getActiveBits() <= Ty->getBitWidth()) { Value *NewAnd = Builder->CreateAnd(Cast->getOperand(0), - ConstantExpr::getTrunc(AndCST, Ty)); + ConstantExpr::getTrunc(AndCst, Ty)); NewAnd->takeName(LHSI); return new ICmpInst(ICI.getPredicate(), NewAnd, ConstantExpr::getTrunc(RHS, Ty)); @@ -1174,41 +1276,58 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // access. BinaryOperator *Shift = dyn_cast(LHSI->getOperand(0)); if (Shift && !Shift->isShift()) - Shift = 0; + Shift = nullptr; ConstantInt *ShAmt; - ShAmt = Shift ? dyn_cast(Shift->getOperand(1)) : 0; - Type *Ty = Shift ? Shift->getType() : 0; // Type of the shift. - Type *AndTy = AndCST->getType(); // Type of the and. + ShAmt = Shift ? dyn_cast(Shift->getOperand(1)) : nullptr; - // We can fold this as long as we can't shift unknown bits - // into the mask. This can only happen with signed shift - // rights, as they sign-extend. + // This seemingly simple opportunity to fold away a shift turns out to + // be rather complicated. See PR17827 + // ( http://llvm.org/bugs/show_bug.cgi?id=17827 ) for details. if (ShAmt) { - bool CanFold = Shift->isLogicalShift(); - if (!CanFold) { - // To test for the bad case of the signed shr, see if any - // of the bits shifted in could be tested after the mask. - uint32_t TyBits = Ty->getPrimitiveSizeInBits(); - int ShAmtVal = TyBits - ShAmt->getLimitedValue(TyBits); - - uint32_t BitWidth = AndTy->getPrimitiveSizeInBits(); - if ((APInt::getHighBitsSet(BitWidth, BitWidth-ShAmtVal) & - AndCST->getValue()) == 0) + bool CanFold = false; + unsigned ShiftOpcode = Shift->getOpcode(); + if (ShiftOpcode == Instruction::AShr) { + // There may be some constraints that make this possible, + // but nothing simple has been discovered yet. + CanFold = false; + } else if (ShiftOpcode == Instruction::Shl) { + // For a left shift, we can fold if the comparison is not signed. + // We can also fold a signed comparison if the mask value and + // comparison value are not negative. These constraints may not be + // obvious, but we can prove that they are correct using an SMT + // solver. + if (!ICI.isSigned() || (!AndCst->isNegative() && !RHS->isNegative())) CanFold = true; + } else if (ShiftOpcode == Instruction::LShr) { + // For a logical right shift, we can fold if the comparison is not + // signed. We can also fold a signed comparison if the shifted mask + // value and the shifted comparison value are not negative. + // These constraints may not be obvious, but we can prove that they + // are correct using an SMT solver. + if (!ICI.isSigned()) + CanFold = true; + else { + ConstantInt *ShiftedAndCst = + cast(ConstantExpr::getShl(AndCst, ShAmt)); + ConstantInt *ShiftedRHSCst = + cast(ConstantExpr::getShl(RHS, ShAmt)); + + if (!ShiftedAndCst->isNegative() && !ShiftedRHSCst->isNegative()) + CanFold = true; + } } if (CanFold) { Constant *NewCst; - if (Shift->getOpcode() == Instruction::Shl) + if (ShiftOpcode == Instruction::Shl) NewCst = ConstantExpr::getLShr(RHS, ShAmt); else NewCst = ConstantExpr::getShl(RHS, ShAmt); // Check to see if we are shifting out any of the bits being // compared. - if (ConstantExpr::get(Shift->getOpcode(), - NewCst, ShAmt) != RHS) { + if (ConstantExpr::get(ShiftOpcode, NewCst, ShAmt) != RHS) { // If we shifted bits out, the fold is not going to work out. // As a special case, check to see if this means that the // result is always true or false now. @@ -1218,12 +1337,12 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, return ReplaceInstUsesWith(ICI, Builder->getTrue()); } else { ICI.setOperand(1, NewCst); - Constant *NewAndCST; - if (Shift->getOpcode() == Instruction::Shl) - NewAndCST = ConstantExpr::getLShr(AndCST, ShAmt); + Constant *NewAndCst; + if (ShiftOpcode == Instruction::Shl) + NewAndCst = ConstantExpr::getLShr(AndCst, ShAmt); else - NewAndCST = ConstantExpr::getShl(AndCST, ShAmt); - LHSI->setOperand(1, NewAndCST); + NewAndCst = ConstantExpr::getShl(AndCst, ShAmt); + LHSI->setOperand(1, NewAndCst); LHSI->setOperand(0, Shift->getOperand(0)); Worklist.Add(Shift); // Shift is dead. return &ICI; @@ -1240,10 +1359,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // Compute C << Y. Value *NS; if (Shift->getOpcode() == Instruction::LShr) { - NS = Builder->CreateShl(AndCST, Shift->getOperand(1)); + NS = Builder->CreateShl(AndCst, Shift->getOperand(1)); } else { // Insert a logical shift. - NS = Builder->CreateLShr(AndCST, Shift->getOperand(1)); + NS = Builder->CreateLShr(AndCst, Shift->getOperand(1)); } // Compute X & (C << Y). @@ -1254,12 +1373,54 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, return &ICI; } - // Replace ((X & AndCST) > RHSV) with ((X & AndCST) != 0), if any - // bit set in (X & AndCST) will produce a result greater than RHSV. + // (icmp pred (and (or (lshr X, Y), X), 1), 0) --> + // (icmp pred (and X, (or (shl 1, Y), 1), 0)) + // + // iff pred isn't signed + { + Value *X, *Y, *LShr; + if (!ICI.isSigned() && RHSV == 0) { + if (match(LHSI->getOperand(1), m_One())) { + Constant *One = cast(LHSI->getOperand(1)); + Value *Or = LHSI->getOperand(0); + if (match(Or, m_Or(m_Value(LShr), m_Value(X))) && + match(LShr, m_LShr(m_Specific(X), m_Value(Y)))) { + unsigned UsesRemoved = 0; + if (LHSI->hasOneUse()) + ++UsesRemoved; + if (Or->hasOneUse()) + ++UsesRemoved; + if (LShr->hasOneUse()) + ++UsesRemoved; + Value *NewOr = nullptr; + // Compute X & ((1 << Y) | 1) + if (auto *C = dyn_cast(Y)) { + if (UsesRemoved >= 1) + NewOr = + ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); + } else { + if (UsesRemoved >= 3) + NewOr = Builder->CreateOr(Builder->CreateShl(One, Y, + LShr->getName(), + /*HasNUW=*/true), + One, Or->getName()); + } + if (NewOr) { + Value *NewAnd = Builder->CreateAnd(X, NewOr, LHSI->getName()); + ICI.setOperand(0, NewAnd); + return &ICI; + } + } + } + } + } + + // Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any + // bit set in (X & AndCst) will produce a result greater than RHSV. if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { - unsigned NTZ = AndCST->getValue().countTrailingZeros(); - if ((NTZ < AndCST->getBitWidth()) && - APInt::getOneBitSet(AndCST->getBitWidth(), NTZ).ugt(RHSV)) + unsigned NTZ = AndCst->getValue().countTrailingZeros(); + if ((NTZ < AndCst->getBitWidth()) && + APInt::getOneBitSet(AndCst->getBitWidth(), NTZ).ugt(RHSV)) return new ICmpInst(ICmpInst::ICMP_NE, LHSI, Constant::getNullValue(RHS->getType())); } @@ -1277,6 +1438,15 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, return Res; } } + + // X & -C == -C -> X > u ~C + // X & -C != -C -> X <= u ~C + // iff C is a power of 2 + if (ICI.isEquality() && RHS == LHSI->getOperand(1) && (-RHSV).isPowerOf2()) + return new ICmpInst( + ICI.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_UGT + : ICmpInst::ICMP_ULE, + LHSI->getOperand(0), SubOne(RHS)); break; case Instruction::Or: { @@ -1340,16 +1510,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, unsigned RHSLog2 = RHSV.logBase2(); // (1 << X) >= 2147483648 -> X >= 31 -> X == 31 - // (1 << X) > 2147483648 -> X > 31 -> false - // (1 << X) <= 2147483648 -> X <= 31 -> true // (1 << X) < 2147483648 -> X < 31 -> X != 31 if (RHSLog2 == TypeBits-1) { if (Pred == ICmpInst::ICMP_UGE) Pred = ICmpInst::ICMP_EQ; - else if (Pred == ICmpInst::ICMP_UGT) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); - else if (Pred == ICmpInst::ICMP_ULE) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); else if (Pred == ICmpInst::ICMP_ULT) Pred = ICmpInst::ICMP_NE; } @@ -1384,10 +1548,6 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (RHSVIsPowerOf2) return new ICmpInst( Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2())); - - return ReplaceInstUsesWith( - ICI, Pred == ICmpInst::ICMP_EQ ? Builder->getFalse() - : Builder->getTrue()); } } break; @@ -1526,7 +1686,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, Builder->CreateOr(LHSI->getOperand(1), RHSV - 1), LHSC); - // C1-X >u C2 -> (X|C2) == C1 + // C1-X >u C2 -> (X|C2) != C1 // iff C1 & C2 == C2 // C2+1 is a power of 2 if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() && @@ -1573,7 +1733,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, Builder->CreateAnd(LHSI->getOperand(0), -RHSV), ConstantExpr::getNeg(LHSC)); - // X-C1 >u C2 -> (X & ~C2) == C1 + // X-C1 >u C2 -> (X & ~C2) != C1 // iff C1 & C2 == 0 // C2+1 is a power of 2 if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() && @@ -1744,7 +1904,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, } } } - return 0; + return nullptr; } /// visitICmpInstWithCastAndCast - Handle icmp (cast x to y), (cast/cst). @@ -1759,18 +1919,20 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the // integer type is the same size as the pointer type. - if (TD && LHSCI->getOpcode() == Instruction::PtrToInt && - TD->getPointerSizeInBits() == - cast(DestTy)->getBitWidth()) { - Value *RHSOp = 0; - if (Constant *RHSC = dyn_cast(ICI.getOperand(1))) { + if (LHSCI->getOpcode() == Instruction::PtrToInt && + DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth()) { + Value *RHSOp = nullptr; + if (PtrToIntOperator *RHSC = dyn_cast(ICI.getOperand(1))) { + Value *RHSCIOp = RHSC->getOperand(0); + if (RHSCIOp->getType()->getPointerAddressSpace() == + LHSCIOp->getType()->getPointerAddressSpace()) { + RHSOp = RHSC->getOperand(0); + // If the pointer types don't match, insert a bitcast. + if (LHSCIOp->getType() != RHSOp->getType()) + RHSOp = Builder->CreateBitCast(RHSOp, LHSCIOp->getType()); + } + } else if (Constant *RHSC = dyn_cast(ICI.getOperand(1))) RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy); - } else if (PtrToIntInst *RHSC = dyn_cast(ICI.getOperand(1))) { - RHSOp = RHSC->getOperand(0); - // If the pointer types don't match, insert a bitcast. - if (LHSCIOp->getType() != RHSOp->getType()) - RHSOp = Builder->CreateBitCast(RHSOp, LHSCIOp->getType()); - } if (RHSOp) return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSOp); @@ -1780,7 +1942,7 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // Enforce this. if (LHSCI->getOpcode() != Instruction::ZExt && LHSCI->getOpcode() != Instruction::SExt) - return 0; + return nullptr; bool isSignedExt = LHSCI->getOpcode() == Instruction::SExt; bool isSignedCmp = ICI.isSigned(); @@ -1789,12 +1951,12 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // Not an extension from the same type? RHSCIOp = CI->getOperand(0); if (RHSCIOp->getType() != LHSCIOp->getType()) - return 0; + return nullptr; // If the signedness of the two casts doesn't agree (i.e. one is a sext // and the other is a zext), then we can't handle this. if (CI->getOpcode() != LHSCI->getOpcode()) - return 0; + return nullptr; // Deal with equality cases early. if (ICI.isEquality()) @@ -1812,7 +1974,7 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // If we aren't dealing with a constant on the RHS, exit early ConstantInt *CI = dyn_cast(ICI.getOperand(1)); if (!CI) - return 0; + return nullptr; // Compute the constant that would happen if we truncated to SrcTy then // reextended to DestTy. @@ -1841,7 +2003,7 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // by SimplifyICmpInst, so only deal with the tricky case. if (isSignedCmp || !isSignedExt) - return 0; + return nullptr; // Evaluate the comparison for LT (we invert for GT below). LE and GE cases // should have been folded away previously and not enter in here. @@ -1877,12 +2039,12 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // In order to eliminate the add-with-constant, the compare can be its only // use. Instruction *AddWithCst = cast(I.getOperand(0)); - if (!AddWithCst->hasOneUse()) return 0; + if (!AddWithCst->hasOneUse()) return nullptr; // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. - if (!CI2->getValue().isPowerOf2()) return 0; + if (!CI2->getValue().isPowerOf2()) return nullptr; unsigned NewWidth = CI2->getValue().countTrailingZeros(); - if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return 0; + if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return nullptr; // The width of the new add formed is 1 more than the bias. ++NewWidth; @@ -1890,33 +2052,32 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // Check to see that CI1 is an all-ones value with NewWidth bits. if (CI1->getBitWidth() == NewWidth || CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) - return 0; + return nullptr; // This is only really a signed overflow check if the inputs have been // sign-extended; check for that condition. For example, if CI2 is 2^31 and // the operands of the add are 64 bits wide, we need at least 33 sign bits. unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; - if (IC.ComputeNumSignBits(A) < NeededSignBits || - IC.ComputeNumSignBits(B) < NeededSignBits) - return 0; + if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || + IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) + return nullptr; // In order to replace the original add with a narrower // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant // and truncates that discard the high bits of the add. Verify that this is // the case. Instruction *OrigAdd = cast(AddWithCst->getOperand(0)); - for (Value::use_iterator UI = OrigAdd->use_begin(), E = OrigAdd->use_end(); - UI != E; ++UI) { - if (*UI == AddWithCst) continue; + for (User *U : OrigAdd->users()) { + if (U == AddWithCst) continue; // Only accept truncates for now. We would really like a nice recursive // predicate like SimplifyDemandedBits, but which goes downwards the use-def // chain to see which bits of a value are actually demanded. If the // original add had another add which was then immediately truncated, we // could still do the transformation. - TruncInst *TI = dyn_cast(*UI); - if (TI == 0 || - TI->getType()->getPrimitiveSizeInBits() > NewWidth) return 0; + TruncInst *TI = dyn_cast(U); + if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth) + return nullptr; } // If the pattern matches, truncate the inputs to the narrower type and @@ -1936,7 +2097,7 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName()+".trunc"); Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName()+".trunc"); - CallInst *Call = Builder->CreateCall2(F, TruncA, TruncB, "sadd"); + CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd"); Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); @@ -1948,33 +2109,329 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, return ExtractValueInst::Create(Call, 1, "sadd.overflow"); } -static Instruction *ProcessUAddIdiom(Instruction &I, Value *OrigAddV, - InstCombiner &IC) { +bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, + Value *RHS, Instruction &OrigI, + Value *&Result, Constant *&Overflow) { + assert((!OrigI.isCommutative() || + !(isa(LHS) && !isa(RHS))) && + "call with a constant RHS if possible!"); + + auto SetResult = [&](Value *OpResult, Constant *OverflowVal, bool ReuseName) { + Result = OpResult; + Overflow = OverflowVal; + if (ReuseName) + Result->takeName(&OrigI); + return true; + }; + + switch (OCF) { + case OCF_INVALID: + llvm_unreachable("bad overflow check kind!"); + + case OCF_UNSIGNED_ADD: { + OverflowResult OR = computeOverflowForUnsignedAdd(LHS, RHS, &OrigI); + if (OR == OverflowResult::NeverOverflows) + return SetResult(Builder->CreateNUWAdd(LHS, RHS), Builder->getFalse(), + true); + + if (OR == OverflowResult::AlwaysOverflows) + return SetResult(Builder->CreateAdd(LHS, RHS), Builder->getTrue(), true); + } + // FALL THROUGH uadd into sadd + case OCF_SIGNED_ADD: { + // X + 0 -> {X, false} + if (match(RHS, m_Zero())) + return SetResult(LHS, Builder->getFalse(), false); + + // We can strength reduce this signed add into a regular add if we can prove + // that it will never overflow. + if (OCF == OCF_SIGNED_ADD) + if (WillNotOverflowSignedAdd(LHS, RHS, OrigI)) + return SetResult(Builder->CreateNSWAdd(LHS, RHS), Builder->getFalse(), + true); + break; + } + + case OCF_UNSIGNED_SUB: + case OCF_SIGNED_SUB: { + // X - 0 -> {X, false} + if (match(RHS, m_Zero())) + return SetResult(LHS, Builder->getFalse(), false); + + if (OCF == OCF_SIGNED_SUB) { + if (WillNotOverflowSignedSub(LHS, RHS, OrigI)) + return SetResult(Builder->CreateNSWSub(LHS, RHS), Builder->getFalse(), + true); + } else { + if (WillNotOverflowUnsignedSub(LHS, RHS, OrigI)) + return SetResult(Builder->CreateNUWSub(LHS, RHS), Builder->getFalse(), + true); + } + break; + } + + case OCF_UNSIGNED_MUL: { + OverflowResult OR = computeOverflowForUnsignedMul(LHS, RHS, &OrigI); + if (OR == OverflowResult::NeverOverflows) + return SetResult(Builder->CreateNUWMul(LHS, RHS), Builder->getFalse(), + true); + if (OR == OverflowResult::AlwaysOverflows) + return SetResult(Builder->CreateMul(LHS, RHS), Builder->getTrue(), true); + } // FALL THROUGH + case OCF_SIGNED_MUL: + // X * undef -> undef + if (isa(RHS)) + return SetResult(RHS, UndefValue::get(Builder->getInt1Ty()), false); + + // X * 0 -> {0, false} + if (match(RHS, m_Zero())) + return SetResult(RHS, Builder->getFalse(), false); + + // X * 1 -> {X, false} + if (match(RHS, m_One())) + return SetResult(LHS, Builder->getFalse(), false); + + if (OCF == OCF_SIGNED_MUL) + if (WillNotOverflowSignedMul(LHS, RHS, OrigI)) + return SetResult(Builder->CreateNSWMul(LHS, RHS), Builder->getFalse(), + true); + } + + return false; +} + +/// \brief Recognize and process idiom involving test for multiplication +/// overflow. +/// +/// The caller has matched a pattern of the form: +/// I = cmp u (mul(zext A, zext B), V +/// The function checks if this is a test for overflow and if so replaces +/// multiplication with call to 'mul.with.overflow' intrinsic. +/// +/// \param I Compare instruction. +/// \param MulVal Result of 'mult' instruction. It is one of the arguments of +/// the compare instruction. Must be of integer type. +/// \param OtherVal The other argument of compare instruction. +/// \returns Instruction which must replace the compare instruction, NULL if no +/// replacement required. +static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, + Value *OtherVal, InstCombiner &IC) { // Don't bother doing this transformation for pointers, don't do it for // vectors. - if (!isa(OrigAddV->getType())) return 0; + if (!isa(MulVal->getType())) + return nullptr; + + assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal); + assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal); + Instruction *MulInstr = cast(MulVal); + assert(MulInstr->getOpcode() == Instruction::Mul); + + auto *LHS = cast(MulInstr->getOperand(0)), + *RHS = cast(MulInstr->getOperand(1)); + assert(LHS->getOpcode() == Instruction::ZExt); + assert(RHS->getOpcode() == Instruction::ZExt); + Value *A = LHS->getOperand(0), *B = RHS->getOperand(0); + + // Calculate type and width of the result produced by mul.with.overflow. + Type *TyA = A->getType(), *TyB = B->getType(); + unsigned WidthA = TyA->getPrimitiveSizeInBits(), + WidthB = TyB->getPrimitiveSizeInBits(); + unsigned MulWidth; + Type *MulType; + if (WidthB > WidthA) { + MulWidth = WidthB; + MulType = TyB; + } else { + MulWidth = WidthA; + MulType = TyA; + } - // If the add is a constant expr, then we don't bother transforming it. - Instruction *OrigAdd = dyn_cast(OrigAddV); - if (OrigAdd == 0) return 0; + // In order to replace the original mul with a narrower mul.with.overflow, + // all uses must ignore upper bits of the product. The number of used low + // bits must be not greater than the width of mul.with.overflow. + if (MulVal->hasNUsesOrMore(2)) + for (User *U : MulVal->users()) { + if (U == &I) + continue; + if (TruncInst *TI = dyn_cast(U)) { + // Check if truncation ignores bits above MulWidth. + unsigned TruncWidth = TI->getType()->getPrimitiveSizeInBits(); + if (TruncWidth > MulWidth) + return nullptr; + } else if (BinaryOperator *BO = dyn_cast(U)) { + // Check if AND ignores bits above MulWidth. + if (BO->getOpcode() != Instruction::And) + return nullptr; + if (ConstantInt *CI = dyn_cast(BO->getOperand(1))) { + const APInt &CVal = CI->getValue(); + if (CVal.getBitWidth() - CVal.countLeadingZeros() > MulWidth) + return nullptr; + } + } else { + // Other uses prohibit this transformation. + return nullptr; + } + } - Value *LHS = OrigAdd->getOperand(0), *RHS = OrigAdd->getOperand(1); + // Recognize patterns + switch (I.getPredicate()) { + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp eq/neq mulval, zext trunc mulval + if (ZExtInst *Zext = dyn_cast(OtherVal)) + if (Zext->hasOneUse()) { + Value *ZextArg = Zext->getOperand(0); + if (TruncInst *Trunc = dyn_cast(ZextArg)) + if (Trunc->getType()->getPrimitiveSizeInBits() == MulWidth) + break; //Recognized + } - // Put the new code above the original add, in case there are any uses of the - // add between the add and the compare. - InstCombiner::BuilderTy *Builder = IC.Builder; - Builder->SetInsertPoint(OrigAdd); + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits. + ConstantInt *CI; + Value *ValToMask; + if (match(OtherVal, m_And(m_Value(ValToMask), m_ConstantInt(CI)))) { + if (ValToMask != MulVal) + return nullptr; + const APInt &CVal = CI->getValue() + 1; + if (CVal.isPowerOf2()) { + unsigned MaskWidth = CVal.logBase2(); + if (MaskWidth == MulWidth) + break; // Recognized + } + } + return nullptr; + + case ICmpInst::ICMP_UGT: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ugt mulval, max + if (ConstantInt *CI = dyn_cast(OtherVal)) { + APInt MaxVal = APInt::getMaxValue(MulWidth); + MaxVal = MaxVal.zext(CI->getBitWidth()); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + case ICmpInst::ICMP_UGE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp uge mulval, max+1 + if (ConstantInt *CI = dyn_cast(OtherVal)) { + APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + case ICmpInst::ICMP_ULE: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ule mulval, max + if (ConstantInt *CI = dyn_cast(OtherVal)) { + APInt MaxVal = APInt::getMaxValue(MulWidth); + MaxVal = MaxVal.zext(CI->getBitWidth()); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + case ICmpInst::ICMP_ULT: + // Recognize pattern: + // mulval = mul(zext A, zext B) + // cmp ule mulval, max + 1 + if (ConstantInt *CI = dyn_cast(OtherVal)) { + APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth); + if (MaxVal.eq(CI->getValue())) + break; // Recognized + } + return nullptr; + + default: + return nullptr; + } + + InstCombiner::BuilderTy *Builder = IC.Builder; + Builder->SetInsertPoint(MulInstr); Module *M = I.getParent()->getParent()->getParent(); - Type *Ty = LHS->getType(); - Value *F = Intrinsic::getDeclaration(M, Intrinsic::uadd_with_overflow, Ty); - CallInst *Call = Builder->CreateCall2(F, LHS, RHS, "uadd"); - Value *Add = Builder->CreateExtractValue(Call, 0); - IC.ReplaceInstUsesWith(*OrigAdd, Add); + // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B) + Value *MulA = A, *MulB = B; + if (WidthA < MulWidth) + MulA = Builder->CreateZExt(A, MulType); + if (WidthB < MulWidth) + MulB = Builder->CreateZExt(B, MulType); + Value *F = + Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, MulType); + CallInst *Call = Builder->CreateCall(F, {MulA, MulB}, "umul"); + IC.Worklist.Add(MulInstr); + + // If there are uses of mul result other than the comparison, we know that + // they are truncation or binary AND. Change them to use result of + // mul.with.overflow and adjust properly mask/size. + if (MulVal->hasNUsesOrMore(2)) { + Value *Mul = Builder->CreateExtractValue(Call, 0, "umul.value"); + for (User *U : MulVal->users()) { + if (U == &I || U == OtherVal) + continue; + if (TruncInst *TI = dyn_cast(U)) { + if (TI->getType()->getPrimitiveSizeInBits() == MulWidth) + IC.ReplaceInstUsesWith(*TI, Mul); + else + TI->setOperand(0, Mul); + } else if (BinaryOperator *BO = dyn_cast(U)) { + assert(BO->getOpcode() == Instruction::And); + // Replace (mul & mask) --> zext (mul.with.overflow & short_mask) + ConstantInt *CI = cast(BO->getOperand(1)); + APInt ShortMask = CI->getValue().trunc(MulWidth); + Value *ShortAnd = Builder->CreateAnd(Mul, ShortMask); + Instruction *Zext = + cast(Builder->CreateZExt(ShortAnd, BO->getType())); + IC.Worklist.Add(Zext); + IC.ReplaceInstUsesWith(*BO, Zext); + } else { + llvm_unreachable("Unexpected Binary operation"); + } + IC.Worklist.Add(cast(U)); + } + } + if (isa(OtherVal)) + IC.Worklist.Add(cast(OtherVal)); - // The original icmp gets replaced with the overflow value. - return ExtractValueInst::Create(Call, 1, "uadd.overflow"); + // The original icmp gets replaced with the overflow value, maybe inverted + // depending on predicate. + bool Inverse = false; + switch (I.getPredicate()) { + case ICmpInst::ICMP_NE: + break; + case ICmpInst::ICMP_EQ: + Inverse = true; + break; + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + if (I.getOperand(0) == MulVal) + break; + Inverse = true; + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + if (I.getOperand(1) == MulVal) + break; + Inverse = true; + break; + default: + llvm_unreachable("Unexpected predicate"); + } + if (Inverse) { + Value *Res = Builder->CreateExtractValue(Call, 1); + return BinaryOperator::CreateNot(Res); + } + + return ExtractValueInst::Create(Call, 1); } // DemandedBitsLHSMask - When performing a comparison against a constant, @@ -2014,20 +2471,181 @@ static APInt DemandedBitsLHSMask(ICmpInst &I, } +/// \brief Check if the order of \p Op0 and \p Op1 as operand in an ICmpInst +/// should be swapped. +/// The decision is based on how many times these two operands are reused +/// as subtract operands and their positions in those instructions. +/// The rational is that several architectures use the same instruction for +/// both subtract and cmp, thus it is better if the order of those operands +/// match. +/// \return true if Op0 and Op1 should be swapped. +static bool swapMayExposeCSEOpportunities(const Value * Op0, + const Value * Op1) { + // Filter out pointer value as those cannot appears directly in subtract. + // FIXME: we may want to go through inttoptrs or bitcasts. + if (Op0->getType()->isPointerTy()) + return false; + // Count every uses of both Op0 and Op1 in a subtract. + // Each time Op0 is the first operand, count -1: swapping is bad, the + // subtract has already the same layout as the compare. + // Each time Op0 is the second operand, count +1: swapping is good, the + // subtract has a different layout as the compare. + // At the end, if the benefit is greater than 0, Op0 should come second to + // expose more CSE opportunities. + int GlobalSwapBenefits = 0; + for (const User *U : Op0->users()) { + const BinaryOperator *BinOp = dyn_cast(U); + if (!BinOp || BinOp->getOpcode() != Instruction::Sub) + continue; + // If Op0 is the first argument, this is not beneficial to swap the + // arguments. + int LocalSwapBenefits = -1; + unsigned Op1Idx = 1; + if (BinOp->getOperand(Op1Idx) == Op0) { + Op1Idx = 0; + LocalSwapBenefits = 1; + } + if (BinOp->getOperand(Op1Idx) != Op1) + continue; + GlobalSwapBenefits += LocalSwapBenefits; + } + return GlobalSwapBenefits > 0; +} + +/// \brief Check that one use is in the same block as the definition and all +/// other uses are in blocks dominated by a given block +/// +/// \param DI Definition +/// \param UI Use +/// \param DB Block that must dominate all uses of \p DI outside +/// the parent block +/// \return true when \p UI is the only use of \p DI in the parent block +/// and all other uses of \p DI are in blocks dominated by \p DB. +/// +bool InstCombiner::dominatesAllUses(const Instruction *DI, + const Instruction *UI, + const BasicBlock *DB) const { + assert(DI && UI && "Instruction not defined\n"); + // ignore incomplete definitions + if (!DI->getParent()) + return false; + // DI and UI must be in the same block + if (DI->getParent() != UI->getParent()) + return false; + // Protect from self-referencing blocks + if (DI->getParent() == DB) + return false; + // DominatorTree available? + if (!DT) + return false; + for (const User *U : DI->users()) { + auto *Usr = cast(U); + if (Usr != UI && !DT->dominates(DB, Usr->getParent())) + return false; + } + return true; +} + +/// +/// true when the instruction sequence within a block is select-cmp-br. +/// +static bool isChainSelectCmpBranch(const SelectInst *SI) { + const BasicBlock *BB = SI->getParent(); + if (!BB) + return false; + auto *BI = dyn_cast_or_null(BB->getTerminator()); + if (!BI || BI->getNumSuccessors() != 2) + return false; + auto *IC = dyn_cast(BI->getCondition()); + if (!IC || (IC->getOperand(0) != SI && IC->getOperand(1) != SI)) + return false; + return true; +} + +/// +/// \brief True when a select result is replaced by one of its operands +/// in select-icmp sequence. This will eventually result in the elimination +/// of the select. +/// +/// \param SI Select instruction +/// \param Icmp Compare instruction +/// \param SIOpd Operand that replaces the select +/// +/// Notes: +/// - The replacement is global and requires dominator information +/// - The caller is responsible for the actual replacement +/// +/// Example: +/// +/// entry: +/// %4 = select i1 %3, %C* %0, %C* null +/// %5 = icmp eq %C* %4, null +/// br i1 %5, label %9, label %7 +/// ... +/// ;