#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
// Implementation of the SCEV class.
//
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD
void SCEV::dump() const {
print(dbgs());
dbgs() << '\n';
}
-#endif
void SCEV::print(raw_ostream &OS) const {
switch (static_cast<SCEVTypes>(getSCEVType())) {
if (!SC) return false;
// Return true if the value is negative, this matches things like (-42 * V).
- return SC->getValue()->getValue().isNegative();
+ return SC->getAPInt().isNegative();
}
SCEVCouldNotCompute::SCEVCouldNotCompute() :
//===----------------------------------------------------------------------===//
namespace {
- /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
- /// than the complexity of the RHS. This comparator is used to canonicalize
- /// expressions.
- class SCEVComplexityCompare {
- const LoopInfo *const LI;
- public:
- explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {}
-
- // Return true or false if LHS is less than, or at least RHS, respectively.
- bool operator()(const SCEV *LHS, const SCEV *RHS) const {
- return compare(LHS, RHS) < 0;
- }
-
- // Return negative, zero, or positive, if LHS is less than, equal to, or
- // greater than RHS, respectively. A three-way result allows recursive
- // comparisons to be more efficient.
- int compare(const SCEV *LHS, const SCEV *RHS) const {
- // Fast-path: SCEVs are uniqued so we can do a quick equality check.
- if (LHS == RHS)
- return 0;
-
- // Primarily, sort the SCEVs by their getSCEVType().
- unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
- if (LType != RType)
- return (int)LType - (int)RType;
-
- // Aside from the getSCEVType() ordering, the particular ordering
- // isn't very important except that it's beneficial to be consistent,
- // so that (a + b) and (b + a) don't end up as different expressions.
- switch (static_cast<SCEVTypes>(LType)) {
- case scUnknown: {
- const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
- const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
-
- // Sort SCEVUnknown values with some loose heuristics. TODO: This is
- // not as complete as it could be.
- const Value *LV = LU->getValue(), *RV = RU->getValue();
-
- // Order pointer values after integer values. This helps SCEVExpander
- // form GEPs.
- bool LIsPointer = LV->getType()->isPointerTy(),
- RIsPointer = RV->getType()->isPointerTy();
- if (LIsPointer != RIsPointer)
- return (int)LIsPointer - (int)RIsPointer;
-
- // Compare getValueID values.
- unsigned LID = LV->getValueID(),
- RID = RV->getValueID();
- if (LID != RID)
- return (int)LID - (int)RID;
-
- // Sort arguments by their position.
- if (const Argument *LA = dyn_cast<Argument>(LV)) {
- const Argument *RA = cast<Argument>(RV);
- unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
- return (int)LArgNo - (int)RArgNo;
- }
-
- // For instructions, compare their loop depth, and their operand
- // count. This is pretty loose.
- if (const Instruction *LInst = dyn_cast<Instruction>(LV)) {
- const Instruction *RInst = cast<Instruction>(RV);
-
- // Compare loop depths.
- const BasicBlock *LParent = LInst->getParent(),
- *RParent = RInst->getParent();
- if (LParent != RParent) {
- unsigned LDepth = LI->getLoopDepth(LParent),
- RDepth = LI->getLoopDepth(RParent);
- if (LDepth != RDepth)
- return (int)LDepth - (int)RDepth;
- }
-
- // Compare the number of operands.
- unsigned LNumOps = LInst->getNumOperands(),
- RNumOps = RInst->getNumOperands();
- return (int)LNumOps - (int)RNumOps;
- }
+/// SCEVComplexityCompare - Return true if the complexity of the LHS is less
+/// than the complexity of the RHS. This comparator is used to canonicalize
+/// expressions.
+class SCEVComplexityCompare {
+ const LoopInfo *const LI;
+public:
+ explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {}
- return 0;
- }
+ // Return true or false if LHS is less than, or at least RHS, respectively.
+ bool operator()(const SCEV *LHS, const SCEV *RHS) const {
+ return compare(LHS, RHS) < 0;
+ }
- case scConstant: {
- const SCEVConstant *LC = cast<SCEVConstant>(LHS);
- const SCEVConstant *RC = cast<SCEVConstant>(RHS);
-
- // Compare constant values.
- const APInt &LA = LC->getValue()->getValue();
- const APInt &RA = RC->getValue()->getValue();
- unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
- if (LBitWidth != RBitWidth)
- return (int)LBitWidth - (int)RBitWidth;
- return LA.ult(RA) ? -1 : 1;
+ // Return negative, zero, or positive, if LHS is less than, equal to, or
+ // greater than RHS, respectively. A three-way result allows recursive
+ // comparisons to be more efficient.
+ int compare(const SCEV *LHS, const SCEV *RHS) const {
+ // Fast-path: SCEVs are uniqued so we can do a quick equality check.
+ if (LHS == RHS)
+ return 0;
+
+ // Primarily, sort the SCEVs by their getSCEVType().
+ unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
+ if (LType != RType)
+ return (int)LType - (int)RType;
+
+ // Aside from the getSCEVType() ordering, the particular ordering
+ // isn't very important except that it's beneficial to be consistent,
+ // so that (a + b) and (b + a) don't end up as different expressions.
+ switch (static_cast<SCEVTypes>(LType)) {
+ case scUnknown: {
+ const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
+ const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
+
+ // Sort SCEVUnknown values with some loose heuristics. TODO: This is
+ // not as complete as it could be.
+ const Value *LV = LU->getValue(), *RV = RU->getValue();
+
+ // Order pointer values after integer values. This helps SCEVExpander
+ // form GEPs.
+ bool LIsPointer = LV->getType()->isPointerTy(),
+ RIsPointer = RV->getType()->isPointerTy();
+ if (LIsPointer != RIsPointer)
+ return (int)LIsPointer - (int)RIsPointer;
+
+ // Compare getValueID values.
+ unsigned LID = LV->getValueID(),
+ RID = RV->getValueID();
+ if (LID != RID)
+ return (int)LID - (int)RID;
+
+ // Sort arguments by their position.
+ if (const Argument *LA = dyn_cast<Argument>(LV)) {
+ const Argument *RA = cast<Argument>(RV);
+ unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
+ return (int)LArgNo - (int)RArgNo;
}
- case scAddRecExpr: {
- const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
- const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
-
- // Compare addrec loop depths.
- const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
- if (LLoop != RLoop) {
- unsigned LDepth = LLoop->getLoopDepth(),
- RDepth = RLoop->getLoopDepth();
+ // For instructions, compare their loop depth, and their operand
+ // count. This is pretty loose.
+ if (const Instruction *LInst = dyn_cast<Instruction>(LV)) {
+ const Instruction *RInst = cast<Instruction>(RV);
+
+ // Compare loop depths.
+ const BasicBlock *LParent = LInst->getParent(),
+ *RParent = RInst->getParent();
+ if (LParent != RParent) {
+ unsigned LDepth = LI->getLoopDepth(LParent),
+ RDepth = LI->getLoopDepth(RParent);
if (LDepth != RDepth)
return (int)LDepth - (int)RDepth;
}
- // Addrec complexity grows with operand count.
- unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
- if (LNumOps != RNumOps)
- return (int)LNumOps - (int)RNumOps;
+ // Compare the number of operands.
+ unsigned LNumOps = LInst->getNumOperands(),
+ RNumOps = RInst->getNumOperands();
+ return (int)LNumOps - (int)RNumOps;
+ }
- // Lexicographically compare.
- for (unsigned i = 0; i != LNumOps; ++i) {
- long X = compare(LA->getOperand(i), RA->getOperand(i));
- if (X != 0)
- return X;
- }
+ return 0;
+ }
+
+ case scConstant: {
+ const SCEVConstant *LC = cast<SCEVConstant>(LHS);
+ const SCEVConstant *RC = cast<SCEVConstant>(RHS);
- return 0;
+ // Compare constant values.
+ const APInt &LA = LC->getAPInt();
+ const APInt &RA = RC->getAPInt();
+ unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
+ if (LBitWidth != RBitWidth)
+ return (int)LBitWidth - (int)RBitWidth;
+ return LA.ult(RA) ? -1 : 1;
+ }
+
+ case scAddRecExpr: {
+ const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
+ const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
+
+ // Compare addrec loop depths.
+ const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
+ if (LLoop != RLoop) {
+ unsigned LDepth = LLoop->getLoopDepth(),
+ RDepth = RLoop->getLoopDepth();
+ if (LDepth != RDepth)
+ return (int)LDepth - (int)RDepth;
}
- case scAddExpr:
- case scMulExpr:
- case scSMaxExpr:
- case scUMaxExpr: {
- const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
- const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
-
- // Lexicographically compare n-ary expressions.
- unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
- if (LNumOps != RNumOps)
- return (int)LNumOps - (int)RNumOps;
-
- for (unsigned i = 0; i != LNumOps; ++i) {
- if (i >= RNumOps)
- return 1;
- long X = compare(LC->getOperand(i), RC->getOperand(i));
- if (X != 0)
- return X;
- }
+ // Addrec complexity grows with operand count.
+ unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
+ if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
+
+ // Lexicographically compare.
+ for (unsigned i = 0; i != LNumOps; ++i) {
+ long X = compare(LA->getOperand(i), RA->getOperand(i));
+ if (X != 0)
+ return X;
}
- case scUDivExpr: {
- const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
- const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
+ return 0;
+ }
+
+ case scAddExpr:
+ case scMulExpr:
+ case scSMaxExpr:
+ case scUMaxExpr: {
+ const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
+ const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
- // Lexicographically compare udiv expressions.
- long X = compare(LC->getLHS(), RC->getLHS());
+ // Lexicographically compare n-ary expressions.
+ unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
+ if (LNumOps != RNumOps)
+ return (int)LNumOps - (int)RNumOps;
+
+ for (unsigned i = 0; i != LNumOps; ++i) {
+ if (i >= RNumOps)
+ return 1;
+ long X = compare(LC->getOperand(i), RC->getOperand(i));
if (X != 0)
return X;
- return compare(LC->getRHS(), RC->getRHS());
}
+ return (int)LNumOps - (int)RNumOps;
+ }
- case scTruncate:
- case scZeroExtend:
- case scSignExtend: {
- const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
- const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
+ case scUDivExpr: {
+ const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
+ const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
- // Compare cast expressions by operand.
- return compare(LC->getOperand(), RC->getOperand());
- }
+ // Lexicographically compare udiv expressions.
+ long X = compare(LC->getLHS(), RC->getLHS());
+ if (X != 0)
+ return X;
+ return compare(LC->getRHS(), RC->getRHS());
+ }
- case scCouldNotCompute:
- llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
- }
- llvm_unreachable("Unknown SCEV kind!");
+ case scTruncate:
+ case scZeroExtend:
+ case scSignExtend: {
+ const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
+ const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
+
+ // Compare cast expressions by operand.
+ return compare(LC->getOperand(), RC->getOperand());
}
- };
-}
+
+ case scCouldNotCompute:
+ llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
+ }
+ llvm_unreachable("Unknown SCEV kind!");
+ }
+};
+} // end anonymous namespace
/// GroupByComplexity - Given a list of SCEV objects, order them by their
/// complexity, and group objects of the same complexity together by value.
}
}
-namespace {
-struct FindSCEVSize {
- int Size;
- FindSCEVSize() : Size(0) {}
-
- bool follow(const SCEV *S) {
- ++Size;
- // Keep looking at all operands of S.
- return true;
- }
- bool isDone() const {
- return false;
- }
-};
-}
-
// Returns the size of the SCEV S.
static inline int sizeOfSCEV(const SCEV *S) {
+ struct FindSCEVSize {
+ int Size;
+ FindSCEVSize() : Size(0) {}
+
+ bool follow(const SCEV *S) {
+ ++Size;
+ // Keep looking at all operands of S.
+ return true;
+ }
+ bool isDone() const {
+ return false;
+ }
+ };
+
FindSCEVSize F;
SCEVTraversal<FindSCEVSize> ST(F);
ST.visitAll(S);
void visitConstant(const SCEVConstant *Numerator) {
if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
- APInt NumeratorVal = Numerator->getValue()->getValue();
- APInt DenominatorVal = D->getValue()->getValue();
+ APInt NumeratorVal = Numerator->getAPInt();
+ APInt DenominatorVal = D->getAPInt();
uint32_t NumeratorBW = NumeratorVal.getBitWidth();
uint32_t DenominatorBW = DenominatorVal.getBitWidth();
// `Step`:
// 1. NSW/NUW flags on the step increment.
- const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags());
+ auto PreStartFlags =
+ ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW);
+ const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
if (!StartC)
return false;
- APInt StartAI = StartC->getValue()->getValue();
+ APInt StartAI = StartC->getAPInt();
for (unsigned Delta : {-2, -1, 1, 2}) {
const SCEV *PreStart = getConstant(StartAI - Delta);
+ FoldingSetNodeID ID;
+ ID.AddInteger(scAddRecExpr);
+ ID.AddPointer(PreStart);
+ ID.AddPointer(Step);
+ ID.AddPointer(L);
+ void *IP = nullptr;
+ const auto *PreAR =
+ static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
+
// Give up if we don't already have the add recurrence we need because
// actually constructing an add recurrence is relatively expensive.
- const SCEVAddRecExpr *PreAR = [&]() {
- FoldingSetNodeID ID;
- ID.AddInteger(scAddRecExpr);
- ID.AddPointer(PreStart);
- ID.AddPointer(Step);
- ID.AddPointer(L);
- void *IP = nullptr;
- return static_cast<SCEVAddRecExpr *>(
- this->UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
- }();
-
if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
}
}
+ if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
+ // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
+ if (SA->getNoWrapFlags(SCEV::FlagNUW)) {
+ // If the addition does not unsign overflow then we can, by definition,
+ // commute the zero extension with the addition operation.
+ SmallVector<const SCEV *, 4> Ops;
+ for (const auto *Op : SA->operands())
+ Ops.push_back(getZeroExtendExpr(Op, Ty));
+ return getAddExpr(Ops, SCEV::FlagNUW);
+ }
+ }
+
// The cast wasn't folded; create an explicit cast node.
// Recompute the insert position, as it may have been invalidated.
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
auto *SMul = dyn_cast<SCEVMulExpr>(SA->getOperand(1));
if (SMul && SC1) {
if (auto *SC2 = dyn_cast<SCEVConstant>(SMul->getOperand(0))) {
- const APInt &C1 = SC1->getValue()->getValue();
- const APInt &C2 = SC2->getValue()->getValue();
+ const APInt &C1 = SC1->getAPInt();
+ const APInt &C2 = SC2->getAPInt();
if (C1.isStrictlyPositive() && C2.isStrictlyPositive() &&
C2.ugt(C1) && C2.isPowerOf2())
return getAddExpr(getSignExtendExpr(SC1, Ty),
auto *SC1 = dyn_cast<SCEVConstant>(Start);
auto *SC2 = dyn_cast<SCEVConstant>(Step);
if (SC1 && SC2) {
- const APInt &C1 = SC1->getValue()->getValue();
- const APInt &C2 = SC2->getValue()->getValue();
+ const APInt &C1 = SC1->getAPInt();
+ const APInt &C2 = SC2->getAPInt();
if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) &&
C2.isPowerOf2()) {
Start = getSignExtendExpr(Start, Ty);
// Sign-extend negative constants.
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
- if (SC->getValue()->getValue().isNegative())
+ if (SC->getAPInt().isNegative())
return getSignExtendExpr(Op, Ty);
// Peel off a truncate cast.
// Pull a buried constant out to the outside.
if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
Interesting = true;
- AccumulatedConstant += Scale * C->getValue()->getValue();
+ AccumulatedConstant += Scale * C->getAPInt();
}
// Next comes everything else. We're especially interested in multiplies
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
APInt NewScale =
- Scale * cast<SCEVConstant>(Mul->getOperand(0))->getValue()->getValue();
+ Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
// A multiplication of a constant with another add; recurse.
const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
return Interesting;
}
-namespace {
- struct APIntCompare {
- bool operator()(const APInt &LHS, const APInt &RHS) const {
- return LHS.ult(RHS);
- }
- };
-}
-
// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
// can't-overflow flags for the operation if possible.
ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
// If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
- auto IsKnownNonNegative =
- std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1);
+ auto IsKnownNonNegative = [&](const SCEV *S) {
+ return SE->isKnownNonNegative(S);
+ };
- if (SignOrUnsignWrap == SCEV::FlagNSW &&
- std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative))
+ if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
Flags =
ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
// (A + C) --> (A + C)<nsw> if the addition does not sign overflow
// (A + C) --> (A + C)<nuw> if the addition does not unsign overflow
- const APInt &C = cast<SCEVConstant>(Ops[0])->getValue()->getValue();
+ const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
auto NSWRegion =
ConstantRange::makeNoWrapRegion(Instruction::Add, C, OBO::NoSignedWrap);
assert(Idx < Ops.size());
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together!
- Ops[0] = getConstant(LHSC->getValue()->getValue() +
- RHSC->getValue()->getValue());
+ Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
if (Ops.size() == 2) return Ops[0];
Ops.erase(Ops.begin()+1); // Erase the folded element
LHSC = cast<SCEVConstant>(Ops[0]);
if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
Ops.data(), Ops.size(),
APInt(BitWidth, 1), *this)) {
+ struct APIntCompare {
+ bool operator()(const APInt &LHS, const APInt &RHS) const {
+ return LHS.ult(RHS);
+ }
+ };
+
// Some interesting folding opportunity is present, so its worthwhile to
// re-generate the operands list. Group the operands by constant scale,
// to avoid multiplying by the same constant scale multiple times.
std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
- for (SmallVectorImpl<const SCEV *>::const_iterator I = NewOps.begin(),
- E = NewOps.end(); I != E; ++I)
- MulOpLists[M.find(*I)->second].push_back(*I);
+ for (const SCEV *NewOp : NewOps)
+ MulOpLists[M.find(NewOp)->second].push_back(NewOp);
// Re-generate the operands list.
Ops.clear();
if (AccumulatedConstant != 0)
Ops.push_back(getConstant(AccumulatedConstant));
- for (std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare>::iterator
- I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I)
- if (I->first != 0)
- Ops.push_back(getMulExpr(getConstant(I->first),
- getAddExpr(I->second)));
+ for (auto &MulOp : MulOpLists)
+ if (MulOp.first != 0)
+ Ops.push_back(getMulExpr(getConstant(MulOp.first),
+ getAddExpr(MulOp.second)));
if (Ops.empty())
return getZero(Ty);
if (Ops.size() == 1)
AddRec->op_end());
for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
++OtherIdx)
- if (const SCEVAddRecExpr *OtherAddRec =
- dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
+ if (const auto *OtherAddRec = dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
if (OtherAddRec->getLoop() == AddRecLoop) {
for (unsigned i = 0, e = OtherAddRec->getNumOperands();
i != e; ++i) {
++Idx;
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together!
- ConstantInt *Fold = ConstantInt::get(getContext(),
- LHSC->getValue()->getValue() *
- RHSC->getValue()->getValue());
+ ConstantInt *Fold =
+ ConstantInt::get(getContext(), LHSC->getAPInt() * RHSC->getAPInt());
Ops[0] = getConstant(Fold);
Ops.erase(Ops.begin()+1); // Erase the folded element
if (Ops.size() == 1) return Ops[0];
if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
SmallVector<const SCEV *, 4> NewOps;
bool AnyFolded = false;
- for (SCEVAddRecExpr::op_iterator I = Add->op_begin(),
- E = Add->op_end(); I != E; ++I) {
- const SCEV *Mul = getMulExpr(Ops[0], *I);
+ for (const SCEV *AddOp : Add->operands()) {
+ const SCEV *Mul = getMulExpr(Ops[0], AddOp);
if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
NewOps.push_back(Mul);
}
} else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
// Negation preserves a recurrence's no self-wrap property.
SmallVector<const SCEV *, 4> Operands;
- for (SCEVAddRecExpr::op_iterator I = AddRec->op_begin(),
- E = AddRec->op_end(); I != E; ++I) {
- Operands.push_back(getMulExpr(Ops[0], *I));
- }
+ for (const SCEV *AddRecOp : AddRec->operands())
+ Operands.push_back(getMulExpr(Ops[0], AddRecOp));
+
return getAddRecExpr(Operands, AddRec->getLoop(),
AddRec->getNoWrapFlags(SCEV::FlagNW));
}
// its operands.
// TODO: Generalize this to non-constants by using known-bits information.
Type *Ty = LHS->getType();
- unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
+ unsigned LZ = RHSC->getAPInt().countLeadingZeros();
unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
// For non-power-of-two values, effectively round the value up to the
// nearest power of two.
- if (!RHSC->getValue()->getValue().isPowerOf2())
+ if (!RHSC->getAPInt().isPowerOf2())
++MaxShiftAmt;
IntegerType *ExtTy =
IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
if (const SCEVConstant *Step =
dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
// {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
- const APInt &StepInt = Step->getValue()->getValue();
- const APInt &DivInt = RHSC->getValue()->getValue();
+ const APInt &StepInt = Step->getAPInt();
+ const APInt &DivInt = RHSC->getAPInt();
if (!StepInt.urem(DivInt) &&
getZeroExtendExpr(AR, ExtTy) ==
getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
getZeroExtendExpr(Step, ExtTy),
AR->getLoop(), SCEV::FlagAnyWrap)) {
- const APInt &StartInt = StartC->getValue()->getValue();
+ const APInt &StartInt = StartC->getAPInt();
const APInt &StartRem = StartInt.urem(StepInt);
if (StartRem != 0)
LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step,
}
static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
- APInt A = C1->getValue()->getValue().abs();
- APInt B = C2->getValue()->getValue().abs();
+ APInt A = C1->getAPInt().abs();
+ APInt B = C2->getAPInt().abs();
uint32_t ABW = A.getBitWidth();
uint32_t BBW = B.getBitWidth();
// check.
APInt Factor = gcd(LHSCst, RHSCst);
if (!Factor.isIntN(1)) {
- LHSCst = cast<SCEVConstant>(
- getConstant(LHSCst->getValue()->getValue().udiv(Factor)));
- RHSCst = cast<SCEVConstant>(
- getConstant(RHSCst->getValue()->getValue().udiv(Factor)));
+ LHSCst =
+ cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
+ RHSCst =
+ cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
SmallVector<const SCEV *, 2> Operands;
Operands.push_back(LHSCst);
Operands.append(Mul->op_begin() + 1, Mul->op_end());
// AddRecs require their operands be loop-invariant with respect to their
// loops. Don't perform this transformation if it would break this
// requirement.
- bool AllInvariant =
- std::all_of(Operands.begin(), Operands.end(),
- [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
+ bool AllInvariant = all_of(
+ Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
if (AllInvariant) {
// Create a recurrence for the outer loop with the same step size.
maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
- AllInvariant = std::all_of(
- NestedOperands.begin(), NestedOperands.end(),
- [&](const SCEV *Op) { return isLoopInvariant(Op, NestedLoop); });
+ AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
+ return isLoopInvariant(Op, NestedLoop);
+ });
if (AllInvariant) {
// Ok, both add recurrences are valid after the transformation.
assert(Idx < Ops.size());
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together!
- ConstantInt *Fold = ConstantInt::get(getContext(),
- APIntOps::smax(LHSC->getValue()->getValue(),
- RHSC->getValue()->getValue()));
+ ConstantInt *Fold = ConstantInt::get(
+ getContext(), APIntOps::smax(LHSC->getAPInt(), RHSC->getAPInt()));
Ops[0] = getConstant(Fold);
Ops.erase(Ops.begin()+1); // Erase the folded element
if (Ops.size() == 1) return Ops[0];
assert(Idx < Ops.size());
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together!
- ConstantInt *Fold = ConstantInt::get(getContext(),
- APIntOps::umax(LHSC->getValue()->getValue(),
- RHSC->getValue()->getValue()));
+ ConstantInt *Fold = ConstantInt::get(
+ getContext(), APIntOps::umax(LHSC->getAPInt(), RHSC->getAPInt()));
Ops[0] = getConstant(Fold);
Ops.erase(Ops.begin()+1); // Erase the folded element
if (Ops.size() == 1) return Ops[0];
// We can bypass creating a target-independent
// constant expression and then folding it back into a ConstantInt.
// This is just a compile-time optimization.
- return getConstant(IntTy,
- F.getParent()->getDataLayout().getTypeAllocSize(AllocTy));
+ return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
}
const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
// constant expression and then folding it back into a ConstantInt.
// This is just a compile-time optimization.
return getConstant(
- IntTy,
- F.getParent()->getDataLayout().getStructLayout(STy)->getElementOffset(
- FieldNo));
+ IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
}
const SCEV *ScalarEvolution::getUnknown(Value *V) {
/// for which isSCEVable must return true.
uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
assert(isSCEVable(Ty) && "Type is not SCEVable!");
- return F.getParent()->getDataLayout().getTypeSizeInBits(Ty);
+ return getDataLayout().getTypeSizeInBits(Ty);
}
/// getEffectiveSCEVType - Return a type with the same bitwidth as
// The only other support type is pointer.
assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
- return F.getParent()->getDataLayout().getIntPtrType(Ty);
+ return getDataLayout().getIntPtrType(Ty);
}
const SCEV *ScalarEvolution::getCouldNotCompute() {
return CouldNotCompute.get();
}
-namespace {
+
+bool ScalarEvolution::checkValidity(const SCEV *S) const {
// Helper class working with SCEVTraversal to figure out if a SCEV contains
// a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne
// is set iff if find such SCEVUnknown.
}
bool isDone() const { return FindOne; }
};
-}
-bool ScalarEvolution::checkValidity(const SCEV *S) const {
FindInvalidSCEVUnknown F;
SCEVTraversal<FindInvalidSCEVUnknown> ST(F);
ST.visitAll(S);
return getPointerBase(Cast->getOperand());
} else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
const SCEV *PtrOp = nullptr;
- for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
- I != E; ++I) {
- if ((*I)->getType()->isPointerTy()) {
+ for (const SCEV *NAryOp : NAry->operands()) {
+ if (NAryOp->getType()->isPointerTy()) {
// Cannot find the base of an expression with multiple pointer operands.
if (PtrOp)
return V;
- PtrOp = *I;
+ PtrOp = NAryOp;
}
}
if (!PtrOp)
}
}
+namespace {
+class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
+public:
+ static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
+ ScalarEvolution &SE) {
+ SCEVInitRewriter Rewriter(L, SE);
+ const SCEV *Result = Rewriter.visit(Scev);
+ return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
+ }
+
+ SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
+ : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant))
+ Valid = false;
+ return Expr;
+ }
+
+ const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+ // Only allow AddRecExprs for this loop.
+ if (Expr->getLoop() == L)
+ return Expr->getStart();
+ Valid = false;
+ return Expr;
+ }
+
+ bool isValid() { return Valid; }
+
+private:
+ const Loop *L;
+ bool Valid;
+};
+
+class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
+public:
+ static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
+ ScalarEvolution &SE) {
+ SCEVShiftRewriter Rewriter(L, SE);
+ const SCEV *Result = Rewriter.visit(Scev);
+ return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
+ }
+
+ SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
+ : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ // Only allow AddRecExprs for this loop.
+ if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant))
+ Valid = false;
+ return Expr;
+ }
+
+ const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+ if (Expr->getLoop() == L && Expr->isAffine())
+ return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
+ Valid = false;
+ return Expr;
+ }
+ bool isValid() { return Valid; }
+
+private:
+ const Loop *L;
+ bool Valid;
+};
+} // end anonymous namespace
+
const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
const Loop *L = LI.getLoopFor(PN->getParent());
if (!L || L->getHeader() != PN->getParent())
return PHISCEV;
}
}
- } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(BEValue)) {
+ } else {
// Otherwise, this could be a loop like this:
// i = 0; for (j = 1; ..; ++j) { .... i = j; }
// In this case, j = {1,+,1} and BEValue is j.
// Because the other in-value of i (0) fits the evolution of BEValue
// i really is an addrec evolution.
- if (AddRec->getLoop() == L && AddRec->isAffine()) {
+ //
+ // We can generalize this saying that i is the shifted value of BEValue
+ // by one iteration:
+ // PHI(f(0), f({1,+,1})) --> f({0,+,1})
+ const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
+ const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this);
+ if (Shifted != getCouldNotCompute() &&
+ Start != getCouldNotCompute()) {
const SCEV *StartVal = getSCEV(StartValueV);
-
- // If StartVal = j.start - j.stride, we can use StartVal as the
- // initial step of the addrec evolution.
- if (StartVal ==
- getMinusSCEV(AddRec->getOperand(0), AddRec->getOperand(1))) {
- // FIXME: For constant StartVal, we should be able to infer
- // no-wrap flags.
- const SCEV *PHISCEV = getAddRecExpr(StartVal, AddRec->getOperand(1),
- L, SCEV::FlagAnyWrap);
-
+ if (Start == StartVal) {
// Okay, for the entire analysis of this edge we assumed the PHI
// to be symbolic. We now need to go back and purge all of the
// entries for the scalars that use the symbolic expression.
ForgetSymbolicName(PN, SymbolicName);
- ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
- return PHISCEV;
+ ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted;
+ return Shifted;
}
}
}
switch (S->getSCEVType()) {
case scConstant: case scTruncate: case scZeroExtend: case scSignExtend:
case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr:
- // These expressions are available if their operand(s) is/are.
- return true;
+ // These expressions are available if their operand(s) is/are.
+ return true;
case scAddRecExpr: {
// We allow add recurrences that are on the loop BB is in, or some
if (PN->getNumIncomingValues() == 2) {
const Loop *L = LI.getLoopFor(PN->getParent());
+ // We don't want to break LCSSA, even in a SCEV expression tree.
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
+ if (LI.getLoopFor(PN->getIncomingBlock(i)) != L)
+ return nullptr;
+
// Try to match
//
// br %cond, label %left, label %right
// PHI's incoming blocks are in a different loop, in which case doing so
// risks breaking LCSSA form. Instcombine would normally zap these, but
// it doesn't have DominatorTree information, so it may miss cases.
- if (Value *V = SimplifyInstruction(PN, F.getParent()->getDataLayout(), &TLI,
- &DT, &AC))
+ if (Value *V = SimplifyInstruction(PN, getDataLayout(), &TLI, &DT, &AC))
if (LI.replacementPreservesLCSSAForm(PN, V))
return getSCEV(V);
uint32_t
ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
- return C->getValue()->getValue().countTrailingZeros();
+ return C->getAPInt().countTrailingZeros();
if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
return std::min(GetMinTrailingZeros(T->getOperand()),
// For a SCEVUnknown, ask ValueTracking.
unsigned BitWidth = getTypeSizeInBits(U->getType());
APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
- computeKnownBits(U->getValue(), Zeros, Ones, F.getParent()->getDataLayout(),
- 0, &AC, nullptr, &DT);
+ computeKnownBits(U->getValue(), Zeros, Ones, getDataLayout(), 0, &AC,
+ nullptr, &DT);
return Zeros.countTrailingOnes();
}
/// GetRangeFromMetadata - Helper method to assign a range to V from
/// metadata present in the IR.
static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
- if (Instruction *I = dyn_cast<Instruction>(V)) {
- if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) {
- ConstantRange TotalRange(
- cast<IntegerType>(I->getType())->getBitWidth(), false);
-
- unsigned NumRanges = MD->getNumOperands() / 2;
- assert(NumRanges >= 1);
-
- for (unsigned i = 0; i < NumRanges; ++i) {
- ConstantInt *Lower =
- mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 0));
- ConstantInt *Upper =
- mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 1));
- ConstantRange Range(Lower->getValue(), Upper->getValue());
- TotalRange = TotalRange.unionWith(Range);
- }
-
- return TotalRange;
- }
- }
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
+ return getConstantRangeFromMetadata(*MD);
return None;
}
return I->second;
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
- return setRange(C, SignHint, ConstantRange(C->getValue()->getValue()));
+ return setRange(C, SignHint, ConstantRange(C->getAPInt()));
unsigned BitWidth = getTypeSizeInBits(S->getType());
ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
if (AddRec->getNoWrapFlags(SCEV::FlagNUW))
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
if (!C->getValue()->isZero())
- ConservativeResult =
- ConservativeResult.intersectWith(
- ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0)));
+ ConservativeResult = ConservativeResult.intersectWith(
+ ConstantRange(C->getAPInt(), APInt(BitWidth, 0)));
// If there's no signed wrap, and all the operands have the same sign or
// zero, the value won't ever change sign.
// Split here to avoid paying the compile-time cost of calling both
// computeKnownBits and ComputeNumSignBits. This restriction can be lifted
// if needed.
- const DataLayout &DL = F.getParent()->getDataLayout();
+ const DataLayout &DL = getDataLayout();
if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
// For a SCEVUnknown, ask ValueTracking.
APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
unsigned TZ = A.countTrailingZeros();
unsigned BitWidth = A.getBitWidth();
APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
- computeKnownBits(U->getOperand(0), KnownZero, KnownOne,
- F.getParent()->getDataLayout(), 0, &AC, nullptr, &DT);
+ computeKnownBits(U->getOperand(0), KnownZero, KnownOne, getDataLayout(),
+ 0, &AC, nullptr, &DT);
APInt EffectiveMask =
APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
return ItCnt;
}
+ ExitLimit ShiftEL = computeShiftCompareExitLimit(
+ ExitCond->getOperand(0), ExitCond->getOperand(1), L, Cond);
+ if (ShiftEL.hasAnyInfo())
+ return ShiftEL;
+
const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
if (AddRec->getLoop() == L) {
// Form the constant range.
ConstantRange CompRange(
- ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue()));
+ ICmpInst::makeConstantRange(Cond, RHSC->getAPInt()));
const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
break;
}
default:
-#if 0
- dbgs() << "computeBackedgeTakenCount ";
- if (ExitCond->getOperand(0)->getType()->isUnsigned())
- dbgs() << "[unsigned] ";
- dbgs() << *LHS << " "
- << Instruction::getOpcodeName(Instruction::ICmp)
- << " " << *RHS << "\n";
-#endif
break;
}
return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
Result = ConstantExpr::getICmp(predicate, Result, RHS);
if (!isa<ConstantInt>(Result)) break; // Couldn't decide for sure
if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
-#if 0
- dbgs() << "\n***\n*** Computed loop count " << *ItCst
- << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
- << "***\n";
-#endif
++NumArrayLenItCounts;
return getConstant(ItCst); // Found terminating iteration!
}
return getCouldNotCompute();
}
+ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
+ Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
+ ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
+ if (!RHS)
+ return getCouldNotCompute();
+
+ const BasicBlock *Latch = L->getLoopLatch();
+ if (!Latch)
+ return getCouldNotCompute();
+
+ const BasicBlock *Predecessor = L->getLoopPredecessor();
+ if (!Predecessor)
+ return getCouldNotCompute();
+
+ // Return true if V is of the form "LHS `shift_op` <positive constant>".
+ // Return LHS in OutLHS and shift_opt in OutOpCode.
+ auto MatchPositiveShift =
+ [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
+
+ using namespace PatternMatch;
+
+ ConstantInt *ShiftAmt;
+ if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
+ OutOpCode = Instruction::LShr;
+ else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
+ OutOpCode = Instruction::AShr;
+ else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
+ OutOpCode = Instruction::Shl;
+ else
+ return false;
+
+ return ShiftAmt->getValue().isStrictlyPositive();
+ };
+
+ // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
+ //
+ // loop:
+ // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
+ // %iv.shifted = lshr i32 %iv, <positive constant>
+ //
+ // Return true on a succesful match. Return the corresponding PHI node (%iv
+ // above) in PNOut and the opcode of the shift operation in OpCodeOut.
+ auto MatchShiftRecurrence =
+ [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
+ Optional<Instruction::BinaryOps> PostShiftOpCode;
+
+ {
+ Instruction::BinaryOps OpC;
+ Value *V;
+
+ // If we encounter a shift instruction, "peel off" the shift operation,
+ // and remember that we did so. Later when we inspect %iv's backedge
+ // value, we will make sure that the backedge value uses the same
+ // operation.
+ //
+ // Note: the peeled shift operation does not have to be the same
+ // instruction as the one feeding into the PHI's backedge value. We only
+ // really care about it being the same *kind* of shift instruction --
+ // that's all that is required for our later inferences to hold.
+ if (MatchPositiveShift(LHS, V, OpC)) {
+ PostShiftOpCode = OpC;
+ LHS = V;
+ }
+ }
+
+ PNOut = dyn_cast<PHINode>(LHS);
+ if (!PNOut || PNOut->getParent() != L->getHeader())
+ return false;
+
+ Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
+ Value *OpLHS;
+
+ return
+ // The backedge value for the PHI node must be a shift by a positive
+ // amount
+ MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
+
+ // of the PHI node itself
+ OpLHS == PNOut &&
+
+ // and the kind of shift should be match the kind of shift we peeled
+ // off, if any.
+ (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut);
+ };
+
+ PHINode *PN;
+ Instruction::BinaryOps OpCode;
+ if (!MatchShiftRecurrence(LHS, PN, OpCode))
+ return getCouldNotCompute();
+
+ const DataLayout &DL = getDataLayout();
+
+ // The key rationale for this optimization is that for some kinds of shift
+ // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
+ // within a finite number of iterations. If the condition guarding the
+ // backedge (in the sense that the backedge is taken if the condition is true)
+ // is false for the value the shift recurrence stabilizes to, then we know
+ // that the backedge is taken only a finite number of times.
+
+ ConstantInt *StableValue = nullptr;
+ switch (OpCode) {
+ default:
+ llvm_unreachable("Impossible case!");
+
+ case Instruction::AShr: {
+ // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
+ // bitwidth(K) iterations.
+ Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
+ bool KnownZero, KnownOne;
+ ComputeSignBit(FirstValue, KnownZero, KnownOne, DL, 0, nullptr,
+ Predecessor->getTerminator(), &DT);
+ auto *Ty = cast<IntegerType>(RHS->getType());
+ if (KnownZero)
+ StableValue = ConstantInt::get(Ty, 0);
+ else if (KnownOne)
+ StableValue = ConstantInt::get(Ty, -1, true);
+ else
+ return getCouldNotCompute();
+
+ break;
+ }
+ case Instruction::LShr:
+ case Instruction::Shl:
+ // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
+ // stabilize to 0 in at most bitwidth(K) iterations.
+ StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
+ break;
+ }
+
+ auto *Result =
+ ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
+ assert(Result->getType()->isIntegerTy(1) &&
+ "Otherwise cannot be an operand to a branch instruction");
+
+ if (Result->isZeroValue()) {
+ unsigned BitWidth = getTypeSizeInBits(RHS->getType());
+ const SCEV *UpperBound =
+ getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
+ return ExitLimit(getCouldNotCompute(), UpperBound);
+ }
+
+ return getCouldNotCompute();
+}
/// CanConstantFold - Return true if we can constant fold an instruction of the
/// specified type, assuming that all operands were constants.
// Otherwise, we can evaluate this instruction if all of its operands are
// constant or derived from a PHI node themselves.
PHINode *PHI = nullptr;
- for (Instruction::op_iterator OpI = UseInst->op_begin(),
- OpE = UseInst->op_end(); OpI != OpE; ++OpI) {
+ for (Value *Op : UseInst->operands()) {
+ if (isa<Constant>(Op)) continue;
- if (isa<Constant>(*OpI)) continue;
-
- Instruction *OpInst = dyn_cast<Instruction>(*OpI);
+ Instruction *OpInst = dyn_cast<Instruction>(Op);
if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
PHINode *P = dyn_cast<PHINode>(OpInst);
TLI);
}
+
+// If every incoming value to PN except the one for BB is a specific Constant,
+// return that, else return nullptr.
+static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) {
+ Constant *IncomingVal = nullptr;
+
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+ if (PN->getIncomingBlock(i) == BB)
+ continue;
+
+ auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
+ if (!CurrentVal)
+ return nullptr;
+
+ if (IncomingVal != CurrentVal) {
+ if (IncomingVal)
+ return nullptr;
+ IncomingVal = CurrentVal;
+ }
+ }
+
+ return IncomingVal;
+}
+
/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
/// in the header of its containing loop, we know the loop executes a
/// constant number of times, and the PHI node is just a recurrence
if (!Latch)
return nullptr;
- // Since the loop has one latch, the PHI node must have two entries. One
- // entry must be a constant (coming in from outside of the loop), and the
- // second must be derived from the same PHI.
-
- BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0)
- ? PN->getIncomingBlock(1)
- : PN->getIncomingBlock(0);
-
- assert(PN->getNumIncomingValues() == 2 && "Follows from having one latch!");
-
- // Note: not all PHI nodes in the same block have to have their incoming
- // values in the same order, so we use the basic block to look up the incoming
- // value, not an index.
-
for (auto &I : *Header) {
PHINode *PHI = dyn_cast<PHINode>(&I);
if (!PHI) break;
- auto *StartCST =
- dyn_cast<Constant>(PHI->getIncomingValueForBlock(NonLatch));
+ auto *StartCST = getOtherIncomingValue(PHI, Latch);
if (!StartCST) continue;
CurrentIterVals[PHI] = StartCST;
}
unsigned NumIterations = BEs.getZExtValue(); // must be in range
unsigned IterationNum = 0;
- const DataLayout &DL = F.getParent()->getDataLayout();
+ const DataLayout &DL = getDataLayout();
for (; ; ++IterationNum) {
if (IterationNum == NumIterations)
return RetVal = CurrentIterVals[PN]; // Got exit value!
BasicBlock *Latch = L->getLoopLatch();
assert(Latch && "Should follow from NumIncomingValues == 2!");
- // NonLatch is the preheader, or something equivalent.
- BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0)
- ? PN->getIncomingBlock(1)
- : PN->getIncomingBlock(0);
-
- // Note: not all PHI nodes in the same block have to have their incoming
- // values in the same order, so we use the basic block to look up the incoming
- // value, not an index.
-
for (auto &I : *Header) {
PHINode *PHI = dyn_cast<PHINode>(&I);
if (!PHI)
break;
- auto *StartCST =
- dyn_cast<Constant>(PHI->getIncomingValueForBlock(NonLatch));
+ auto *StartCST = getOtherIncomingValue(PHI, Latch);
if (!StartCST) continue;
CurrentIterVals[PHI] = StartCST;
}
// the loop symbolically to determine when the condition gets a value of
// "ExitWhen".
unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
- const DataLayout &DL = F.getParent()->getDataLayout();
+ const DataLayout &DL = getDataLayout();
for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
auto *CondVal = dyn_cast_or_null<ConstantInt>(
EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
/// In the case that a relevant loop exit value cannot be computed, the
/// original value V is returned.
const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
+ SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
+ ValuesAtScopes[V];
// Check to see if we've folded this expression at this loop before.
- SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = ValuesAtScopes[V];
- for (unsigned u = 0; u < Values.size(); u++) {
- if (Values[u].first == L)
- return Values[u].second ? Values[u].second : V;
- }
- Values.push_back(std::make_pair(L, static_cast<const SCEV *>(nullptr)));
+ for (auto &LS : Values)
+ if (LS.first == L)
+ return LS.second ? LS.second : V;
+
+ Values.emplace_back(L, nullptr);
+
// Otherwise compute it.
const SCEV *C = computeSCEVAtScope(V, L);
- SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values2 = ValuesAtScopes[V];
- for (unsigned u = Values2.size(); u > 0; u--) {
- if (Values2[u - 1].first == L) {
- Values2[u - 1].second = C;
+ for (auto &LS : reverse(ValuesAtScopes[V]))
+ if (LS.first == L) {
+ LS.second = C;
break;
}
- }
return C;
}
// Okay, we know how many times the containing loop executes. If
// this is a constant evolving PHI node, get the final value at
// the specified iteration number.
- Constant *RV = getConstantEvolutionLoopExitValue(PN,
- BTCC->getValue()->getValue(),
- LI);
+ Constant *RV =
+ getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), LI);
if (RV) return getSCEV(RV);
}
}
// Check to see if getSCEVAtScope actually made an improvement.
if (MadeImprovement) {
Constant *C = nullptr;
- const DataLayout &DL = F.getParent()->getDataLayout();
+ const DataLayout &DL = getDataLayout();
if (const CmpInst *CI = dyn_cast<CmpInst>(I))
C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
Operands[1], DL, &TLI);
return std::make_pair(CNC, CNC);
}
- uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
- const APInt &L = LC->getValue()->getValue();
- const APInt &M = MC->getValue()->getValue();
- const APInt &N = NC->getValue()->getValue();
+ uint32_t BitWidth = LC->getAPInt().getBitWidth();
+ const APInt &L = LC->getAPInt();
+ const APInt &M = MC->getAPInt();
+ const APInt &N = NC->getAPInt();
APInt Two(BitWidth, 2);
APInt Four(BitWidth, 4);
const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
if (R1 && R2) {
-#if 0
- dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1
- << " sol#2: " << *R2 << "\n";
-#endif
// Pick the smallest positive root value.
if (ConstantInt *CB =
dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT,
// For negative steps (counting down to zero):
// N = Start/-Step
// First compute the unsigned distance from zero in the direction of Step.
- bool CountDown = StepC->getValue()->getValue().isNegative();
+ bool CountDown = StepC->getAPInt().isNegative();
const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
// Handle unitary steps, which cannot wraparound.
// done by counting and comparing the number of trailing zeros of Step and
// Distance.
if (!CountDown) {
- const APInt &StepV = StepC->getValue()->getValue();
+ const APInt &StepV = StepC->getAPInt();
// StepV.isPowerOf2() returns true if StepV is an positive power of two. It
// also returns true if StepV is maximally negative (eg, INT_MIN), but that
// case is not handled as this code is guarded by !CountDown.
// Then, try to solve the above equation provided that Start is constant.
if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
- return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
- -StartC->getValue()->getValue(),
+ return SolveLinEquationWithOverflow(StepC->getAPInt(), -StartC->getAPInt(),
*this);
return getCouldNotCompute();
}
// If there's a constant operand, canonicalize comparisons with boundary
// cases, and canonicalize *-or-equal comparisons to regular comparisons.
if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
- const APInt &RA = RC->getValue()->getValue();
+ const APInt &RA = RC->getAPInt();
switch (Pred) {
default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
case ICmpInst::ICMP_EQ:
Pred = ICmpInst::ICMP_ULT;
Changed = true;
} else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
- LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
- SCEV::FlagNUW);
+ LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
Pred = ICmpInst::ICMP_ULT;
Changed = true;
}
break;
case ICmpInst::ICMP_UGE:
if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
- RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
- SCEV::FlagNUW);
+ RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
Pred = ICmpInst::ICMP_UGT;
Changed = true;
} else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
return false;
}
+bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
+ const SCEV *LHS,
+ const SCEV *RHS) {
+
+ // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.
+ // Return Y via OutY.
+ auto MatchBinaryAddToConst =
+ [this](const SCEV *Result, const SCEV *X, APInt &OutY,
+ SCEV::NoWrapFlags ExpectedFlags) {
+ const SCEV *NonConstOp, *ConstOp;
+ SCEV::NoWrapFlags FlagsPresent;
+
+ if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) ||
+ !isa<SCEVConstant>(ConstOp) || NonConstOp != X)
+ return false;
+
+ OutY = cast<SCEVConstant>(ConstOp)->getAPInt();
+ return (FlagsPresent & ExpectedFlags) == ExpectedFlags;
+ };
+
+ APInt C;
+
+ switch (Pred) {
+ default:
+ break;
+
+ case ICmpInst::ICMP_SGE:
+ std::swap(LHS, RHS);
+ case ICmpInst::ICMP_SLE:
+ // X s<= (X + C)<nsw> if C >= 0
+ if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative())
+ return true;
+
+ // (X + C)<nsw> s<= X if C <= 0
+ if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
+ !C.isStrictlyPositive())
+ return true;
+ break;
+
+ case ICmpInst::ICMP_SGT:
+ std::swap(LHS, RHS);
+ case ICmpInst::ICMP_SLT:
+ // X s< (X + C)<nsw> if C > 0
+ if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) &&
+ C.isStrictlyPositive())
+ return true;
+
+ // (X + C)<nsw> s< X if C < 0
+ if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
+ return true;
+ break;
+ }
+
+ return false;
+}
+
bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
const SCEV *LHS,
const SCEV *RHS) {
// expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
// interesting cases seen in practice. We can consider "upgrading" L >= 0 to
// use isKnownPredicate later if needed.
- if (isKnownNonNegative(RHS) &&
- isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
- isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS))
- return true;
-
- return false;
+ return isKnownNonNegative(RHS) &&
+ isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
+ isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
}
/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
return false;
}
+namespace {
/// RAII wrapper to prevent recursive application of isImpliedCond.
/// ScalarEvolution's PendingLoopPredicates set must be empty unless we are
/// currently evaluating isImpliedCond.
LoopPreds.erase(Cond);
}
};
+} // end anonymous namespace
/// isImpliedCond - Test whether the condition described by Pred, LHS,
/// and RHS is true whenever the given Cond value evaluates to true.
RHS, LHS, FoundLHS, FoundRHS);
}
+ // Unsigned comparison is the same as signed comparison when both the operands
+ // are non-negative.
+ if (CmpInst::isUnsigned(FoundPred) &&
+ CmpInst::getSignedPredicate(FoundPred) == Pred &&
+ isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS))
+ return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
+
// Check if we can make progress by sharpening ranges.
if (FoundPred == ICmpInst::ICMP_NE &&
(isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
APInt Min = ICmpInst::isSigned(Pred) ?
getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin();
- if (Min == C->getValue()->getValue()) {
+ if (Min == C->getAPInt()) {
// Given (V >= Min && V != Min) we conclude V >= (Min + 1).
// This is true even if (Min + 1) wraps around -- in case of
// wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
}
if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
- const auto &M = cast<SCEVConstant>(More)->getValue()->getValue();
- const auto &L = cast<SCEVConstant>(Less)->getValue()->getValue();
+ const auto &M = cast<SCEVConstant>(More)->getAPInt();
+ const auto &L = cast<SCEVConstant>(Less)->getAPInt();
C = M - L;
return true;
}
if (splitBinaryAdd(Less, L, R, Flags))
if (const auto *LC = dyn_cast<SCEVConstant>(L))
if (R == More) {
- C = -(LC->getValue()->getValue());
+ C = -(LC->getAPInt());
return true;
}
if (splitBinaryAdd(More, L, R, Flags))
if (const auto *LC = dyn_cast<SCEVConstant>(L))
if (R == Less) {
- C = LC->getValue()->getValue();
+ C = LC->getAPInt();
return true;
}
const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr);
if (!MaxExpr) return false;
- auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate);
- return It != MaxExpr->op_end();
+ return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end();
}
auto IsKnownPredicateFull =
[this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
return isKnownPredicateWithRanges(Pred, LHS, RHS) ||
- IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
- IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS);
+ IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
+ IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
+ isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
};
switch (Pred) {
!isa<SCEVConstant>(AddLHS->getOperand(0)))
return false;
- APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getValue()->getValue();
+ APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
// `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
// antecedent "`FoundLHS` `Pred` `FoundRHS`".
// Since `LHS` is `FoundLHS` + `AddLHS->getOperand(0)`, we can compute a range
// for `LHS`:
- APInt Addend =
- cast<SCEVConstant>(AddLHS->getOperand(0))->getValue()->getValue();
+ APInt Addend = cast<SCEVConstant>(AddLHS->getOperand(0))->getAPInt();
ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(Addend));
// We can also compute the range of values for `LHS` that satisfy the
// consequent, "`LHS` `Pred` `RHS`":
- APInt ConstRHS = cast<SCEVConstant>(RHS)->getValue()->getValue();
+ APInt ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
ConstantRange SatisfyingLHSRange =
ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS);
// overflow, in which case if RHS - Start is a constant, we don't need to
// do a max operation since we can just figure it out statically
if (NoWrap && isa<SCEVConstant>(Diff)) {
- APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue();
+ APInt D = dyn_cast<const SCEVConstant>(Diff)->getAPInt();
if (D.isNegative())
End = Start;
} else
// overflow, in which case if RHS - Start is a constant, we don't need to
// do a max operation since we can just figure it out statically
if (NoWrap && isa<SCEVConstant>(Diff)) {
- APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue();
+ APInt D = dyn_cast<const SCEVConstant>(Diff)->getAPInt();
if (!D.isNegative())
End = Start;
} else
getNoWrapFlags(FlagNW));
if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
return ShiftedAddRec->getNumIterationsInRange(
- Range.subtract(SC->getValue()->getValue()), SE);
+ Range.subtract(SC->getAPInt()), SE);
// This is strange and shouldn't happen.
return SE.getCouldNotCompute();
}
// The only time we can solve this is when we have all constant indices.
// Otherwise, we cannot determine the overflow conditions.
- if (std::any_of(op_begin(), op_end(),
- [](const SCEV *Op) { return !isa<SCEVConstant>(Op);}))
+ if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
return SE.getCouldNotCompute();
// Okay at this point we know that all elements of the chrec are constants and
// If A is negative then the lower of the range is the last possible loop
// value. Also note that we already checked for a full range.
APInt One(BitWidth,1);
- APInt A = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
+ APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
// The exit value should be (End+A)/A.
FlagAnyWrap);
// Next, solve the constructed addrec
- std::pair<const SCEV *,const SCEV *> Roots =
- SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
+ auto Roots = SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
if (R1) {
// Pick the smallest positive root value.
- if (ConstantInt *CB =
- dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
- R1->getValue(), R2->getValue()))) {
+ if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp(
+ ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) {
if (!CB->getZExtValue())
std::swap(R1, R2); // R1 is the minimum root now.
if (Range.contains(R1Val->getValue())) {
// The next iteration must be out of the range...
ConstantInt *NextVal =
- ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1);
+ ConstantInt::get(SE.getContext(), R1->getAPInt() + 1);
R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
if (!Range.contains(R1Val->getValue()))
// If R1 was not in the range, then it is a good return value. Make
// sure that R1-1 WAS in the range though, just in case.
ConstantInt *NextVal =
- ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1);
+ ConstantInt::get(SE.getContext(), R1->getAPInt() - 1);
R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
if (Range.contains(R1Val->getValue()))
return R1;
return true;
}
-namespace {
-struct FindParameter {
- bool FoundParameter;
- FindParameter() : FoundParameter(false) {}
-
- bool follow(const SCEV *S) {
- if (isa<SCEVUnknown>(S)) {
- FoundParameter = true;
- // Stop recursion: we found a parameter.
- return false;
- }
- // Keep looking.
- return true;
- }
- bool isDone() const {
- // Stop recursion if we have found a parameter.
- return FoundParameter;
- }
-};
-}
-
// Returns true when S contains at least a SCEVUnknown parameter.
static inline bool
containsParameters(const SCEV *S) {
+ struct FindParameter {
+ bool FoundParameter;
+ FindParameter() : FoundParameter(false) {}
+
+ bool follow(const SCEV *S) {
+ if (isa<SCEVUnknown>(S)) {
+ FoundParameter = true;
+ // Stop recursion: we found a parameter.
+ return false;
+ }
+ // Keep looking.
+ return true;
+ }
+ bool isDone() const {
+ // Stop recursion if we have found a parameter.
+ return FoundParameter;
+ }
+ };
+
FindParameter F;
SCEVTraversal<FindParameter> ST(F);
ST.visitAll(S);
UnsignedRanges(std::move(Arg.UnsignedRanges)),
SignedRanges(std::move(Arg.SignedRanges)),
UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
+ UniquePreds(std::move(Arg.UniquePreds)),
SCEVAllocator(std::move(Arg.SCEVAllocator)),
FirstUnknown(Arg.FirstUnknown) {
Arg.FirstUnknown = nullptr;
// This recurrence is variant w.r.t. L if any of its operands
// are variant.
- for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
- I != E; ++I)
- if (!isLoopInvariant(*I, L))
+ for (auto *Op : AR->operands())
+ if (!isLoopInvariant(Op, L))
return LoopVariant;
// Otherwise it's loop-invariant.
case scMulExpr:
case scUMaxExpr:
case scSMaxExpr: {
- const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
bool HasVarying = false;
- for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
- I != E; ++I) {
- LoopDisposition D = getLoopDisposition(*I, L);
+ for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
+ LoopDisposition D = getLoopDisposition(Op, L);
if (D == LoopVariant)
return LoopVariant;
if (D == LoopComputable)
// invariant if they are not contained in the specified loop.
// Instructions are never considered invariant in the function body
// (null loop) because they are defined within the "loop".
- if (Instruction *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
+ if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
return LoopInvariant;
case scCouldNotCompute:
case scSMaxExpr: {
const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
bool Proper = true;
- for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
- I != E; ++I) {
- BlockDisposition D = getBlockDisposition(*I, BB);
+ for (const SCEV *NAryOp : NAry->operands()) {
+ BlockDisposition D = getBlockDisposition(NAryOp, BB);
if (D == DoesNotDominateBlock)
return DoesNotDominateBlock;
if (D == DominatesBlock)
return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
}
-namespace {
-// Search for a SCEV expression node within an expression tree.
-// Implements SCEVTraversal::Visitor.
-struct SCEVSearch {
- const SCEV *Node;
- bool IsFound;
+bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
+ // Search for a SCEV expression node within an expression tree.
+ // Implements SCEVTraversal::Visitor.
+ struct SCEVSearch {
+ const SCEV *Node;
+ bool IsFound;
- SCEVSearch(const SCEV *N): Node(N), IsFound(false) {}
+ SCEVSearch(const SCEV *N): Node(N), IsFound(false) {}
- bool follow(const SCEV *S) {
- IsFound |= (S == Node);
- return !IsFound;
- }
- bool isDone() const { return IsFound; }
-};
-}
+ bool follow(const SCEV *S) {
+ IsFound |= (S == Node);
+ return !IsFound;
+ }
+ bool isDone() const { return IsFound; }
+ };
-bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
SCEVSearch Search(Op);
visitAll(S, Search);
return Search.IsFound;
/// getLoopBackedgeTakenCounts - Helper method for verifyAnalysis.
static void
getLoopBackedgeTakenCounts(Loop *L, VerifyMap &Map, ScalarEvolution &SE) {
- for (Loop::reverse_iterator I = L->rbegin(), E = L->rend(); I != E; ++I) {
- getLoopBackedgeTakenCounts(*I, Map, SE); // recurse.
+ std::string &S = Map[L];
+ if (S.empty()) {
+ raw_string_ostream OS(S);
+ SE.getBackedgeTakenCount(L)->print(OS);
- std::string &S = Map[L];
- if (S.empty()) {
- raw_string_ostream OS(S);
- SE.getBackedgeTakenCount(L)->print(OS);
-
- // false and 0 are semantically equivalent. This can happen in dead loops.
- replaceSubString(OS.str(), "false", "0");
- // Remove wrap flags, their use in SCEV is highly fragile.
- // FIXME: Remove this when SCEV gets smarter about them.
- replaceSubString(OS.str(), "<nw>", "");
- replaceSubString(OS.str(), "<nsw>", "");
- replaceSubString(OS.str(), "<nuw>", "");
- }
+ // false and 0 are semantically equivalent. This can happen in dead loops.
+ replaceSubString(OS.str(), "false", "0");
+ // Remove wrap flags, their use in SCEV is highly fragile.
+ // FIXME: Remove this when SCEV gets smarter about them.
+ replaceSubString(OS.str(), "<nw>", "");
+ replaceSubString(OS.str(), "<nsw>", "");
+ replaceSubString(OS.str(), "<nuw>", "");
}
+
+ for (auto *R : reverse(*L))
+ getLoopBackedgeTakenCounts(R, Map, SE); // recurse.
}
void ScalarEvolution::verify() const {
AU.addRequiredTransitive<DominatorTreeWrapperPass>();
AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
}
+
+const SCEVPredicate *
+ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
+ const SCEVConstant *RHS) {
+ FoldingSetNodeID ID;
+ // Unique this node based on the arguments
+ ID.AddInteger(SCEVPredicate::P_Equal);
+ ID.AddPointer(LHS);
+ ID.AddPointer(RHS);
+ void *IP = nullptr;
+ if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
+ return S;
+ SCEVEqualPredicate *Eq = new (SCEVAllocator)
+ SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
+ UniquePreds.InsertNode(Eq, IP);
+ return Eq;
+}
+
+namespace {
+class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
+public:
+ static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
+ SCEVUnionPredicate &A) {
+ SCEVPredicateRewriter Rewriter(SE, A);
+ return Rewriter.visit(Scev);
+ }
+
+ SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P)
+ : SCEVRewriteVisitor(SE), P(P) {}
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ auto ExprPreds = P.getPredicatesForExpr(Expr);
+ for (auto *Pred : ExprPreds)
+ if (const auto *IPred = dyn_cast<const SCEVEqualPredicate>(Pred))
+ if (IPred->getLHS() == Expr)
+ return IPred->getRHS();
+
+ return Expr;
+ }
+
+private:
+ SCEVUnionPredicate &P;
+};
+} // end anonymous namespace
+
+const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev,
+ SCEVUnionPredicate &Preds) {
+ return SCEVPredicateRewriter::rewrite(Scev, *this, Preds);
+}
+
+/// SCEV predicates
+SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
+ SCEVPredicateKind Kind)
+ : FastID(ID), Kind(Kind) {}
+
+SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
+ const SCEVUnknown *LHS,
+ const SCEVConstant *RHS)
+ : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {}
+
+bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
+ const auto *Op = dyn_cast<const SCEVEqualPredicate>(N);
+
+ if (!Op)
+ return false;
+
+ return Op->LHS == LHS && Op->RHS == RHS;
+}
+
+bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
+
+const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
+
+void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
+ OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
+}
+
+/// Union predicates don't get cached so create a dummy set ID for it.
+SCEVUnionPredicate::SCEVUnionPredicate()
+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
+
+bool SCEVUnionPredicate::isAlwaysTrue() const {
+ return all_of(Preds,
+ [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
+}
+
+ArrayRef<const SCEVPredicate *>
+SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
+ auto I = SCEVToPreds.find(Expr);
+ if (I == SCEVToPreds.end())
+ return ArrayRef<const SCEVPredicate *>();
+ return I->second;
+}
+
+bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
+ if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N))
+ return all_of(Set->Preds,
+ [this](const SCEVPredicate *I) { return this->implies(I); });
+
+ auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
+ if (ScevPredsIt == SCEVToPreds.end())
+ return false;
+ auto &SCEVPreds = ScevPredsIt->second;
+
+ return any_of(SCEVPreds,
+ [N](const SCEVPredicate *I) { return I->implies(N); });
+}
+
+const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }
+
+void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
+ for (auto Pred : Preds)
+ Pred->print(OS, Depth);
+}
+
+void SCEVUnionPredicate::add(const SCEVPredicate *N) {
+ if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) {
+ for (auto Pred : Set->Preds)
+ add(Pred);
+ return;
+ }
+
+ if (implies(N))
+ return;
+
+ const SCEV *Key = N->getExpr();
+ assert(Key && "Only SCEVUnionPredicate doesn't have an "
+ " associated expression!");
+
+ SCEVToPreds[Key].push_back(N);
+ Preds.push_back(N);
+}
+
+PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE)
+ : SE(SE), Generation(0) {}
+
+const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
+ const SCEV *Expr = SE.getSCEV(V);
+ RewriteEntry &Entry = RewriteMap[Expr];
+
+ // If we already have an entry and the version matches, return it.
+ if (Entry.second && Generation == Entry.first)
+ return Entry.second;
+
+ // We found an entry but it's stale. Rewrite the stale entry
+ // acording to the current predicate.
+ if (Entry.second)
+ Expr = Entry.second;
+
+ const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds);
+ Entry = {Generation, NewSCEV};
+
+ return NewSCEV;
+}
+
+void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
+ if (Preds.implies(&Pred))
+ return;
+ Preds.add(&Pred);
+ updateGeneration();
+}
+
+const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const {
+ return Preds;
+}
+
+void PredicatedScalarEvolution::updateGeneration() {
+ // If the generation number wrapped recompute everything.
+ if (++Generation == 0) {
+ for (auto &II : RewriteMap) {
+ const SCEV *Rewritten = II.second.second;
+ II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)};
+ }
+ }
+}