X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FTransforms%2FInstCombine%2FInstCombineCompares.cpp;h=ab0d1b10c2455f5cf11dcfa30b6b4176dbb84189;hb=46216f7f993c9cd71b71d46e37054b4c27cbaf40;hp=f7075dfc674f363b15e38c9fc0b180fdc82d6701;hpb=3b1f741856b01f0c9bbd0f8162b66964dad0590d;p=oota-llvm.git diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index f7075dfc674..ab0d1b10c24 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -11,7 +11,7 @@ // //===----------------------------------------------------------------------===// -#include "InstCombine.h" +#include "InstCombineInternal.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" @@ -24,7 +24,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Target/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" using namespace llvm; using namespace PatternMatch; @@ -229,10 +229,6 @@ static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, Instruction *InstCombiner:: FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, ConstantInt *AndCst) { - // We need TD information to know the pointer size unless this is inbounds. - if (!GEP->isInBounds() && !DL) - return nullptr; - Constant *Init = GV->getInitializer(); if (!isa(Init) && !isa(Init)) return nullptr; @@ -303,7 +299,6 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // the array, this will fully represent all the comparison results. uint64_t MagicBitvector = 0; - // Scan the array and see if one of our patterns matches. Constant *CompareRHS = cast(ICI.getOperand(1)); for (unsigned i = 0, e = ArrayElementCount; i != e; ++i) { @@ -398,7 +393,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // index down like the GEP would do implicitly. We don't have to do this for // an inbounds GEP because the index can't be out of range. if (!GEP->isInBounds()) { - Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); + Type *IntPtrTy = DL.getIntPtrType(GEP->getType()); unsigned PtrSize = IntPtrTy->getIntegerBitWidth(); if (Idx->getType()->getPrimitiveSizeInBits() > PtrSize) Idx = Builder->CreateTrunc(Idx, IntPtrTy); @@ -487,10 +482,8 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // - Default to i32 if (ArrayElementCount <= Idx->getType()->getIntegerBitWidth()) Ty = Idx->getType(); - else if (DL) - Ty = DL->getSmallestLegalIntType(Init->getContext(), ArrayElementCount); - else if (ArrayElementCount <= 32) - Ty = Type::getInt32Ty(Init->getContext()); + else + Ty = DL.getSmallestLegalIntType(Init->getContext(), ArrayElementCount); if (Ty) { Value *V = Builder->CreateIntCast(Idx, Ty, false); @@ -514,8 +507,8 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, /// /// If we can't emit an optimized form for this expression, this returns null. /// -static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) { - const DataLayout &DL = *IC.getDataLayout(); +static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, + const DataLayout &DL) { gep_type_iterator GTI = gep_type_begin(GEP); // Check to see if this gep only has a single variable index. If so, and if @@ -628,12 +621,12 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, RHS = RHS->stripPointerCasts(); Value *PtrBase = GEPLHS->getOperand(0); - if (DL && PtrBase == RHS && GEPLHS->isInBounds()) { + if (PtrBase == RHS && GEPLHS->isInBounds()) { // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). // This transformation (ignoring the base and scales) is valid because we // know pointers can't overflow since the gep is inbounds. See if we can // output an optimized form. - Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, *this); + Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, *this, DL); // If not, synthesize the offset the hard way. if (!Offset) @@ -661,11 +654,11 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // If we're comparing GEPs with two base pointers that only differ in type // and both GEPs have only constant indices or just one use, then fold // the compare with the adjusted indices. - if (DL && GEPLHS->isInBounds() && GEPRHS->isInBounds() && + if (GEPLHS->isInBounds() && GEPRHS->isInBounds() && (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) && (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse()) && PtrBase->stripPointerCasts() == - GEPRHS->getOperand(0)->stripPointerCasts()) { + GEPRHS->getOperand(0)->stripPointerCasts()) { Value *LOffset = EmitGEPOffset(GEPLHS); Value *ROffset = EmitGEPOffset(GEPRHS); @@ -733,9 +726,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Only lower this if the icmp is the only user of the GEP or if we expect // the result to fold to a constant! - if (DL && - GEPsInBounds && - (isa(GEPLHS) || GEPLHS->hasOneUse()) && + if (GEPsInBounds && (isa(GEPLHS) || GEPLHS->hasOneUse()) && (isa(GEPRHS) || GEPRHS->hasOneUse())) { // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2) Value *L = EmitGEPOffset(GEPLHS); @@ -1928,17 +1919,20 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) { // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the // integer type is the same size as the pointer type. - if (DL && LHSCI->getOpcode() == Instruction::PtrToInt && - DL->getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth()) { + if (LHSCI->getOpcode() == Instruction::PtrToInt && + DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth()) { Value *RHSOp = nullptr; - if (Constant *RHSC = dyn_cast(ICI.getOperand(1))) { + if (PtrToIntOperator *RHSC = dyn_cast(ICI.getOperand(1))) { + Value *RHSCIOp = RHSC->getOperand(0); + if (RHSCIOp->getType()->getPointerAddressSpace() == + LHSCIOp->getType()->getPointerAddressSpace()) { + RHSOp = RHSC->getOperand(0); + // If the pointer types don't match, insert a bitcast. + if (LHSCIOp->getType() != RHSOp->getType()) + RHSOp = Builder->CreateBitCast(RHSOp, LHSCIOp->getType()); + } + } else if (Constant *RHSC = dyn_cast(ICI.getOperand(1))) RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy); - } else if (PtrToIntInst *RHSC = dyn_cast(ICI.getOperand(1))) { - RHSOp = RHSC->getOperand(0); - // If the pointer types don't match, insert a bitcast. - if (LHSCIOp->getType() != RHSOp->getType()) - RHSOp = Builder->CreateBitCast(RHSOp, LHSCIOp->getType()); - } if (RHSOp) return new ICmpInst(ICI.getPredicate(), LHSCIOp, RHSOp); @@ -2103,7 +2097,7 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName()+".trunc"); Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName()+".trunc"); - CallInst *Call = Builder->CreateCall2(F, TruncA, TruncB, "sadd"); + CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd"); Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); @@ -2115,33 +2109,95 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, return ExtractValueInst::Create(Call, 1, "sadd.overflow"); } -static Instruction *ProcessUAddIdiom(Instruction &I, Value *OrigAddV, - InstCombiner &IC) { - // Don't bother doing this transformation for pointers, don't do it for - // vectors. - if (!isa(OrigAddV->getType())) return nullptr; +bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, + Value *RHS, Instruction &OrigI, + Value *&Result, Constant *&Overflow) { + assert((!OrigI.isCommutative() || + !(isa(LHS) && !isa(RHS))) && + "call with a constant RHS if possible!"); + + auto SetResult = [&](Value *OpResult, Constant *OverflowVal, bool ReuseName) { + Result = OpResult; + Overflow = OverflowVal; + if (ReuseName) + Result->takeName(&OrigI); + return true; + }; - // If the add is a constant expr, then we don't bother transforming it. - Instruction *OrigAdd = dyn_cast(OrigAddV); - if (!OrigAdd) return nullptr; + switch (OCF) { + case OCF_INVALID: + llvm_unreachable("bad overflow check kind!"); - Value *LHS = OrigAdd->getOperand(0), *RHS = OrigAdd->getOperand(1); + case OCF_UNSIGNED_ADD: { + OverflowResult OR = computeOverflowForUnsignedAdd(LHS, RHS, &OrigI); + if (OR == OverflowResult::NeverOverflows) + return SetResult(Builder->CreateNUWAdd(LHS, RHS), Builder->getFalse(), + true); - // Put the new code above the original add, in case there are any uses of the - // add between the add and the compare. - InstCombiner::BuilderTy *Builder = IC.Builder; - Builder->SetInsertPoint(OrigAdd); + if (OR == OverflowResult::AlwaysOverflows) + return SetResult(Builder->CreateAdd(LHS, RHS), Builder->getTrue(), true); + } + // FALL THROUGH uadd into sadd + case OCF_SIGNED_ADD: { + // X + 0 -> {X, false} + if (match(RHS, m_Zero())) + return SetResult(LHS, Builder->getFalse(), false); + + // We can strength reduce this signed add into a regular add if we can prove + // that it will never overflow. + if (OCF == OCF_SIGNED_ADD) + if (WillNotOverflowSignedAdd(LHS, RHS, OrigI)) + return SetResult(Builder->CreateNSWAdd(LHS, RHS), Builder->getFalse(), + true); + break; + } - Module *M = I.getParent()->getParent()->getParent(); - Type *Ty = LHS->getType(); - Value *F = Intrinsic::getDeclaration(M, Intrinsic::uadd_with_overflow, Ty); - CallInst *Call = Builder->CreateCall2(F, LHS, RHS, "uadd"); - Value *Add = Builder->CreateExtractValue(Call, 0); + case OCF_UNSIGNED_SUB: + case OCF_SIGNED_SUB: { + // X - 0 -> {X, false} + if (match(RHS, m_Zero())) + return SetResult(LHS, Builder->getFalse(), false); - IC.ReplaceInstUsesWith(*OrigAdd, Add); + if (OCF == OCF_SIGNED_SUB) { + if (WillNotOverflowSignedSub(LHS, RHS, OrigI)) + return SetResult(Builder->CreateNSWSub(LHS, RHS), Builder->getFalse(), + true); + } else { + if (WillNotOverflowUnsignedSub(LHS, RHS, OrigI)) + return SetResult(Builder->CreateNUWSub(LHS, RHS), Builder->getFalse(), + true); + } + break; + } - // The original icmp gets replaced with the overflow value. - return ExtractValueInst::Create(Call, 1, "uadd.overflow"); + case OCF_UNSIGNED_MUL: { + OverflowResult OR = computeOverflowForUnsignedMul(LHS, RHS, &OrigI); + if (OR == OverflowResult::NeverOverflows) + return SetResult(Builder->CreateNUWMul(LHS, RHS), Builder->getFalse(), + true); + if (OR == OverflowResult::AlwaysOverflows) + return SetResult(Builder->CreateMul(LHS, RHS), Builder->getTrue(), true); + } // FALL THROUGH + case OCF_SIGNED_MUL: + // X * undef -> undef + if (isa(RHS)) + return SetResult(RHS, UndefValue::get(Builder->getInt1Ty()), false); + + // X * 0 -> {0, false} + if (match(RHS, m_Zero())) + return SetResult(RHS, Builder->getFalse(), false); + + // X * 1 -> {X, false} + if (match(RHS, m_One())) + return SetResult(LHS, Builder->getFalse(), false); + + if (OCF == OCF_SIGNED_MUL) + if (WillNotOverflowSignedMul(LHS, RHS, OrigI)) + return SetResult(Builder->CreateNSWMul(LHS, RHS), Builder->getFalse(), + true); + } + + return false; } /// \brief Recognize and process idiom involving test for multiplication @@ -2311,7 +2367,7 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, MulB = Builder->CreateZExt(B, MulType); Value *F = Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, MulType); - CallInst *Call = Builder->CreateCall2(F, MulA, MulB, "umul"); + CallInst *Call = Builder->CreateCall(F, {MulA, MulB}, "umul"); IC.Worklist.Add(MulInstr); // If there are uses of mul result other than the comparison, we know that @@ -2657,8 +2713,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { unsigned BitWidth = 0; if (Ty->isIntOrIntVectorTy()) BitWidth = Ty->getScalarSizeInBits(); - else if (DL) // Pointers require DL info to get their size. - BitWidth = DL->getTypeSizeInBits(Ty->getScalarType()); + else // Get pointer size. + BitWidth = DL.getTypeSizeInBits(Ty->getScalarType()); bool isSignBit = false; @@ -2686,33 +2742,35 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return Res; } - // (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B) - if (I.isEquality() && CI->isZero() && - match(Op0, m_Sub(m_Value(A), m_Value(B)))) { - // (icmp cond A B) if cond is equality - return new ICmpInst(I.getPredicate(), A, B); + // The following transforms are only 'worth it' if the only user of the + // subtraction is the icmp. + if (Op0->hasOneUse()) { + // (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B) + if (I.isEquality() && CI->isZero() && + match(Op0, m_Sub(m_Value(A), m_Value(B)))) + return new ICmpInst(I.getPredicate(), A, B); + + // (icmp sgt (sub nsw A B), -1) -> (icmp sge A, B) + if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isAllOnesValue() && + match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) + return new ICmpInst(ICmpInst::ICMP_SGE, A, B); + + // (icmp sgt (sub nsw A B), 0) -> (icmp sgt A, B) + if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isZero() && + match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) + return new ICmpInst(ICmpInst::ICMP_SGT, A, B); + + // (icmp slt (sub nsw A B), 0) -> (icmp slt A, B) + if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isZero() && + match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) + return new ICmpInst(ICmpInst::ICMP_SLT, A, B); + + // (icmp slt (sub nsw A B), 1) -> (icmp sle A, B) + if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isOne() && + match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) + return new ICmpInst(ICmpInst::ICMP_SLE, A, B); } - // (icmp sgt (sub nsw A B), -1) -> (icmp sge A, B) - if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isAllOnesValue() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SGE, A, B); - - // (icmp sgt (sub nsw A B), 0) -> (icmp sgt A, B) - if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isZero() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SGT, A, B); - - // (icmp slt (sub nsw A B), 0) -> (icmp slt A, B) - if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isZero() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SLT, A, B); - - // (icmp slt (sub nsw A B), 1) -> (icmp sle A, B) - if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isOne() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SLE, A, B); - // If we have an icmp le or icmp ge instruction, turn it into the // appropriate icmp lt or icmp gt instruction. This allows us to rely on // them being folded in the code below. The SimplifyICmpInst code has @@ -2769,8 +2827,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Op0KnownZero, Op0KnownOne, 0)) return &I; if (SimplifyDemandedBits(I.getOperandUse(1), - APInt::getAllOnesValue(BitWidth), - Op1KnownZero, Op1KnownOne, 0)) + APInt::getAllOnesValue(BitWidth), Op1KnownZero, + Op1KnownOne, 0)) return &I; // Given the known and unknown bits, compute a range that the LHS could be @@ -3089,9 +3147,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } case Instruction::IntToPtr: // icmp pred inttoptr(X), null -> icmp pred X, 0 - if (RHSC->isNullValue() && DL && - DL->getIntPtrType(RHSC->getType()) == - LHSI->getOperand(0)->getType()) + if (RHSC->isNullValue() && + DL.getIntPtrType(RHSC->getType()) == LHSI->getOperand(0)->getType()) return new ICmpInst(I.getPredicate(), LHSI->getOperand(0), Constant::getNullValue(LHSI->getOperand(0)->getType())); break; @@ -3423,7 +3480,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // if A is a power of 2. if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && match(Op1, m_Zero()) && - isKnownToBeAPowerOfTwo(A, false, 0, AC, &I, DT) && I.isEquality()) + isKnownToBeAPowerOfTwo(A, DL, false, 0, AC, &I, DT) && I.isEquality()) return new ICmpInst(I.getInversePredicate(), Builder->CreateAnd(A, B), Op1); @@ -3437,21 +3494,18 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return new ICmpInst(I.getPredicate(), ConstantExpr::getNot(RHSC), A); } - // (a+b) llvm.uadd.with.overflow. - // (a+b) llvm.uadd.with.overflow. - if (I.getPredicate() == ICmpInst::ICMP_ULT && - match(Op0, m_Add(m_Value(A), m_Value(B))) && - (Op1 == A || Op1 == B)) - if (Instruction *R = ProcessUAddIdiom(I, Op0, *this)) - return R; - - // a >u (a+b) --> llvm.uadd.with.overflow. - // b >u (a+b) --> llvm.uadd.with.overflow. - if (I.getPredicate() == ICmpInst::ICMP_UGT && - match(Op1, m_Add(m_Value(A), m_Value(B))) && - (Op0 == A || Op0 == B)) - if (Instruction *R = ProcessUAddIdiom(I, Op1, *this)) - return R; + Instruction *AddI = nullptr; + if (match(&I, m_UAddWithOverflow(m_Value(A), m_Value(B), + m_Instruction(AddI))) && + isa(A->getType())) { + Value *Result; + Constant *Overflow; + if (OptimizeOverflowCheck(OCF_UNSIGNED_ADD, A, B, *AddI, Result, + Overflow)) { + ReplaceInstUsesWith(*AddI, Result); + return ReplaceInstUsesWith(I, Overflow); + } + } // (zext a) * (zext b) --> llvm.umul.with.overflow. if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { @@ -3558,6 +3612,21 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } } + // (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0 + if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) && + match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) { + unsigned TypeBits = Cst1->getBitWidth(); + unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); + if (ShAmt < TypeBits && ShAmt != 0) { + Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); + APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt); + Value *And = Builder->CreateAnd(Xor, Builder->getInt(AndVal), + I.getName() + ".mask"); + return new ICmpInst(I.getPredicate(), And, + Constant::getNullValue(Cst1->getType())); + } + } + // Transform "icmp eq (trunc (lshr(X, cst1)), cst" to // "icmp (and X, mask), cst" uint64_t ShAmt = 0; @@ -3884,6 +3953,19 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { } } + // Test if the FCmpInst instruction is used exclusively by a select as + // part of a minimum or maximum operation. If so, refrain from doing + // any other folding. This helps out other analyses which understand + // non-obfuscated minimum and maximum idioms, such as ScalarEvolution + // and CodeGen. And in this case, at least one of the comparison + // operands has at least one user besides the compare (the select), + // which would often largely negate the benefit of folding anyway. + if (I.hasOneUse()) + if (SelectInst *SI = dyn_cast(*I.user_begin())) + if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || + (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) + return nullptr; + // Handle fcmp with constant RHS if (Constant *RHSC = dyn_cast(Op1)) { if (Instruction *LHSI = dyn_cast(Op0))