//===----------------------------------------------------------------------===//
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/InstructionSimplify.h"
}
void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
- APInt &KnownZero) {
+ APInt &KnownZero,
+ APInt &KnownOne) {
unsigned BitWidth = KnownZero.getBitWidth();
unsigned NumRanges = Ranges.getNumOperands() / 2;
assert(NumRanges >= 1);
- // Use the high end of the ranges to find leading zeros.
- unsigned MinLeadingZeros = BitWidth;
+ KnownZero.setAllBits();
+ KnownOne.setAllBits();
+
for (unsigned i = 0; i < NumRanges; ++i) {
ConstantInt *Lower =
mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 0));
ConstantInt *Upper =
mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 1));
ConstantRange Range(Lower->getValue(), Upper->getValue());
- if (Range.isWrappedSet())
- MinLeadingZeros = 0; // -1 has no zeros
- unsigned LeadingZeros = (Upper->getValue() - 1).countLeadingZeros();
- MinLeadingZeros = std::min(LeadingZeros, MinLeadingZeros);
- }
- KnownZero = APInt::getHighBitsSet(BitWidth, MinLeadingZeros);
+ // The first CommonPrefixBits of all values in Range are equal.
+ unsigned CommonPrefixBits =
+ (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countLeadingZeros();
+
+ APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits);
+ KnownOne &= Range.getUnsignedMax() & Mask;
+ KnownZero &= ~Range.getUnsignedMax() & Mask;
+ }
}
static bool isEphemeralValueOf(Instruction *I, const Value *E) {
// calculation. Reusing the APInts here to prevent unnecessary allocations.
KnownZero.clearAllBits(), KnownOne.clearAllBits();
+ // If we know the shifter operand is nonzero, we can sometimes infer more
+ // known bits. However this is expensive to compute, so be lazy about it and
+ // only compute it when absolutely necessary.
+ Optional<bool> ShifterOperandIsNonZero;
+
// Early exit if we can't constrain any well-defined shift amount.
- if (!(ShiftAmtKZ & (BitWidth-1)) && !(ShiftAmtKO & (BitWidth-1)))
- return;
+ if (!(ShiftAmtKZ & (BitWidth - 1)) && !(ShiftAmtKO & (BitWidth - 1))) {
+ ShifterOperandIsNonZero =
+ isKnownNonZero(I->getOperand(1), DL, Depth + 1, Q);
+ if (!*ShifterOperandIsNonZero)
+ return;
+ }
computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, DL, Depth + 1, Q);
continue;
if ((ShiftAmt | ShiftAmtKO) != ShiftAmt)
continue;
+ // If we know the shifter is nonzero, we may be able to infer more known
+ // bits. This check is sunk down as far as possible to avoid the expensive
+ // call to isKnownNonZero if the cheaper checks above fail.
+ if (ShiftAmt == 0) {
+ if (!ShifterOperandIsNonZero.hasValue())
+ ShifterOperandIsNonZero =
+ isKnownNonZero(I->getOperand(1), DL, Depth + 1, Q);
+ if (*ShifterOperandIsNonZero)
+ continue;
+ }
KnownZero &= KZF(KnownZero2, ShiftAmt);
KnownOne &= KOF(KnownOne2, ShiftAmt);
default: break;
case Instruction::Load:
if (MDNode *MD = cast<LoadInst>(I)->getMetadata(LLVMContext::MD_range))
- computeKnownBitsFromRangeMetadata(*MD, KnownZero);
+ computeKnownBitsFromRangeMetadata(*MD, KnownZero, KnownOne);
break;
case Instruction::And: {
// If either the LHS or the RHS are Zero, the result is zero.
case Instruction::Call:
case Instruction::Invoke:
if (MDNode *MD = cast<Instruction>(I)->getMetadata(LLVMContext::MD_range))
- computeKnownBitsFromRangeMetadata(*MD, KnownZero);
+ computeKnownBitsFromRangeMetadata(*MD, KnownZero, KnownOne);
// If a range metadata is attached to this IntrinsicInst, intersect the
// explicit range specified by the metadata and the implicit range of
// the intrinsic.
return CR;
}
+
+bool llvm::isImpliedCondition(Value *LHS, Value *RHS) {
+ assert(LHS->getType() == RHS->getType() && "mismatched type");
+ Type *OpTy = LHS->getType();
+ assert(OpTy->getScalarType()->isIntegerTy(1));
+
+ // LHS ==> RHS by definition
+ if (LHS == RHS) return true;
+
+ if (OpTy->isVectorTy())
+ // TODO: extending the code below to handle vectors
+ return false;
+ assert(OpTy->isIntegerTy(1) && "implied by above");
+
+ ICmpInst::Predicate APred, BPred;
+ Value *I;
+ Value *L;
+ ConstantInt *CI;
+ // i +_{nsw} C_{>0} <s L ==> i <s L
+ if (match(LHS, m_ICmp(APred,
+ m_NSWAdd(m_Value(I), m_ConstantInt(CI)),
+ m_Value(L))) &&
+ APred == ICmpInst::ICMP_SLT &&
+ !CI->isNegative() &&
+ match(RHS, m_ICmp(BPred, m_Specific(I), m_Specific(L))) &&
+ BPred == ICmpInst::ICMP_SLT)
+ return true;
+
+ // i +_{nuw} C_{>0} <u L ==> i <u L
+ if (match(LHS, m_ICmp(APred,
+ m_NUWAdd(m_Value(I), m_ConstantInt(CI)),
+ m_Value(L))) &&
+ APred == ICmpInst::ICMP_ULT &&
+ !CI->isNegative() &&
+ match(RHS, m_ICmp(BPred, m_Specific(I), m_Specific(L))) &&
+ BPred == ICmpInst::ICMP_ULT)
+ return true;
+
+ return false;
+}