X-Git-Url: http://plrg.eecs.uci.edu/git/?p=oota-llvm.git;a=blobdiff_plain;f=lib%2FTransforms%2FInstCombine%2FInstCombineCompares.cpp;h=9ce1c3fc41b302126f33943cd0d1610a0c3b8749;hp=d803904ed368275e2a3993a230ce9335db97c985;hb=92c2acd055885e7e329fca25b9233fadd0da0dd0;hpb=8d7221ccf5012e7ece93aa976bf2603789b31441 diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index d803904ed36..9ce1c3fc41b 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -11,7 +11,9 @@ // //===----------------------------------------------------------------------===// -#include "InstCombine.h" +#include "InstCombineInternal.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" @@ -20,12 +22,20 @@ #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Target/TargetLibraryInfo.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Analysis/TargetLibraryInfo.h" + using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "instcombine" +// How many times is a select replaced by one of its operands? +STATISTIC(NumSel, "Number of select opts"); + +// Initialization Routines + static ConstantInt *getOne(Constant *C) { return ConstantInt::get(cast(C->getType()), 1); } @@ -219,10 +229,6 @@ static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, Instruction *InstCombiner:: FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, ConstantInt *AndCst) { - // We need TD information to know the pointer size unless this is inbounds. - if (!GEP->isInBounds() && !DL) - return nullptr; - Constant *Init = GV->getInitializer(); if (!isa(Init) && !isa(Init)) return nullptr; @@ -293,7 +299,6 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // the array, this will fully represent all the comparison results. uint64_t MagicBitvector = 0; - // Scan the array and see if one of our patterns matches. Constant *CompareRHS = cast(ICI.getOperand(1)); for (unsigned i = 0, e = ArrayElementCount; i != e; ++i) { @@ -388,7 +393,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // index down like the GEP would do implicitly. We don't have to do this for // an inbounds GEP because the index can't be out of range. if (!GEP->isInBounds()) { - Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); + Type *IntPtrTy = DL.getIntPtrType(GEP->getType()); unsigned PtrSize = IntPtrTy->getIntegerBitWidth(); if (Idx->getType()->getPrimitiveSizeInBits() > PtrSize) Idx = Builder->CreateTrunc(Idx, IntPtrTy); @@ -477,10 +482,8 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // - Default to i32 if (ArrayElementCount <= Idx->getType()->getIntegerBitWidth()) Ty = Idx->getType(); - else if (DL) - Ty = DL->getSmallestLegalIntType(Init->getContext(), ArrayElementCount); - else if (ArrayElementCount <= 32) - Ty = Type::getInt32Ty(Init->getContext()); + else + Ty = DL.getSmallestLegalIntType(Init->getContext(), ArrayElementCount); if (Ty) { Value *V = Builder->CreateIntCast(Idx, Ty, false); @@ -504,8 +507,8 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, /// /// If we can't emit an optimized form for this expression, this returns null. /// -static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { - const DataLayout &DL = *IC.getDataLayout(); +static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, + const DataLayout &DL) { gep_type_iterator GTI = gep_type_begin(GEP); // Check to see if this gep only has a single variable index. If so, and if @@ -612,17 +615,18 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, if (ICmpInst::isSigned(Cond)) return nullptr; - // Look through bitcasts. - if (BitCastInst *BCI = dyn_cast(RHS)) - RHS = BCI->getOperand(0); + // Look through bitcasts and addrspacecasts. We do not however want to remove + // 0 GEPs. + if (!isa(RHS)) + RHS = RHS->stripPointerCasts(); Value *PtrBase = GEPLHS->getOperand(0); - if (DL && PtrBase == RHS && GEPLHS->isInBounds()) { + if (PtrBase == RHS && GEPLHS->isInBounds()) { // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). // This transformation (ignoring the base and scales) is valid because we // know pointers can't overflow since the gep is inbounds. See if we can // output an optimized form. - Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, *this); + Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, *this, DL); // If not, synthesize the offset the hard way. if (!Offset) @@ -650,14 +654,29 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // If we're comparing GEPs with two base pointers that only differ in type // and both GEPs have only constant indices or just one use, then fold // the compare with the adjusted indices. - if (DL && GEPLHS->isInBounds() && GEPRHS->isInBounds() && + if (GEPLHS->isInBounds() && GEPRHS->isInBounds() && (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) && (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse()) && PtrBase->stripPointerCasts() == - GEPRHS->getOperand(0)->stripPointerCasts()) { + GEPRHS->getOperand(0)->stripPointerCasts()) { + Value *LOffset = EmitGEPOffset(GEPLHS); + Value *ROffset = EmitGEPOffset(GEPRHS); + + // If we looked through an addrspacecast between different sized address + // spaces, the LHS and RHS pointers are different sized + // integers. Truncate to the smaller one. + Type *LHSIndexTy = LOffset->getType(); + Type *RHSIndexTy = ROffset->getType(); + if (LHSIndexTy != RHSIndexTy) { + if (LHSIndexTy->getPrimitiveSizeInBits() < + RHSIndexTy->getPrimitiveSizeInBits()) { + ROffset = Builder->CreateTrunc(ROffset, LHSIndexTy); + } else + LOffset = Builder->CreateTrunc(LOffset, RHSIndexTy); + } + Value *Cmp = Builder->CreateICmp(ICmpInst::getSignedPredicate(Cond), - EmitGEPOffset(GEPLHS), - EmitGEPOffset(GEPRHS)); + LOffset, ROffset); return ReplaceInstUsesWith(I, Cmp); } @@ -667,26 +686,12 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, } // If one of the GEPs has all zero indices, recurse. - bool AllZeros = true; - for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i) - if (!isa(GEPLHS->getOperand(i)) || - !cast(GEPLHS->getOperand(i))->isNullValue()) { - AllZeros = false; - break; - } - if (AllZeros) + if (GEPLHS->hasAllZeroIndices()) return FoldGEPICmp(GEPRHS, GEPLHS->getOperand(0), ICmpInst::getSwappedPredicate(Cond), I); // If the other GEP has all zero indices, recurse. - AllZeros = true; - for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i) - if (!isa(GEPRHS->getOperand(i)) || - !cast(GEPRHS->getOperand(i))->isNullValue()) { - AllZeros = false; - break; - } - if (AllZeros) + if (GEPRHS->hasAllZeroIndices()) return FoldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); @@ -721,9 +726,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Only lower this if the icmp is the only user of the GEP or if we expect // the result to fold to a constant! - if (DL && - GEPsInBounds && - (isa(GEPLHS) || GEPLHS->hasOneUse()) && + if (GEPsInBounds && (isa(GEPLHS) || GEPLHS->hasOneUse()) && (isa(GEPRHS) || GEPRHS->hasOneUse())) { // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2) Value *L = EmitGEPOffset(GEPLHS); @@ -738,21 +741,6 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, Value *X, ConstantInt *CI, ICmpInst::Predicate Pred) { - // If we have X+0, exit early (simplifying logic below) and let it get folded - // elsewhere. icmp X+0, X -> icmp X, X - if (CI->isZero()) { - bool isTrue = ICmpInst::isTrueWhenEqual(Pred); - return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue)); - } - - // (X+4) == X -> false. - if (Pred == ICmpInst::ICMP_EQ) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); - - // (X+4) != X -> true. - if (Pred == ICmpInst::ICMP_NE) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); - // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similarly for all other "or equals" // operators. @@ -1042,6 +1030,111 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, return nullptr; } +/// FoldICmpCstShrCst - Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" -> +/// (icmp eq/ne A, Log2(const2/const1)) -> +/// (icmp eq/ne A, Log2(const2) - Log2(const1)). +Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, + ConstantInt *CI2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getConstant = [&I, this](bool IsTrue) { + if (I.getPredicate() == I.ICMP_NE) + IsTrue = !IsTrue; + return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + }; + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + APInt AP1 = CI1->getValue(); + APInt AP2 = CI2->getValue(); + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + bool IsAShr = isa(Op); + if (IsAShr) { + if (AP2.isAllOnesValue()) + return nullptr; + if (AP2.isNegative() != AP1.isNegative()) + return nullptr; + if (AP2.sgt(AP1)) + return nullptr; + } + + if (!AP1) + // 'A' must be large enough to shift out the highest set bit. + return getICmp(I.ICMP_UGT, A, + ConstantInt::get(A->getType(), AP2.logBase2())); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + // Get the distance between the highest bit that's set. + int Shift; + // Both the constants are negative, take their positive to calculate log. + if (IsAShr && AP1.isNegative()) + // Get the ones' complement of AP2 and AP1 when computing the distance. + Shift = (~AP2).logBase2() - (~AP1).logBase2(); + else + Shift = AP2.logBase2() - AP1.logBase2(); + + if (Shift > 0) { + if (IsAShr ? AP1 == AP2.ashr(Shift) : AP1 == AP2.lshr(Shift)) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + } + // Shifting const2 will never be equal to const1. + return getConstant(false); +} + +/// FoldICmpCstShlCst - Handle "(icmp eq/ne (shl const2, A), const1)" -> +/// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)). +Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, + ConstantInt *CI1, + ConstantInt *CI2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getConstant = [&I, this](bool IsTrue) { + if (I.getPredicate() == I.ICMP_NE) + IsTrue = !IsTrue; + return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); + }; + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + APInt AP1 = CI1->getValue(); + APInt AP2 = CI2->getValue(); + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + + unsigned AP2TrailingZeros = AP2.countTrailingZeros(); + + if (!AP1 && AP2TrailingZeros != 0) + return getICmp(I.ICMP_UGE, A, + ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros)); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + // Get the distance between the lowest bits that are set. + int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + + if (Shift > 0 && AP2.shl(Shift) == AP1) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + + // Shifting const2 will never be equal to const1. + return getConstant(false); +} /// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)". /// @@ -1058,7 +1151,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(), SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits(); APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); - ComputeMaskedBits(LHSI->getOperand(0), KnownZero, KnownOne); + computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne, 0, &ICI); // If all the high bits are known, we can do this xform. if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { @@ -1280,6 +1373,48 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, return &ICI; } + // (icmp pred (and (or (lshr X, Y), X), 1), 0) --> + // (icmp pred (and X, (or (shl 1, Y), 1), 0)) + // + // iff pred isn't signed + { + Value *X, *Y, *LShr; + if (!ICI.isSigned() && RHSV == 0) { + if (match(LHSI->getOperand(1), m_One())) { + Constant *One = cast(LHSI->getOperand(1)); + Value *Or = LHSI->getOperand(0); + if (match(Or, m_Or(m_Value(LShr), m_Value(X))) && + match(LShr, m_LShr(m_Specific(X), m_Value(Y)))) { + unsigned UsesRemoved = 0; + if (LHSI->hasOneUse()) + ++UsesRemoved; + if (Or->hasOneUse()) + ++UsesRemoved; + if (LShr->hasOneUse()) + ++UsesRemoved; + Value *NewOr = nullptr; + // Compute X & ((1 << Y) | 1) + if (auto *C = dyn_cast(Y)) { + if (UsesRemoved >= 1) + NewOr = + ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); + } else { + if (UsesRemoved >= 3) + NewOr = Builder->CreateOr(Builder->CreateShl(One, Y, + LShr->getName(), + /*HasNUW=*/true), + One, Or->getName()); + } + if (NewOr) { + Value *NewAnd = Builder->CreateAnd(X, NewOr, LHSI->getName()); + ICI.setOperand(0, NewAnd); + return &ICI; + } + } + } + } + } + // Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any // bit set in (X & AndCst) will produce a result greater than RHSV. if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { @@ -1312,6 +1447,23 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, ICI.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_ULE, LHSI->getOperand(0), SubOne(RHS)); + + // (icmp eq (and %A, C), 0) -> (icmp sgt (trunc %A), -1) + // iff C is a power of 2 + if (ICI.isEquality() && LHSI->hasOneUse() && match(RHS, m_Zero())) { + if (auto *CI = dyn_cast(LHSI->getOperand(1))) { + const APInt &AI = CI->getValue(); + int32_t ExactLogBase2 = AI.exactLogBase2(); + if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { + Type *NTy = IntegerType::get(ICI.getContext(), ExactLogBase2 + 1); + Value *Trunc = Builder->CreateTrunc(LHSI->getOperand(0), NTy); + return new ICmpInst(ICI.getPredicate() == ICmpInst::ICMP_EQ + ? ICmpInst::ICMP_SGE + : ICmpInst::ICMP_SLT, + Trunc, Constant::getNullValue(NTy)); + } + } + } break; case Instruction::Or: { @@ -1375,16 +1527,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, unsigned RHSLog2 = RHSV.logBase2(); // (1 << X) >= 2147483648 -> X >= 31 -> X == 31 - // (1 << X) > 2147483648 -> X > 31 -> false - // (1 << X) <= 2147483648 -> X <= 31 -> true // (1 << X) < 2147483648 -> X < 31 -> X != 31 if (RHSLog2 == TypeBits-1) { if (Pred == ICmpInst::ICMP_UGE) Pred = ICmpInst::ICMP_EQ; - else if (Pred == ICmpInst::ICMP_UGT) - return ReplaceInstUsesWith(ICI, Builder->getFalse()); - else if (Pred == ICmpInst::ICMP_ULE) - return ReplaceInstUsesWith(ICI, Builder->getTrue()); else if (Pred == ICmpInst::ICMP_ULT) Pred = ICmpInst::ICMP_NE; } @@ -1419,10 +1565,6 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, if (RHSVIsPowerOf2) return new ICmpInst( Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2())); - - return ReplaceInstUsesWith( - ICI, Pred == ICmpInst::ICMP_EQ ? Builder->getFalse() - : Builder->getTrue()); } } break; @@ -1794,17 +1936,20 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the // integer type is the same size as the pointer type. - if (DL && LHSCI->getOpcode() == Instruction::PtrToInt && - DL->getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth()) { + if (LHSCI->getOpcode() == Instruction::PtrToInt && + DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth()) { Value *RHSOp = nullptr; - if (Constant *RHSC = dyn_cast(ICI.getOperand(1))) { + if (PtrToIntOperator *RHSC = dyn_cast(ICI.getOperand(1))) { + Value *RHSCIOp = RHSC->getOperand(0); + if (RHSCIOp->getType()->getPointerAddressSpace() == + LHSCIOp->getType()->getPointerAddressSpace()) { + RHSOp = RHSC->getOperand(0); + // If the pointer types don't match, insert a bitcast. + if (LHSCIOp->getType() != RHSOp->getType()) + RHSOp = Builder->CreateBitCast(RHSOp, LHSCIOp->getType()); + } + } else if (Constant *RHSC = dyn_cast(ICI.getOperand(1))) RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy); - } else if (PtrToIntInst *RHSC = dyn_cast(ICI.getOperand(1))) { - RHSOp = RHSC->getOperand(0); - // If the pointer types don't match, insert a bitcast. - if (LHSCIOp->getType() != RHSOp->getType()) - RHSOp = Builder->CreateBitCast(RHSOp, LHSCIOp->getType()); - } if (RHSOp) return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSOp); @@ -1930,8 +2075,8 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // sign-extended; check for that condition. For example, if CI2 is 2^31 and // the operands of the add are 64 bits wide, we need at least 33 sign bits. unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; - if (IC.ComputeNumSignBits(A) < NeededSignBits || - IC.ComputeNumSignBits(B) < NeededSignBits) + if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || + IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) return nullptr; // In order to replace the original add with a narrower @@ -1969,7 +2114,7 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName()+".trunc"); Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName()+".trunc"); - CallInst *Call = Builder->CreateCall2(F, TruncA, TruncB, "sadd"); + CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd"); Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); @@ -1981,33 +2126,95 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, return ExtractValueInst::Create(Call, 1, "sadd.overflow"); } -static Instruction *ProcessUAddIdiom(Instruction &I, Value *OrigAddV, - InstCombiner &IC) { - // Don't bother doing this transformation for pointers, don't do it for - // vectors. - if (!isa(OrigAddV->getType())) return nullptr; - - // If the add is a constant expr, then we don't bother transforming it. - Instruction *OrigAdd = dyn_cast(OrigAddV); - if (!OrigAdd) return nullptr; +bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, + Value *RHS, Instruction &OrigI, + Value *&Result, Constant *&Overflow) { + if (OrigI.isCommutative() && isa(LHS) && !isa(RHS)) + std::swap(LHS, RHS); + + auto SetResult = [&](Value *OpResult, Constant *OverflowVal, bool ReuseName) { + Result = OpResult; + Overflow = OverflowVal; + if (ReuseName) + Result->takeName(&OrigI); + return true; + }; + + switch (OCF) { + case OCF_INVALID: + llvm_unreachable("bad overflow check kind!"); + + case OCF_UNSIGNED_ADD: { + OverflowResult OR = computeOverflowForUnsignedAdd(LHS, RHS, &OrigI); + if (OR == OverflowResult::NeverOverflows) + return SetResult(Builder->CreateNUWAdd(LHS, RHS), Builder->getFalse(), + true); + + if (OR == OverflowResult::AlwaysOverflows) + return SetResult(Builder->CreateAdd(LHS, RHS), Builder->getTrue(), true); + } + // FALL THROUGH uadd into sadd + case OCF_SIGNED_ADD: { + // X + 0 -> {X, false} + if (match(RHS, m_Zero())) + return SetResult(LHS, Builder->getFalse(), false); + + // We can strength reduce this signed add into a regular add if we can prove + // that it will never overflow. + if (OCF == OCF_SIGNED_ADD) + if (WillNotOverflowSignedAdd(LHS, RHS, OrigI)) + return SetResult(Builder->CreateNSWAdd(LHS, RHS), Builder->getFalse(), + true); + break; + } - Value *LHS = OrigAdd->getOperand(0), *RHS = OrigAdd->getOperand(1); + case OCF_UNSIGNED_SUB: + case OCF_SIGNED_SUB: { + // X - 0 -> {X, false} + if (match(RHS, m_Zero())) + return SetResult(LHS, Builder->getFalse(), false); - // Put the new code above the original add, in case there are any uses of the - // add between the add and the compare. - InstCombiner::BuilderTy *Builder = IC.Builder; - Builder->SetInsertPoint(OrigAdd); - - Module *M = I.getParent()->getParent()->getParent(); - Type *Ty = LHS->getType(); - Value *F = Intrinsic::getDeclaration(M, Intrinsic::uadd_with_overflow, Ty); - CallInst *Call = Builder->CreateCall2(F, LHS, RHS, "uadd"); - Value *Add = Builder->CreateExtractValue(Call, 0); + if (OCF == OCF_SIGNED_SUB) { + if (WillNotOverflowSignedSub(LHS, RHS, OrigI)) + return SetResult(Builder->CreateNSWSub(LHS, RHS), Builder->getFalse(), + true); + } else { + if (WillNotOverflowUnsignedSub(LHS, RHS, OrigI)) + return SetResult(Builder->CreateNUWSub(LHS, RHS), Builder->getFalse(), + true); + } + break; + } - IC.ReplaceInstUsesWith(*OrigAdd, Add); + case OCF_UNSIGNED_MUL: { + OverflowResult OR = computeOverflowForUnsignedMul(LHS, RHS, &OrigI); + if (OR == OverflowResult::NeverOverflows) + return SetResult(Builder->CreateNUWMul(LHS, RHS), Builder->getFalse(), + true); + if (OR == OverflowResult::AlwaysOverflows) + return SetResult(Builder->CreateMul(LHS, RHS), Builder->getTrue(), true); + } // FALL THROUGH + case OCF_SIGNED_MUL: + // X * undef -> undef + if (isa(RHS)) + return SetResult(RHS, UndefValue::get(Builder->getInt1Ty()), false); + + // X * 0 -> {0, false} + if (match(RHS, m_Zero())) + return SetResult(RHS, Builder->getFalse(), false); + + // X * 1 -> {X, false} + if (match(RHS, m_One())) + return SetResult(LHS, Builder->getFalse(), false); + + if (OCF == OCF_SIGNED_MUL) + if (WillNotOverflowSignedMul(LHS, RHS, OrigI)) + return SetResult(Builder->CreateNSWMul(LHS, RHS), Builder->getFalse(), + true); + break; + } - // The original icmp gets replaced with the overflow value. - return ExtractValueInst::Create(Call, 1, "uadd.overflow"); + return false; } /// \brief Recognize and process idiom involving test for multiplication @@ -2026,14 +2233,18 @@ static Instruction *ProcessUAddIdiom(Instruction &I, Value *OrigAddV, /// replacement required. static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, Value *OtherVal, InstCombiner &IC) { + // Don't bother doing this transformation for pointers, don't do it for + // vectors. + if (!isa(MulVal->getType())) + return nullptr; + assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal); assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal); - assert(isa(MulVal->getType())); Instruction *MulInstr = cast(MulVal); assert(MulInstr->getOpcode() == Instruction::Mul); - Instruction *LHS = cast(MulInstr->getOperand(0)), - *RHS = cast(MulInstr->getOperand(1)); + auto *LHS = cast(MulInstr->getOperand(0)), + *RHS = cast(MulInstr->getOperand(1)); assert(LHS->getOpcode() == Instruction::ZExt); assert(RHS->getOpcode() == Instruction::ZExt); Value *A = LHS->getOperand(0), *B = RHS->getOperand(0); @@ -2173,7 +2384,7 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, MulB = Builder->CreateZExt(B, MulType); Value *F = Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, MulType); - CallInst *Call = Builder->CreateCall2(F, MulA, MulB, "umul"); + CallInst *Call = Builder->CreateCall(F, {MulA, MulB}, "umul"); IC.Worklist.Add(MulInstr); // If there are uses of mul result other than the comparison, we know that @@ -2318,6 +2529,122 @@ static bool swapMayExposeCSEOpportunities(const Value * Op0, return GlobalSwapBenefits > 0; } +/// \brief Check that one use is in the same block as the definition and all +/// other uses are in blocks dominated by a given block +/// +/// \param DI Definition +/// \param UI Use +/// \param DB Block that must dominate all uses of \p DI outside +/// the parent block +/// \return true when \p UI is the only use of \p DI in the parent block +/// and all other uses of \p DI are in blocks dominated by \p DB. +/// +bool InstCombiner::dominatesAllUses(const Instruction *DI, + const Instruction *UI, + const BasicBlock *DB) const { + assert(DI && UI && "Instruction not defined\n"); + // ignore incomplete definitions + if (!DI->getParent()) + return false; + // DI and UI must be in the same block + if (DI->getParent() != UI->getParent()) + return false; + // Protect from self-referencing blocks + if (DI->getParent() == DB) + return false; + // DominatorTree available? + if (!DT) + return false; + for (const User *U : DI->users()) { + auto *Usr = cast(U); + if (Usr != UI && !DT->dominates(DB, Usr->getParent())) + return false; + } + return true; +} + +/// +/// true when the instruction sequence within a block is select-cmp-br. +/// +static bool isChainSelectCmpBranch(const SelectInst *SI) { + const BasicBlock *BB = SI->getParent(); + if (!BB) + return false; + auto *BI = dyn_cast_or_null(BB->getTerminator()); + if (!BI || BI->getNumSuccessors() != 2) + return false; + auto *IC = dyn_cast(BI->getCondition()); + if (!IC || (IC->getOperand(0) != SI && IC->getOperand(1) != SI)) + return false; + return true; +} + +/// +/// \brief True when a select result is replaced by one of its operands +/// in select-icmp sequence. This will eventually result in the elimination +/// of the select. +/// +/// \param SI Select instruction +/// \param Icmp Compare instruction +/// \param SIOpd Operand that replaces the select +/// +/// Notes: +/// - The replacement is global and requires dominator information +/// - The caller is responsible for the actual replacement +/// +/// Example: +/// +/// entry: +/// %4 = select i1 %3, %C* %0, %C* null +/// %5 = icmp eq %C* %4, null +/// br i1 %5, label %9, label %7 +/// ... +/// ;