X-Git-Url: http://plrg.eecs.uci.edu/git/?p=oota-llvm.git;a=blobdiff_plain;f=lib%2FAnalysis%2FValueTracking.cpp;h=f4824aebe525610a5197aa5bc102e0f88bc8193b;hp=bb4220d6164d5d0e2f3e2261ac69b609119c17dc;hb=18e290023d9be76e117ce4f030306bad3fa9cfea;hpb=ba3f3a65e64fe2cf1f492d90499928270fc1a426 diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index bb4220d6164..f4824aebe52 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ValueTracking.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -366,26 +367,30 @@ static void computeKnownBitsMul(Value *Op0, Value *Op1, bool NSW, } void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, - APInt &KnownZero) { + APInt &KnownZero, + APInt &KnownOne) { unsigned BitWidth = KnownZero.getBitWidth(); unsigned NumRanges = Ranges.getNumOperands() / 2; assert(NumRanges >= 1); - // Use the high end of the ranges to find leading zeros. - unsigned MinLeadingZeros = BitWidth; + KnownZero.setAllBits(); + KnownOne.setAllBits(); + for (unsigned i = 0; i < NumRanges; ++i) { ConstantInt *Lower = mdconst::extract(Ranges.getOperand(2 * i + 0)); ConstantInt *Upper = mdconst::extract(Ranges.getOperand(2 * i + 1)); ConstantRange Range(Lower->getValue(), Upper->getValue()); - if (Range.isWrappedSet()) - MinLeadingZeros = 0; // -1 has no zeros - unsigned LeadingZeros = (Upper->getValue() - 1).countLeadingZeros(); - MinLeadingZeros = std::min(LeadingZeros, MinLeadingZeros); - } - KnownZero = APInt::getHighBitsSet(BitWidth, MinLeadingZeros); + // The first CommonPrefixBits of all values in Range are equal. + unsigned CommonPrefixBits = + (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countLeadingZeros(); + + APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits); + KnownOne &= Range.getUnsignedMax() & Mask; + KnownZero &= ~Range.getUnsignedMax() & Mask; + } } static bool isEphemeralValueOf(Instruction *I, const Value *E) { @@ -1004,9 +1009,18 @@ static void computeKnownBitsFromShiftOperator(Operator *I, // calculation. Reusing the APInts here to prevent unnecessary allocations. KnownZero.clearAllBits(), KnownOne.clearAllBits(); + // If we know the shifter operand is nonzero, we can sometimes infer more + // known bits. However this is expensive to compute, so be lazy about it and + // only compute it when absolutely necessary. + Optional ShifterOperandIsNonZero; + // Early exit if we can't constrain any well-defined shift amount. - if (!(ShiftAmtKZ & (BitWidth-1)) && !(ShiftAmtKO & (BitWidth-1))) - return; + if (!(ShiftAmtKZ & (BitWidth - 1)) && !(ShiftAmtKO & (BitWidth - 1))) { + ShifterOperandIsNonZero = + isKnownNonZero(I->getOperand(1), DL, Depth + 1, Q); + if (!*ShifterOperandIsNonZero) + return; + } computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, DL, Depth + 1, Q); @@ -1018,6 +1032,16 @@ static void computeKnownBitsFromShiftOperator(Operator *I, continue; if ((ShiftAmt | ShiftAmtKO) != ShiftAmt) continue; + // If we know the shifter is nonzero, we may be able to infer more known + // bits. This check is sunk down as far as possible to avoid the expensive + // call to isKnownNonZero if the cheaper checks above fail. + if (ShiftAmt == 0) { + if (!ShifterOperandIsNonZero.hasValue()) + ShifterOperandIsNonZero = + isKnownNonZero(I->getOperand(1), DL, Depth + 1, Q); + if (*ShifterOperandIsNonZero) + continue; + } KnownZero &= KZF(KnownZero2, ShiftAmt); KnownOne &= KOF(KnownOne2, ShiftAmt); @@ -1042,7 +1066,7 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, default: break; case Instruction::Load: if (MDNode *MD = cast(I)->getMetadata(LLVMContext::MD_range)) - computeKnownBitsFromRangeMetadata(*MD, KnownZero); + computeKnownBitsFromRangeMetadata(*MD, KnownZero, KnownOne); break; case Instruction::And: { // If either the LHS or the RHS are Zero, the result is zero. @@ -1434,7 +1458,7 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, case Instruction::Call: case Instruction::Invoke: if (MDNode *MD = cast(I)->getMetadata(LLVMContext::MD_range)) - computeKnownBitsFromRangeMetadata(*MD, KnownZero); + computeKnownBitsFromRangeMetadata(*MD, KnownZero, KnownOne); // If a range metadata is attached to this IntrinsicInst, intersect the // explicit range specified by the metadata and the implicit range of // the intrinsic. @@ -4057,3 +4081,89 @@ ConstantRange llvm::getConstantRangeFromMetadata(MDNode &Ranges) { return CR; } + +/// Return true if "icmp Pred LHS RHS" is always true. +static bool isTruePredicate(CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS) + return true; + + switch (Pred) { + default: + return false; + + case CmpInst::ICMP_SLT: + case CmpInst::ICMP_SLE: { + ConstantInt *CI; + + // LHS s< LHS +_{nsw} C if C > 0 + // LHS s<= LHS +_{nsw} C if C >= 0 + if (match(RHS, m_NSWAdd(m_Specific(LHS), m_ConstantInt(CI)))) { + if (Pred == CmpInst::ICMP_SLT) + return CI->getValue().isStrictlyPositive(); + return !CI->isNegative(); + } + return false; + } + + case CmpInst::ICMP_ULT: + case CmpInst::ICMP_ULE: { + ConstantInt *CI; + + // LHS u< LHS +_{nuw} C if C != 0 + // LHS u<= LHS +_{nuw} C + if (match(RHS, m_NUWAdd(m_Specific(LHS), m_ConstantInt(CI)))) { + if (Pred == CmpInst::ICMP_ULT) + return !CI->isZero(); + return true; + } + return false; + } + } +} + +/// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred +/// ALHS ARHS" is true. +static bool isImpliedCondOperands(CmpInst::Predicate Pred, Value *ALHS, + Value *ARHS, Value *BLHS, Value *BRHS) { + switch (Pred) { + default: + return false; + + case CmpInst::ICMP_SLT: + case CmpInst::ICMP_SLE: + return isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS) && + isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS); + + case CmpInst::ICMP_ULT: + case CmpInst::ICMP_ULE: + return isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS) && + isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS); + } +} + +bool llvm::isImpliedCondition(Value *LHS, Value *RHS) { + assert(LHS->getType() == RHS->getType() && "mismatched type"); + Type *OpTy = LHS->getType(); + assert(OpTy->getScalarType()->isIntegerTy(1)); + + // LHS ==> RHS by definition + if (LHS == RHS) return true; + + if (OpTy->isVectorTy()) + // TODO: extending the code below to handle vectors + return false; + assert(OpTy->isIntegerTy(1) && "implied by above"); + + ICmpInst::Predicate APred, BPred; + Value *ALHS, *ARHS; + Value *BLHS, *BRHS; + + if (!match(LHS, m_ICmp(APred, m_Value(ALHS), m_Value(ARHS))) || + !match(RHS, m_ICmp(BPred, m_Value(BLHS), m_Value(BRHS)))) + return false; + + if (APred == BPred) + return isImpliedCondOperands(APred, ALHS, ARHS, BLHS, BRHS); + + return false; +}