#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
// Implementation of the SCEV class.
//
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD
void SCEV::dump() const {
print(dbgs());
dbgs() << '\n';
}
-#endif
void SCEV::print(raw_ostream &OS) const {
switch (static_cast<SCEVTypes>(getSCEVType())) {
if (!SC) return false;
// Return true if the value is negative, this matches things like (-42 * V).
- return SC->getValue()->getValue().isNegative();
+ return SC->getAPInt().isNegative();
}
SCEVCouldNotCompute::SCEVCouldNotCompute() :
//===----------------------------------------------------------------------===//
namespace {
- /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
- /// than the complexity of the RHS. This comparator is used to canonicalize
- /// expressions.
- class SCEVComplexityCompare {
- const LoopInfo *const LI;
- public:
- explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {}
-
- // Return true or false if LHS is less than, or at least RHS, respectively.
- bool operator()(const SCEV *LHS, const SCEV *RHS) const {
- return compare(LHS, RHS) < 0;
- }
-
- // Return negative, zero, or positive, if LHS is less than, equal to, or
- // greater than RHS, respectively. A three-way result allows recursive
- // comparisons to be more efficient.
- int compare(const SCEV *LHS, const SCEV *RHS) const {
- // Fast-path: SCEVs are uniqued so we can do a quick equality check.
- if (LHS == RHS)
- return 0;
-
- // Primarily, sort the SCEVs by their getSCEVType().
- unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
- if (LType != RType)
- return (int)LType - (int)RType;
-
- // Aside from the getSCEVType() ordering, the particular ordering
- // isn't very important except that it's beneficial to be consistent,
- // so that (a + b) and (b + a) don't end up as different expressions.
- switch (static_cast<SCEVTypes>(LType)) {
- case scUnknown: {
- const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
- const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
-
- // Sort SCEVUnknown values with some loose heuristics. TODO: This is
- // not as complete as it could be.
- const Value *LV = LU->getValue(), *RV = RU->getValue();
-
- // Order pointer values after integer values. This helps SCEVExpander
- // form GEPs.
- bool LIsPointer = LV->getType()->isPointerTy(),
- RIsPointer = RV->getType()->isPointerTy();
- if (LIsPointer != RIsPointer)
- return (int)LIsPointer - (int)RIsPointer;
-
- // Compare getValueID values.
- unsigned LID = LV->getValueID(),
- RID = RV->getValueID();
- if (LID != RID)
- return (int)LID - (int)RID;
-
- // Sort arguments by their position.
- if (const Argument *LA = dyn_cast<Argument>(LV)) {
- const Argument *RA = cast<Argument>(RV);
- unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
- return (int)LArgNo - (int)RArgNo;
- }
-
- // For instructions, compare their loop depth, and their operand
- // count. This is pretty loose.
- if (const Instruction *LInst = dyn_cast<Instruction>(LV)) {
- const Instruction *RInst = cast<Instruction>(RV);
-
- // Compare loop depths.
- const BasicBlock *LParent = LInst->getParent(),
- *RParent = RInst->getParent();
- if (LParent != RParent) {
- unsigned LDepth = LI->getLoopDepth(LParent),
- RDepth = LI->getLoopDepth(RParent);
- if (LDepth != RDepth)
- return (int)LDepth - (int)RDepth;
- }
-
- // Compare the number of operands.
- unsigned LNumOps = LInst->getNumOperands(),
- RNumOps = RInst->getNumOperands();
- return (int)LNumOps - (int)RNumOps;
- }
+/// SCEVComplexityCompare - Return true if the complexity of the LHS is less
+/// than the complexity of the RHS. This comparator is used to canonicalize
+/// expressions.
+class SCEVComplexityCompare {
+ const LoopInfo *const LI;
+public:
+ explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {}
- return 0;
- }
+ // Return true or false if LHS is less than, or at least RHS, respectively.
+ bool operator()(const SCEV *LHS, const SCEV *RHS) const {
+ return compare(LHS, RHS) < 0;
+ }
- case scConstant: {
- const SCEVConstant *LC = cast<SCEVConstant>(LHS);
- const SCEVConstant *RC = cast<SCEVConstant>(RHS);
-
- // Compare constant values.
- const APInt &LA = LC->getValue()->getValue();
- const APInt &RA = RC->getValue()->getValue();
- unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
- if (LBitWidth != RBitWidth)
- return (int)LBitWidth - (int)RBitWidth;
- return LA.ult(RA) ? -1 : 1;
+ // Return negative, zero, or positive, if LHS is less than, equal to, or
+ // greater than RHS, respectively. A three-way result allows recursive
+ // comparisons to be more efficient.
+ int compare(const SCEV *LHS, const SCEV *RHS) const {
+ // Fast-path: SCEVs are uniqued so we can do a quick equality check.
+ if (LHS == RHS)
+ return 0;
+
+ // Primarily, sort the SCEVs by their getSCEVType().
+ unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
+ if (LType != RType)
+ return (int)LType - (int)RType;
+
+ // Aside from the getSCEVType() ordering, the particular ordering
+ // isn't very important except that it's beneficial to be consistent,
+ // so that (a + b) and (b + a) don't end up as different expressions.
+ switch (static_cast<SCEVTypes>(LType)) {
+ case scUnknown: {
+ const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
+ const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
+
+ // Sort SCEVUnknown values with some loose heuristics. TODO: This is
+ // not as complete as it could be.
+ const Value *LV = LU->getValue(), *RV = RU->getValue();
+
+ // Order pointer values after integer values. This helps SCEVExpander
+ // form GEPs.
+ bool LIsPointer = LV->getType()->isPointerTy(),
+ RIsPointer = RV->getType()->isPointerTy();
+ if (LIsPointer != RIsPointer)
+ return (int)LIsPointer - (int)RIsPointer;
+
+ // Compare getValueID values.
+ unsigned LID = LV->getValueID(),
+ RID = RV->getValueID();
+ if (LID != RID)
+ return (int)LID - (int)RID;
+
+ // Sort arguments by their position.
+ if (const Argument *LA = dyn_cast<Argument>(LV)) {
+ const Argument *RA = cast<Argument>(RV);
+ unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
+ return (int)LArgNo - (int)RArgNo;
}
- case scAddRecExpr: {
- const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
- const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
-
- // Compare addrec loop depths.
- const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
- if (LLoop != RLoop) {
- unsigned LDepth = LLoop->getLoopDepth(),
- RDepth = RLoop->getLoopDepth();
+ // For instructions, compare their loop depth, and their operand
+ // count. This is pretty loose.
+ if (const Instruction *LInst = dyn_cast<Instruction>(LV)) {
+ const Instruction *RInst = cast<Instruction>(RV);
+
+ // Compare loop depths.
+ const BasicBlock *LParent = LInst->getParent(),
+ *RParent = RInst->getParent();
+ if (LParent != RParent) {
+ unsigned LDepth = LI->getLoopDepth(LParent),
+ RDepth = LI->getLoopDepth(RParent);
if (LDepth != RDepth)
return (int)LDepth - (int)RDepth;
}
- // Addrec complexity grows with operand count.
- unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
- if (LNumOps != RNumOps)
- return (int)LNumOps - (int)RNumOps;
+ // Compare the number of operands.
+ unsigned LNumOps = LInst->getNumOperands(),
+ RNumOps = RInst->getNumOperands();
+ return (int)LNumOps - (int)RNumOps;
+ }
+
+ return 0;
+ }
- // Lexicographically compare.
- for (unsigned i = 0; i != LNumOps; ++i) {
- long X = compare(LA->getOperand(i), RA->getOperand(i));
- if (X != 0)
- return X;
- }
+ case scConstant: {
+ const SCEVConstant *LC = cast<SCEVConstant>(LHS);
+ const SCEVConstant *RC = cast<SCEVConstant>(RHS);
- return 0;
+ // Compare constant values.
+ const APInt &LA = LC->getAPInt();
+ const APInt &RA = RC->getAPInt();
+ unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
+ if (LBitWidth != RBitWidth)
+ return (int)LBitWidth - (int)RBitWidth;
+ return LA.ult(RA) ? -1 : 1;
+ }
+
+ case scAddRecExpr: {
+ const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
+ const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
+
+ // Compare addrec loop depths.
+ const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
+ if (LLoop != RLoop) {
+ unsigned LDepth = LLoop->getLoopDepth(),
+ RDepth = RLoop->getLoopDepth();
+ if (LDepth != RDepth)
+ return (int)LDepth - (int)RDepth;
}
- case scAddExpr:
- case scMulExpr:
- case scSMaxExpr:
- case scUMaxExpr: {
- const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
- const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
-
- // Lexicographically compare n-ary expressions.
- unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
- if (LNumOps != RNumOps)
- return (int)LNumOps - (int)RNumOps;
-
- for (unsigned i = 0; i != LNumOps; ++i) {
- if (i >= RNumOps)
- return 1;
- long X = compare(LC->getOperand(i), RC->getOperand(i));
- if (X != 0)
- return X;
- }
+ // Addrec complexity grows with operand count.
+ unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
+ if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
+
+ // Lexicographically compare.
+ for (unsigned i = 0; i != LNumOps; ++i) {
+ long X = compare(LA->getOperand(i), RA->getOperand(i));
+ if (X != 0)
+ return X;
}
- case scUDivExpr: {
- const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
- const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
+ return 0;
+ }
- // Lexicographically compare udiv expressions.
- long X = compare(LC->getLHS(), RC->getLHS());
+ case scAddExpr:
+ case scMulExpr:
+ case scSMaxExpr:
+ case scUMaxExpr: {
+ const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
+ const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
+
+ // Lexicographically compare n-ary expressions.
+ unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
+ if (LNumOps != RNumOps)
+ return (int)LNumOps - (int)RNumOps;
+
+ for (unsigned i = 0; i != LNumOps; ++i) {
+ if (i >= RNumOps)
+ return 1;
+ long X = compare(LC->getOperand(i), RC->getOperand(i));
if (X != 0)
return X;
- return compare(LC->getRHS(), RC->getRHS());
}
+ return (int)LNumOps - (int)RNumOps;
+ }
- case scTruncate:
- case scZeroExtend:
- case scSignExtend: {
- const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
- const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
+ case scUDivExpr: {
+ const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
+ const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
- // Compare cast expressions by operand.
- return compare(LC->getOperand(), RC->getOperand());
- }
+ // Lexicographically compare udiv expressions.
+ long X = compare(LC->getLHS(), RC->getLHS());
+ if (X != 0)
+ return X;
+ return compare(LC->getRHS(), RC->getRHS());
+ }
- case scCouldNotCompute:
- llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
- }
- llvm_unreachable("Unknown SCEV kind!");
+ case scTruncate:
+ case scZeroExtend:
+ case scSignExtend: {
+ const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
+ const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
+
+ // Compare cast expressions by operand.
+ return compare(LC->getOperand(), RC->getOperand());
}
- };
-}
+
+ case scCouldNotCompute:
+ llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
+ }
+ llvm_unreachable("Unknown SCEV kind!");
+ }
+};
+} // end anonymous namespace
/// GroupByComplexity - Given a list of SCEV objects, order them by their
/// complexity, and group objects of the same complexity together by value.
}
}
-namespace {
-struct FindSCEVSize {
- int Size;
- FindSCEVSize() : Size(0) {}
-
- bool follow(const SCEV *S) {
- ++Size;
- // Keep looking at all operands of S.
- return true;
- }
- bool isDone() const {
- return false;
- }
-};
-}
-
// Returns the size of the SCEV S.
static inline int sizeOfSCEV(const SCEV *S) {
+ struct FindSCEVSize {
+ int Size;
+ FindSCEVSize() : Size(0) {}
+
+ bool follow(const SCEV *S) {
+ ++Size;
+ // Keep looking at all operands of S.
+ return true;
+ }
+ bool isDone() const {
+ return false;
+ }
+ };
+
FindSCEVSize F;
SCEVTraversal<FindSCEVSize> ST(F);
ST.visitAll(S);
void visitConstant(const SCEVConstant *Numerator) {
if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
- APInt NumeratorVal = Numerator->getValue()->getValue();
- APInt DenominatorVal = D->getValue()->getValue();
+ APInt NumeratorVal = Numerator->getAPInt();
+ APInt DenominatorVal = D->getAPInt();
uint32_t NumeratorBW = NumeratorVal.getBitWidth();
uint32_t DenominatorBW = DenominatorVal.getBitWidth();
// If the input value is a chrec scev, truncate the chrec's operands.
if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
SmallVector<const SCEV *, 4> Operands;
- for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
- Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
+ for (const SCEV *Op : AddRec->operands())
+ Operands.push_back(getTruncateExpr(Op, Ty));
return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
}
// `Step`:
// 1. NSW/NUW flags on the step increment.
- const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags());
+ auto PreStartFlags =
+ ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW);
+ const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
if (OverflowLimit &&
- SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) {
+ SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
return PreStart;
- }
+
return nullptr;
}
if (!StartC)
return false;
- APInt StartAI = StartC->getValue()->getValue();
+ APInt StartAI = StartC->getAPInt();
for (unsigned Delta : {-2, -1, 1, 2}) {
const SCEV *PreStart = getConstant(StartAI - Delta);
+ FoldingSetNodeID ID;
+ ID.AddInteger(scAddRecExpr);
+ ID.AddPointer(PreStart);
+ ID.AddPointer(Step);
+ ID.AddPointer(L);
+ void *IP = nullptr;
+ const auto *PreAR =
+ static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
+
// Give up if we don't already have the add recurrence we need because
// actually constructing an add recurrence is relatively expensive.
- const SCEVAddRecExpr *PreAR = [&]() {
- FoldingSetNodeID ID;
- ID.AddInteger(scAddRecExpr);
- ID.AddPointer(PreStart);
- ID.AddPointer(Step);
- ID.AddPointer(L);
- void *IP = nullptr;
- return static_cast<SCEVAddRecExpr *>(
- this->UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
- }();
-
if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
}
}
+ if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
+ // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
+ if (SA->getNoWrapFlags(SCEV::FlagNUW)) {
+ // If the addition does not unsign overflow then we can, by definition,
+ // commute the zero extension with the addition operation.
+ SmallVector<const SCEV *, 4> Ops;
+ for (const auto *Op : SA->operands())
+ Ops.push_back(getZeroExtendExpr(Op, Ty));
+ return getAddExpr(Ops, SCEV::FlagNUW);
+ }
+ }
+
// The cast wasn't folded; create an explicit cast node.
// Recompute the insert position, as it may have been invalidated.
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
}
// sext(C1 + (C2 * x)) --> C1 + sext(C2 * x) if C1 < C2
- if (auto SA = dyn_cast<SCEVAddExpr>(Op)) {
+ if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
if (SA->getNumOperands() == 2) {
- auto SC1 = dyn_cast<SCEVConstant>(SA->getOperand(0));
- auto SMul = dyn_cast<SCEVMulExpr>(SA->getOperand(1));
+ auto *SC1 = dyn_cast<SCEVConstant>(SA->getOperand(0));
+ auto *SMul = dyn_cast<SCEVMulExpr>(SA->getOperand(1));
if (SMul && SC1) {
- if (auto SC2 = dyn_cast<SCEVConstant>(SMul->getOperand(0))) {
- const APInt &C1 = SC1->getValue()->getValue();
- const APInt &C2 = SC2->getValue()->getValue();
+ if (auto *SC2 = dyn_cast<SCEVConstant>(SMul->getOperand(0))) {
+ const APInt &C1 = SC1->getAPInt();
+ const APInt &C2 = SC2->getAPInt();
if (C1.isStrictlyPositive() && C2.isStrictlyPositive() &&
C2.ugt(C1) && C2.isPowerOf2())
return getAddExpr(getSignExtendExpr(SC1, Ty),
}
}
}
+
+ // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
+ if (SA->getNoWrapFlags(SCEV::FlagNSW)) {
+ // If the addition does not sign overflow then we can, by definition,
+ // commute the sign extension with the addition operation.
+ SmallVector<const SCEV *, 4> Ops;
+ for (const auto *Op : SA->operands())
+ Ops.push_back(getSignExtendExpr(Op, Ty));
+ return getAddExpr(Ops, SCEV::FlagNSW);
+ }
}
// If the input value is a chrec scev, and we can prove that the value
// did not overflow the old, smaller, value, we can sign extend all of the
// If Start and Step are constants, check if we can apply this
// transformation:
// sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2
- auto SC1 = dyn_cast<SCEVConstant>(Start);
- auto SC2 = dyn_cast<SCEVConstant>(Step);
+ auto *SC1 = dyn_cast<SCEVConstant>(Start);
+ auto *SC2 = dyn_cast<SCEVConstant>(Step);
if (SC1 && SC2) {
- const APInt &C1 = SC1->getValue()->getValue();
- const APInt &C2 = SC2->getValue()->getValue();
+ const APInt &C1 = SC1->getAPInt();
+ const APInt &C2 = SC2->getAPInt();
if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) &&
C2.isPowerOf2()) {
Start = getSignExtendExpr(Start, Ty);
// Sign-extend negative constants.
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
- if (SC->getValue()->getValue().isNegative())
+ if (SC->getAPInt().isNegative())
return getSignExtendExpr(Op, Ty);
// Peel off a truncate cast.
// Pull a buried constant out to the outside.
if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
Interesting = true;
- AccumulatedConstant += Scale * C->getValue()->getValue();
+ AccumulatedConstant += Scale * C->getAPInt();
}
// Next comes everything else. We're especially interested in multiplies
const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
APInt NewScale =
- Scale * cast<SCEVConstant>(Mul->getOperand(0))->getValue()->getValue();
+ Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
// A multiplication of a constant with another add; recurse.
const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
return Interesting;
}
-namespace {
- struct APIntCompare {
- bool operator()(const APInt &LHS, const APInt &RHS) const {
- return LHS.ult(RHS);
- }
- };
-}
-
// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
// can't-overflow flags for the operation if possible.
static SCEV::NoWrapFlags
StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
const SmallVectorImpl<const SCEV *> &Ops,
- SCEV::NoWrapFlags OldFlags) {
+ SCEV::NoWrapFlags Flags) {
using namespace std::placeholders;
+ typedef OverflowingBinaryOperator OBO;
bool CanAnalyze =
Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
SCEV::NoWrapFlags SignOrUnsignWrap =
- ScalarEvolution::maskFlags(OldFlags, SignOrUnsignMask);
+ ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
// If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
- auto IsKnownNonNegative =
- std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1);
+ auto IsKnownNonNegative = [&](const SCEV *S) {
+ return SE->isKnownNonNegative(S);
+ };
+
+ if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
+ Flags =
+ ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
- if (SignOrUnsignWrap == SCEV::FlagNSW &&
- std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative))
- return ScalarEvolution::setFlags(OldFlags,
- (SCEV::NoWrapFlags)SignOrUnsignMask);
+ SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
- return OldFlags;
+ if (SignOrUnsignWrap != SignOrUnsignMask && Type == scAddExpr &&
+ Ops.size() == 2 && isa<SCEVConstant>(Ops[0])) {
+
+ // (A + C) --> (A + C)<nsw> if the addition does not sign overflow
+ // (A + C) --> (A + C)<nuw> if the addition does not unsign overflow
+
+ const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
+ if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
+ auto NSWRegion =
+ ConstantRange::makeNoWrapRegion(Instruction::Add, C, OBO::NoSignedWrap);
+ if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
+ Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
+ }
+ if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
+ auto NUWRegion =
+ ConstantRange::makeNoWrapRegion(Instruction::Add, C,
+ OBO::NoUnsignedWrap);
+ if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
+ Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
+ }
+ }
+
+ return Flags;
}
/// getAddExpr - Get a canonical add expression, or something simpler if
"SCEVAddExpr operand types don't match!");
#endif
- Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags);
-
// Sort by complexity, this groups all similar expression types together.
GroupByComplexity(Ops, &LI);
+ Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags);
+
// If there are any constants, fold them together.
unsigned Idx = 0;
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
assert(Idx < Ops.size());
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together!
- Ops[0] = getConstant(LHSC->getValue()->getValue() +
- RHSC->getValue()->getValue());
+ Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
if (Ops.size() == 2) return Ops[0];
Ops.erase(Ops.begin()+1); // Erase the folded element
LHSC = cast<SCEVConstant>(Ops[0]);
break;
}
LargeMulOps.push_back(T->getOperand());
- } else if (const SCEVConstant *C =
- dyn_cast<SCEVConstant>(M->getOperand(j))) {
+ } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
} else {
Ok = false;
if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
Ops.data(), Ops.size(),
APInt(BitWidth, 1), *this)) {
+ struct APIntCompare {
+ bool operator()(const APInt &LHS, const APInt &RHS) const {
+ return LHS.ult(RHS);
+ }
+ };
+
// Some interesting folding opportunity is present, so its worthwhile to
// re-generate the operands list. Group the operands by constant scale,
// to avoid multiplying by the same constant scale multiple times.
std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
- for (SmallVectorImpl<const SCEV *>::const_iterator I = NewOps.begin(),
- E = NewOps.end(); I != E; ++I)
- MulOpLists[M.find(*I)->second].push_back(*I);
+ for (const SCEV *NewOp : NewOps)
+ MulOpLists[M.find(NewOp)->second].push_back(NewOp);
// Re-generate the operands list.
Ops.clear();
if (AccumulatedConstant != 0)
Ops.push_back(getConstant(AccumulatedConstant));
- for (std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare>::iterator
- I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I)
- if (I->first != 0)
- Ops.push_back(getMulExpr(getConstant(I->first),
- getAddExpr(I->second)));
+ for (auto &MulOp : MulOpLists)
+ if (MulOp.first != 0)
+ Ops.push_back(getMulExpr(getConstant(MulOp.first),
+ getAddExpr(MulOp.second)));
if (Ops.empty())
return getZero(Ty);
if (Ops.size() == 1)
AddRec->op_end());
for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
++OtherIdx)
- if (const SCEVAddRecExpr *OtherAddRec =
- dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
+ if (const auto *OtherAddRec = dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
if (OtherAddRec->getLoop() == AddRecLoop) {
for (unsigned i = 0, e = OtherAddRec->getNumOperands();
i != e; ++i) {
"SCEVMulExpr operand types don't match!");
#endif
- Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags);
-
// Sort by complexity, this groups all similar expression types together.
GroupByComplexity(Ops, &LI);
+ Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags);
+
// If there are any constants, fold them together.
unsigned Idx = 0;
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
++Idx;
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together!
- ConstantInt *Fold = ConstantInt::get(getContext(),
- LHSC->getValue()->getValue() *
- RHSC->getValue()->getValue());
+ ConstantInt *Fold =
+ ConstantInt::get(getContext(), LHSC->getAPInt() * RHSC->getAPInt());
Ops[0] = getConstant(Fold);
Ops.erase(Ops.begin()+1); // Erase the folded element
if (Ops.size() == 1) return Ops[0];
if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
SmallVector<const SCEV *, 4> NewOps;
bool AnyFolded = false;
- for (SCEVAddRecExpr::op_iterator I = Add->op_begin(),
- E = Add->op_end(); I != E; ++I) {
- const SCEV *Mul = getMulExpr(Ops[0], *I);
+ for (const SCEV *AddOp : Add->operands()) {
+ const SCEV *Mul = getMulExpr(Ops[0], AddOp);
if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
NewOps.push_back(Mul);
}
if (AnyFolded)
return getAddExpr(NewOps);
- }
- else if (const SCEVAddRecExpr *
- AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
+ } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
// Negation preserves a recurrence's no self-wrap property.
SmallVector<const SCEV *, 4> Operands;
- for (SCEVAddRecExpr::op_iterator I = AddRec->op_begin(),
- E = AddRec->op_end(); I != E; ++I) {
- Operands.push_back(getMulExpr(Ops[0], *I));
- }
+ for (const SCEV *AddRecOp : AddRec->operands())
+ Operands.push_back(getMulExpr(Ops[0], AddRecOp));
+
return getAddRecExpr(Operands, AddRec->getLoop(),
AddRec->getNoWrapFlags(SCEV::FlagNW));
}
// its operands.
// TODO: Generalize this to non-constants by using known-bits information.
Type *Ty = LHS->getType();
- unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
+ unsigned LZ = RHSC->getAPInt().countLeadingZeros();
unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
// For non-power-of-two values, effectively round the value up to the
// nearest power of two.
- if (!RHSC->getValue()->getValue().isPowerOf2())
+ if (!RHSC->getAPInt().isPowerOf2())
++MaxShiftAmt;
IntegerType *ExtTy =
IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
if (const SCEVConstant *Step =
dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
// {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
- const APInt &StepInt = Step->getValue()->getValue();
- const APInt &DivInt = RHSC->getValue()->getValue();
+ const APInt &StepInt = Step->getAPInt();
+ const APInt &DivInt = RHSC->getAPInt();
if (!StepInt.urem(DivInt) &&
getZeroExtendExpr(AR, ExtTy) ==
getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
getZeroExtendExpr(Step, ExtTy),
AR->getLoop(), SCEV::FlagAnyWrap)) {
SmallVector<const SCEV *, 4> Operands;
- for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
- Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
- return getAddRecExpr(Operands, AR->getLoop(),
- SCEV::FlagNW);
+ for (const SCEV *Op : AR->operands())
+ Operands.push_back(getUDivExpr(Op, RHS));
+ return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
}
/// Get a canonical UDivExpr for a recurrence.
/// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
getZeroExtendExpr(Step, ExtTy),
AR->getLoop(), SCEV::FlagAnyWrap)) {
- const APInt &StartInt = StartC->getValue()->getValue();
+ const APInt &StartInt = StartC->getAPInt();
const APInt &StartRem = StartInt.urem(StepInt);
if (StartRem != 0)
LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step,
// (A*B)/C --> A*(B/C) if safe and B/C can be folded.
if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
SmallVector<const SCEV *, 4> Operands;
- for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
- Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
+ for (const SCEV *Op : M->operands())
+ Operands.push_back(getZeroExtendExpr(Op, ExtTy));
if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
// Find an operand that's safely divisible.
for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
// (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
SmallVector<const SCEV *, 4> Operands;
- for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
- Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
+ for (const SCEV *Op : A->operands())
+ Operands.push_back(getZeroExtendExpr(Op, ExtTy));
if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
Operands.clear();
for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
}
static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
- APInt A = C1->getValue()->getValue().abs();
- APInt B = C2->getValue()->getValue().abs();
+ APInt A = C1->getAPInt().abs();
+ APInt B = C2->getAPInt().abs();
uint32_t ABW = A.getBitWidth();
uint32_t BBW = B.getBitWidth();
if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
// If the mulexpr multiplies by a constant, then that constant must be the
// first element of the mulexpr.
- if (const SCEVConstant *LHSCst =
- dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
+ if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
if (LHSCst == RHSCst) {
SmallVector<const SCEV *, 2> Operands;
Operands.append(Mul->op_begin() + 1, Mul->op_end());
// check.
APInt Factor = gcd(LHSCst, RHSCst);
if (!Factor.isIntN(1)) {
- LHSCst = cast<SCEVConstant>(
- getConstant(LHSCst->getValue()->getValue().udiv(Factor)));
- RHSCst = cast<SCEVConstant>(
- getConstant(RHSCst->getValue()->getValue().udiv(Factor)));
+ LHSCst =
+ cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
+ RHSCst =
+ cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
SmallVector<const SCEV *, 2> Operands;
Operands.push_back(LHSCst);
Operands.append(Mul->op_begin() + 1, Mul->op_end());
// AddRecs require their operands be loop-invariant with respect to their
// loops. Don't perform this transformation if it would break this
// requirement.
- bool AllInvariant = true;
- for (unsigned i = 0, e = Operands.size(); i != e; ++i)
- if (!isLoopInvariant(Operands[i], L)) {
- AllInvariant = false;
- break;
- }
+ bool AllInvariant = all_of(
+ Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
+
if (AllInvariant) {
// Create a recurrence for the outer loop with the same step size.
//
maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
- AllInvariant = true;
- for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i)
- if (!isLoopInvariant(NestedOperands[i], NestedLoop)) {
- AllInvariant = false;
- break;
- }
+ AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
+ return isLoopInvariant(Op, NestedLoop);
+ });
+
if (AllInvariant) {
// Ok, both add recurrences are valid after the transformation.
//
assert(Idx < Ops.size());
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together!
- ConstantInt *Fold = ConstantInt::get(getContext(),
- APIntOps::smax(LHSC->getValue()->getValue(),
- RHSC->getValue()->getValue()));
+ ConstantInt *Fold = ConstantInt::get(
+ getContext(), APIntOps::smax(LHSC->getAPInt(), RHSC->getAPInt()));
Ops[0] = getConstant(Fold);
Ops.erase(Ops.begin()+1); // Erase the folded element
if (Ops.size() == 1) return Ops[0];
assert(Idx < Ops.size());
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// We found two constants, fold them together!
- ConstantInt *Fold = ConstantInt::get(getContext(),
- APIntOps::umax(LHSC->getValue()->getValue(),
- RHSC->getValue()->getValue()));
+ ConstantInt *Fold = ConstantInt::get(
+ getContext(), APIntOps::umax(LHSC->getAPInt(), RHSC->getAPInt()));
Ops[0] = getConstant(Fold);
Ops.erase(Ops.begin()+1); // Erase the folded element
if (Ops.size() == 1) return Ops[0];
// We can bypass creating a target-independent
// constant expression and then folding it back into a ConstantInt.
// This is just a compile-time optimization.
- return getConstant(IntTy,
- F.getParent()->getDataLayout().getTypeAllocSize(AllocTy));
+ return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
}
const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
// constant expression and then folding it back into a ConstantInt.
// This is just a compile-time optimization.
return getConstant(
- IntTy,
- F.getParent()->getDataLayout().getStructLayout(STy)->getElementOffset(
- FieldNo));
+ IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
}
const SCEV *ScalarEvolution::getUnknown(Value *V) {
/// for which isSCEVable must return true.
uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
assert(isSCEVable(Ty) && "Type is not SCEVable!");
- return F.getParent()->getDataLayout().getTypeSizeInBits(Ty);
+ return getDataLayout().getTypeSizeInBits(Ty);
}
/// getEffectiveSCEVType - Return a type with the same bitwidth as
Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
assert(isSCEVable(Ty) && "Type is not SCEVable!");
- if (Ty->isIntegerTy()) {
+ if (Ty->isIntegerTy())
return Ty;
- }
// The only other support type is pointer.
assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
- return F.getParent()->getDataLayout().getIntPtrType(Ty);
+ return getDataLayout().getIntPtrType(Ty);
}
const SCEV *ScalarEvolution::getCouldNotCompute() {
return CouldNotCompute.get();
}
-namespace {
+
+bool ScalarEvolution::checkValidity(const SCEV *S) const {
// Helper class working with SCEVTraversal to figure out if a SCEV contains
// a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne
// is set iff if find such SCEVUnknown.
}
bool isDone() const { return FindOne; }
};
-}
-bool ScalarEvolution::checkValidity(const SCEV *S) const {
FindInvalidSCEVUnknown F;
SCEVTraversal<FindInvalidSCEVUnknown> ST(F);
ST.visitAll(S);
if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
return getPointerBase(Cast->getOperand());
- }
- else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
+ } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
const SCEV *PtrOp = nullptr;
- for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
- I != E; ++I) {
- if ((*I)->getType()->isPointerTy()) {
+ for (const SCEV *NAryOp : NAry->operands()) {
+ if (NAryOp->getType()->isPointerTy()) {
// Cannot find the base of an expression with multiple pointer operands.
if (PtrOp)
return V;
- PtrOp = *I;
+ PtrOp = NAryOp;
}
}
if (!PtrOp)
if (!Visited.insert(I).second)
continue;
- ValueExprMapType::iterator It =
- ValueExprMap.find_as(static_cast<Value *>(I));
+ auto It = ValueExprMap.find_as(static_cast<Value *>(I));
if (It != ValueExprMap.end()) {
const SCEV *Old = It->second;
}
}
-/// createNodeForPHI - PHI nodes have two cases. Either the PHI node exists in
-/// a loop header, making it a potential recurrence, or it doesn't.
-///
-const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
- if (const Loop *L = LI.getLoopFor(PN->getParent()))
- if (L->getHeader() == PN->getParent()) {
- // The loop may have multiple entrances or multiple exits; we can analyze
- // this phi as an addrec if it has a unique entry value and a unique
- // backedge value.
- Value *BEValueV = nullptr, *StartValueV = nullptr;
- for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
- Value *V = PN->getIncomingValue(i);
- if (L->contains(PN->getIncomingBlock(i))) {
- if (!BEValueV) {
- BEValueV = V;
- } else if (BEValueV != V) {
- BEValueV = nullptr;
- break;
- }
- } else if (!StartValueV) {
- StartValueV = V;
- } else if (StartValueV != V) {
- StartValueV = nullptr;
- break;
- }
- }
- if (BEValueV && StartValueV) {
- // While we are analyzing this PHI node, handle its value symbolically.
- const SCEV *SymbolicName = getUnknown(PN);
- assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
- "PHI node already processed?");
- ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
-
- // Using this symbolic name for the PHI, analyze the value coming around
- // the back-edge.
- const SCEV *BEValue = getSCEV(BEValueV);
-
- // NOTE: If BEValue is loop invariant, we know that the PHI node just
- // has a special value for the first iteration of the loop.
-
- // If the value coming around the backedge is an add with the symbolic
- // value we just inserted, then we found a simple induction variable!
- if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
- // If there is a single occurrence of the symbolic value, replace it
- // with a recurrence.
- unsigned FoundIndex = Add->getNumOperands();
- for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
- if (Add->getOperand(i) == SymbolicName)
- if (FoundIndex == e) {
- FoundIndex = i;
- break;
- }
+namespace {
+class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
+public:
+ static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
+ ScalarEvolution &SE) {
+ SCEVInitRewriter Rewriter(L, SE);
+ const SCEV *Result = Rewriter.visit(Scev);
+ return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
+ }
- if (FoundIndex != Add->getNumOperands()) {
- // Create an add with everything but the specified operand.
- SmallVector<const SCEV *, 8> Ops;
- for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
- if (i != FoundIndex)
- Ops.push_back(Add->getOperand(i));
- const SCEV *Accum = getAddExpr(Ops);
-
- // This is not a valid addrec if the step amount is varying each
- // loop iteration, but is not itself an addrec in this loop.
- if (isLoopInvariant(Accum, L) ||
- (isa<SCEVAddRecExpr>(Accum) &&
- cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
- SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
-
- // If the increment doesn't overflow, then neither the addrec nor
- // the post-increment will overflow.
- if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) {
- if (OBO->getOperand(0) == PN) {
- if (OBO->hasNoUnsignedWrap())
- Flags = setFlags(Flags, SCEV::FlagNUW);
- if (OBO->hasNoSignedWrap())
- Flags = setFlags(Flags, SCEV::FlagNSW);
- }
- } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
- // If the increment is an inbounds GEP, then we know the address
- // space cannot be wrapped around. We cannot make any guarantee
- // about signed or unsigned overflow because pointers are
- // unsigned but we may have a negative index from the base
- // pointer. We can guarantee that no unsigned wrap occurs if the
- // indices form a positive value.
- if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
- Flags = setFlags(Flags, SCEV::FlagNW);
-
- const SCEV *Ptr = getSCEV(GEP->getPointerOperand());
- if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
- Flags = setFlags(Flags, SCEV::FlagNUW);
- }
+ SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
+ : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
- // We cannot transfer nuw and nsw flags from subtraction
- // operations -- sub nuw X, Y is not the same as add nuw X, -Y
- // for instance.
- }
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant))
+ Valid = false;
+ return Expr;
+ }
- const SCEV *StartVal = getSCEV(StartValueV);
- const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
-
- // Since the no-wrap flags are on the increment, they apply to the
- // post-incremented value as well.
- if (isLoopInvariant(Accum, L))
- (void)getAddRecExpr(getAddExpr(StartVal, Accum),
- Accum, L, Flags);
-
- // Okay, for the entire analysis of this edge we assumed the PHI
- // to be symbolic. We now need to go back and purge all of the
- // entries for the scalars that use the symbolic expression.
- ForgetSymbolicName(PN, SymbolicName);
- ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
- return PHISCEV;
- }
+ const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+ // Only allow AddRecExprs for this loop.
+ if (Expr->getLoop() == L)
+ return Expr->getStart();
+ Valid = false;
+ return Expr;
+ }
+
+ bool isValid() { return Valid; }
+
+private:
+ const Loop *L;
+ bool Valid;
+};
+
+class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
+public:
+ static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
+ ScalarEvolution &SE) {
+ SCEVShiftRewriter Rewriter(L, SE);
+ const SCEV *Result = Rewriter.visit(Scev);
+ return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
+ }
+
+ SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
+ : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ // Only allow AddRecExprs for this loop.
+ if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant))
+ Valid = false;
+ return Expr;
+ }
+
+ const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+ if (Expr->getLoop() == L && Expr->isAffine())
+ return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
+ Valid = false;
+ return Expr;
+ }
+ bool isValid() { return Valid; }
+
+private:
+ const Loop *L;
+ bool Valid;
+};
+} // end anonymous namespace
+
+const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
+ const Loop *L = LI.getLoopFor(PN->getParent());
+ if (!L || L->getHeader() != PN->getParent())
+ return nullptr;
+
+ // The loop may have multiple entrances or multiple exits; we can analyze
+ // this phi as an addrec if it has a unique entry value and a unique
+ // backedge value.
+ Value *BEValueV = nullptr, *StartValueV = nullptr;
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+ Value *V = PN->getIncomingValue(i);
+ if (L->contains(PN->getIncomingBlock(i))) {
+ if (!BEValueV) {
+ BEValueV = V;
+ } else if (BEValueV != V) {
+ BEValueV = nullptr;
+ break;
+ }
+ } else if (!StartValueV) {
+ StartValueV = V;
+ } else if (StartValueV != V) {
+ StartValueV = nullptr;
+ break;
+ }
+ }
+ if (BEValueV && StartValueV) {
+ // While we are analyzing this PHI node, handle its value symbolically.
+ const SCEV *SymbolicName = getUnknown(PN);
+ assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
+ "PHI node already processed?");
+ ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
+
+ // Using this symbolic name for the PHI, analyze the value coming around
+ // the back-edge.
+ const SCEV *BEValue = getSCEV(BEValueV);
+
+ // NOTE: If BEValue is loop invariant, we know that the PHI node just
+ // has a special value for the first iteration of the loop.
+
+ // If the value coming around the backedge is an add with the symbolic
+ // value we just inserted, then we found a simple induction variable!
+ if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
+ // If there is a single occurrence of the symbolic value, replace it
+ // with a recurrence.
+ unsigned FoundIndex = Add->getNumOperands();
+ for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
+ if (Add->getOperand(i) == SymbolicName)
+ if (FoundIndex == e) {
+ FoundIndex = i;
+ break;
}
- } else if (const SCEVAddRecExpr *AddRec =
- dyn_cast<SCEVAddRecExpr>(BEValue)) {
- // Otherwise, this could be a loop like this:
- // i = 0; for (j = 1; ..; ++j) { .... i = j; }
- // In this case, j = {1,+,1} and BEValue is j.
- // Because the other in-value of i (0) fits the evolution of BEValue
- // i really is an addrec evolution.
- if (AddRec->getLoop() == L && AddRec->isAffine()) {
- const SCEV *StartVal = getSCEV(StartValueV);
-
- // If StartVal = j.start - j.stride, we can use StartVal as the
- // initial step of the addrec evolution.
- if (StartVal == getMinusSCEV(AddRec->getOperand(0),
- AddRec->getOperand(1))) {
- // FIXME: For constant StartVal, we should be able to infer
- // no-wrap flags.
- const SCEV *PHISCEV =
- getAddRecExpr(StartVal, AddRec->getOperand(1), L,
- SCEV::FlagAnyWrap);
-
- // Okay, for the entire analysis of this edge we assumed the PHI
- // to be symbolic. We now need to go back and purge all of the
- // entries for the scalars that use the symbolic expression.
- ForgetSymbolicName(PN, SymbolicName);
- ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
- return PHISCEV;
+
+ if (FoundIndex != Add->getNumOperands()) {
+ // Create an add with everything but the specified operand.
+ SmallVector<const SCEV *, 8> Ops;
+ for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
+ if (i != FoundIndex)
+ Ops.push_back(Add->getOperand(i));
+ const SCEV *Accum = getAddExpr(Ops);
+
+ // This is not a valid addrec if the step amount is varying each
+ // loop iteration, but is not itself an addrec in this loop.
+ if (isLoopInvariant(Accum, L) ||
+ (isa<SCEVAddRecExpr>(Accum) &&
+ cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
+ SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
+
+ // If the increment doesn't overflow, then neither the addrec nor
+ // the post-increment will overflow.
+ if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) {
+ if (OBO->getOperand(0) == PN) {
+ if (OBO->hasNoUnsignedWrap())
+ Flags = setFlags(Flags, SCEV::FlagNUW);
+ if (OBO->hasNoSignedWrap())
+ Flags = setFlags(Flags, SCEV::FlagNSW);
}
+ } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
+ // If the increment is an inbounds GEP, then we know the address
+ // space cannot be wrapped around. We cannot make any guarantee
+ // about signed or unsigned overflow because pointers are
+ // unsigned but we may have a negative index from the base
+ // pointer. We can guarantee that no unsigned wrap occurs if the
+ // indices form a positive value.
+ if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
+ Flags = setFlags(Flags, SCEV::FlagNW);
+
+ const SCEV *Ptr = getSCEV(GEP->getPointerOperand());
+ if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
+ Flags = setFlags(Flags, SCEV::FlagNUW);
+ }
+
+ // We cannot transfer nuw and nsw flags from subtraction
+ // operations -- sub nuw X, Y is not the same as add nuw X, -Y
+ // for instance.
}
+
+ const SCEV *StartVal = getSCEV(StartValueV);
+ const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
+
+ // Since the no-wrap flags are on the increment, they apply to the
+ // post-incremented value as well.
+ if (isLoopInvariant(Accum, L))
+ (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
+
+ // Okay, for the entire analysis of this edge we assumed the PHI
+ // to be symbolic. We now need to go back and purge all of the
+ // entries for the scalars that use the symbolic expression.
+ ForgetSymbolicName(PN, SymbolicName);
+ ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
+ return PHISCEV;
+ }
+ }
+ } else {
+ // Otherwise, this could be a loop like this:
+ // i = 0; for (j = 1; ..; ++j) { .... i = j; }
+ // In this case, j = {1,+,1} and BEValue is j.
+ // Because the other in-value of i (0) fits the evolution of BEValue
+ // i really is an addrec evolution.
+ //
+ // We can generalize this saying that i is the shifted value of BEValue
+ // by one iteration:
+ // PHI(f(0), f({1,+,1})) --> f({0,+,1})
+ const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
+ const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this);
+ if (Shifted != getCouldNotCompute() &&
+ Start != getCouldNotCompute()) {
+ const SCEV *StartVal = getSCEV(StartValueV);
+ if (Start == StartVal) {
+ // Okay, for the entire analysis of this edge we assumed the PHI
+ // to be symbolic. We now need to go back and purge all of the
+ // entries for the scalars that use the symbolic expression.
+ ForgetSymbolicName(PN, SymbolicName);
+ ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted;
+ return Shifted;
}
}
}
+ }
+
+ return nullptr;
+}
+
+// Checks if the SCEV S is available at BB. S is considered available at BB
+// if S can be materialized at BB without introducing a fault.
+static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
+ BasicBlock *BB) {
+ struct CheckAvailable {
+ bool TraversalDone = false;
+ bool Available = true;
+
+ const Loop *L = nullptr; // The loop BB is in (can be nullptr)
+ BasicBlock *BB = nullptr;
+ DominatorTree &DT;
+
+ CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT)
+ : L(L), BB(BB), DT(DT) {}
+
+ bool setUnavailable() {
+ TraversalDone = true;
+ Available = false;
+ return false;
+ }
+
+ bool follow(const SCEV *S) {
+ switch (S->getSCEVType()) {
+ case scConstant: case scTruncate: case scZeroExtend: case scSignExtend:
+ case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr:
+ // These expressions are available if their operand(s) is/are.
+ return true;
+
+ case scAddRecExpr: {
+ // We allow add recurrences that are on the loop BB is in, or some
+ // outer loop. This guarantees availability because the value of the
+ // add recurrence at BB is simply the "current" value of the induction
+ // variable. We can relax this in the future; for instance an add
+ // recurrence on a sibling dominating loop is also available at BB.
+ const auto *ARLoop = cast<SCEVAddRecExpr>(S)->getLoop();
+ if (L && (ARLoop == L || ARLoop->contains(L)))
+ return true;
+
+ return setUnavailable();
+ }
+
+ case scUnknown: {
+ // For SCEVUnknown, we check for simple dominance.
+ const auto *SU = cast<SCEVUnknown>(S);
+ Value *V = SU->getValue();
+
+ if (isa<Argument>(V))
+ return false;
+
+ if (isa<Instruction>(V) && DT.dominates(cast<Instruction>(V), BB))
+ return false;
+
+ return setUnavailable();
+ }
+
+ case scUDivExpr:
+ case scCouldNotCompute:
+ // We do not try to smart about these at all.
+ return setUnavailable();
+ }
+ llvm_unreachable("switch should be fully covered!");
+ }
+
+ bool isDone() { return TraversalDone; }
+ };
+
+ CheckAvailable CA(L, BB, DT);
+ SCEVTraversal<CheckAvailable> ST(CA);
+
+ ST.visitAll(S);
+ return CA.Available;
+}
+
+// Try to match a control flow sequence that branches out at BI and merges back
+// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
+// match.
+static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge,
+ Value *&C, Value *&LHS, Value *&RHS) {
+ C = BI->getCondition();
+
+ BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
+ BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
+
+ if (!LeftEdge.isSingleEdge())
+ return false;
+
+ assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
+
+ Use &LeftUse = Merge->getOperandUse(0);
+ Use &RightUse = Merge->getOperandUse(1);
+
+ if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
+ LHS = LeftUse;
+ RHS = RightUse;
+ return true;
+ }
+
+ if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
+ LHS = RightUse;
+ RHS = LeftUse;
+ return true;
+ }
+
+ return false;
+}
+
+const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
+ if (PN->getNumIncomingValues() == 2) {
+ const Loop *L = LI.getLoopFor(PN->getParent());
+
+ // We don't want to break LCSSA, even in a SCEV expression tree.
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
+ if (LI.getLoopFor(PN->getIncomingBlock(i)) != L)
+ return nullptr;
+
+ // Try to match
+ //
+ // br %cond, label %left, label %right
+ // left:
+ // br label %merge
+ // right:
+ // br label %merge
+ // merge:
+ // V = phi [ %x, %left ], [ %y, %right ]
+ //
+ // as "select %cond, %x, %y"
+
+ BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
+ assert(IDom && "At least the entry block should dominate PN");
+
+ auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
+ Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
+
+ if (BI && BI->isConditional() &&
+ BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
+ IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) &&
+ IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent()))
+ return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
+ }
+
+ return nullptr;
+}
+
+const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
+ if (const SCEV *S = createAddRecFromPHI(PN))
+ return S;
+
+ if (const SCEV *S = createNodeFromSelectLikePHI(PN))
+ return S;
// If the PHI has a single incoming value, follow that value, unless the
// PHI's incoming blocks are in a different loop, in which case doing so
// risks breaking LCSSA form. Instcombine would normally zap these, but
// it doesn't have DominatorTree information, so it may miss cases.
- if (Value *V = SimplifyInstruction(PN, F.getParent()->getDataLayout(), &TLI,
- &DT, &AC))
+ if (Value *V = SimplifyInstruction(PN, getDataLayout(), &TLI, &DT, &AC))
if (LI.replacementPreservesLCSSAForm(PN, V))
return getSCEV(V);
return getUnknown(PN);
}
+const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,
+ Value *Cond,
+ Value *TrueVal,
+ Value *FalseVal) {
+ // Handle "constant" branch or select. This can occur for instance when a
+ // loop pass transforms an inner loop and moves on to process the outer loop.
+ if (auto *CI = dyn_cast<ConstantInt>(Cond))
+ return getSCEV(CI->isOne() ? TrueVal : FalseVal);
+
+ // Try to match some simple smax or umax patterns.
+ auto *ICI = dyn_cast<ICmpInst>(Cond);
+ if (!ICI)
+ return getUnknown(I);
+
+ Value *LHS = ICI->getOperand(0);
+ Value *RHS = ICI->getOperand(1);
+
+ switch (ICI->getPredicate()) {
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_SLE:
+ std::swap(LHS, RHS);
+ // fall through
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_SGE:
+ // a >s b ? a+x : b+x -> smax(a, b)+x
+ // a >s b ? b+x : a+x -> smin(a, b)+x
+ if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
+ const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType());
+ const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType());
+ const SCEV *LA = getSCEV(TrueVal);
+ const SCEV *RA = getSCEV(FalseVal);
+ const SCEV *LDiff = getMinusSCEV(LA, LS);
+ const SCEV *RDiff = getMinusSCEV(RA, RS);
+ if (LDiff == RDiff)
+ return getAddExpr(getSMaxExpr(LS, RS), LDiff);
+ LDiff = getMinusSCEV(LA, RS);
+ RDiff = getMinusSCEV(RA, LS);
+ if (LDiff == RDiff)
+ return getAddExpr(getSMinExpr(LS, RS), LDiff);
+ }
+ break;
+ case ICmpInst::ICMP_ULT:
+ case ICmpInst::ICMP_ULE:
+ std::swap(LHS, RHS);
+ // fall through
+ case ICmpInst::ICMP_UGT:
+ case ICmpInst::ICMP_UGE:
+ // a >u b ? a+x : b+x -> umax(a, b)+x
+ // a >u b ? b+x : a+x -> umin(a, b)+x
+ if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) {
+ const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
+ const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType());
+ const SCEV *LA = getSCEV(TrueVal);
+ const SCEV *RA = getSCEV(FalseVal);
+ const SCEV *LDiff = getMinusSCEV(LA, LS);
+ const SCEV *RDiff = getMinusSCEV(RA, RS);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMaxExpr(LS, RS), LDiff);
+ LDiff = getMinusSCEV(LA, RS);
+ RDiff = getMinusSCEV(RA, LS);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMinExpr(LS, RS), LDiff);
+ }
+ break;
+ case ICmpInst::ICMP_NE:
+ // n != 0 ? n+x : 1+x -> umax(n, 1)+x
+ if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
+ isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+ const SCEV *One = getOne(I->getType());
+ const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
+ const SCEV *LA = getSCEV(TrueVal);
+ const SCEV *RA = getSCEV(FalseVal);
+ const SCEV *LDiff = getMinusSCEV(LA, LS);
+ const SCEV *RDiff = getMinusSCEV(RA, One);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMaxExpr(One, LS), LDiff);
+ }
+ break;
+ case ICmpInst::ICMP_EQ:
+ // n == 0 ? 1+x : n+x -> umax(n, 1)+x
+ if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) &&
+ isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+ const SCEV *One = getOne(I->getType());
+ const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType());
+ const SCEV *LA = getSCEV(TrueVal);
+ const SCEV *RA = getSCEV(FalseVal);
+ const SCEV *LDiff = getMinusSCEV(LA, One);
+ const SCEV *RDiff = getMinusSCEV(RA, LS);
+ if (LDiff == RDiff)
+ return getAddExpr(getUMaxExpr(One, LS), LDiff);
+ }
+ break;
+ default:
+ break;
+ }
+
+ return getUnknown(I);
+}
+
/// createNodeForGEP - Expand GEP instructions into add and multiply
/// operations. This allows them to be analyzed by regular SCEV code.
///
uint32_t
ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
- return C->getValue()->getValue().countTrailingZeros();
+ return C->getAPInt().countTrailingZeros();
if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
return std::min(GetMinTrailingZeros(T->getOperand()),
// For a SCEVUnknown, ask ValueTracking.
unsigned BitWidth = getTypeSizeInBits(U->getType());
APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
- computeKnownBits(U->getValue(), Zeros, Ones, F.getParent()->getDataLayout(),
- 0, &AC, nullptr, &DT);
+ computeKnownBits(U->getValue(), Zeros, Ones, getDataLayout(), 0, &AC,
+ nullptr, &DT);
return Zeros.countTrailingOnes();
}
/// GetRangeFromMetadata - Helper method to assign a range to V from
/// metadata present in the IR.
static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
- if (Instruction *I = dyn_cast<Instruction>(V)) {
- if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) {
- ConstantRange TotalRange(
- cast<IntegerType>(I->getType())->getBitWidth(), false);
-
- unsigned NumRanges = MD->getNumOperands() / 2;
- assert(NumRanges >= 1);
-
- for (unsigned i = 0; i < NumRanges; ++i) {
- ConstantInt *Lower =
- mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 0));
- ConstantInt *Upper =
- mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 1));
- ConstantRange Range(Lower->getValue(), Upper->getValue());
- TotalRange = TotalRange.unionWith(Range);
- }
-
- return TotalRange;
- }
- }
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
+ return getConstantRangeFromMetadata(*MD);
return None;
}
return I->second;
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
- return setRange(C, SignHint, ConstantRange(C->getValue()->getValue()));
+ return setRange(C, SignHint, ConstantRange(C->getAPInt()));
unsigned BitWidth = getTypeSizeInBits(S->getType());
ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
if (AddRec->getNoWrapFlags(SCEV::FlagNUW))
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
if (!C->getValue()->isZero())
- ConservativeResult =
- ConservativeResult.intersectWith(
- ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0)));
+ ConservativeResult = ConservativeResult.intersectWith(
+ ConstantRange(C->getAPInt(), APInt(BitWidth, 0)));
// If there's no signed wrap, and all the operands have the same sign or
// zero, the value won't ever change sign.
// Split here to avoid paying the compile-time cost of calling both
// computeKnownBits and ComputeNumSignBits. This restriction can be lifted
// if needed.
- const DataLayout &DL = F.getParent()->getDataLayout();
+ const DataLayout &DL = getDataLayout();
if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
// For a SCEVUnknown, ask ValueTracking.
APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
unsigned TZ = A.countTrailingZeros();
unsigned BitWidth = A.getBitWidth();
APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
- computeKnownBits(U->getOperand(0), KnownZero, KnownOne,
- F.getParent()->getDataLayout(), 0, &AC, nullptr, &DT);
+ computeKnownBits(U->getOperand(0), KnownZero, KnownOne, getDataLayout(),
+ 0, &AC, nullptr, &DT);
APInt EffectiveMask =
APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
return createNodeForPHI(cast<PHINode>(U));
case Instruction::Select:
- // This could be a smax or umax that was lowered earlier.
- // Try to recover it.
- if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
- Value *LHS = ICI->getOperand(0);
- Value *RHS = ICI->getOperand(1);
- switch (ICI->getPredicate()) {
- case ICmpInst::ICMP_SLT:
- case ICmpInst::ICMP_SLE:
- std::swap(LHS, RHS);
- // fall through
- case ICmpInst::ICMP_SGT:
- case ICmpInst::ICMP_SGE:
- // a >s b ? a+x : b+x -> smax(a, b)+x
- // a >s b ? b+x : a+x -> smin(a, b)+x
- if (getTypeSizeInBits(LHS->getType()) <=
- getTypeSizeInBits(U->getType())) {
- const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType());
- const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType());
- const SCEV *LA = getSCEV(U->getOperand(1));
- const SCEV *RA = getSCEV(U->getOperand(2));
- const SCEV *LDiff = getMinusSCEV(LA, LS);
- const SCEV *RDiff = getMinusSCEV(RA, RS);
- if (LDiff == RDiff)
- return getAddExpr(getSMaxExpr(LS, RS), LDiff);
- LDiff = getMinusSCEV(LA, RS);
- RDiff = getMinusSCEV(RA, LS);
- if (LDiff == RDiff)
- return getAddExpr(getSMinExpr(LS, RS), LDiff);
- }
- break;
- case ICmpInst::ICMP_ULT:
- case ICmpInst::ICMP_ULE:
- std::swap(LHS, RHS);
- // fall through
- case ICmpInst::ICMP_UGT:
- case ICmpInst::ICMP_UGE:
- // a >u b ? a+x : b+x -> umax(a, b)+x
- // a >u b ? b+x : a+x -> umin(a, b)+x
- if (getTypeSizeInBits(LHS->getType()) <=
- getTypeSizeInBits(U->getType())) {
- const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
- const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType());
- const SCEV *LA = getSCEV(U->getOperand(1));
- const SCEV *RA = getSCEV(U->getOperand(2));
- const SCEV *LDiff = getMinusSCEV(LA, LS);
- const SCEV *RDiff = getMinusSCEV(RA, RS);
- if (LDiff == RDiff)
- return getAddExpr(getUMaxExpr(LS, RS), LDiff);
- LDiff = getMinusSCEV(LA, RS);
- RDiff = getMinusSCEV(RA, LS);
- if (LDiff == RDiff)
- return getAddExpr(getUMinExpr(LS, RS), LDiff);
- }
- break;
- case ICmpInst::ICMP_NE:
- // n != 0 ? n+x : 1+x -> umax(n, 1)+x
- if (getTypeSizeInBits(LHS->getType()) <=
- getTypeSizeInBits(U->getType()) &&
- isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
- const SCEV *One = getOne(U->getType());
- const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
- const SCEV *LA = getSCEV(U->getOperand(1));
- const SCEV *RA = getSCEV(U->getOperand(2));
- const SCEV *LDiff = getMinusSCEV(LA, LS);
- const SCEV *RDiff = getMinusSCEV(RA, One);
- if (LDiff == RDiff)
- return getAddExpr(getUMaxExpr(One, LS), LDiff);
- }
- break;
- case ICmpInst::ICMP_EQ:
- // n == 0 ? 1+x : n+x -> umax(n, 1)+x
- if (getTypeSizeInBits(LHS->getType()) <=
- getTypeSizeInBits(U->getType()) &&
- isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
- const SCEV *One = getOne(U->getType());
- const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
- const SCEV *LA = getSCEV(U->getOperand(1));
- const SCEV *RA = getSCEV(U->getOperand(2));
- const SCEV *LDiff = getMinusSCEV(LA, One);
- const SCEV *RDiff = getMinusSCEV(RA, LS);
- if (LDiff == RDiff)
- return getAddExpr(getUMaxExpr(One, LS), LDiff);
- }
- break;
- default:
- break;
- }
- }
+ // U can also be a select constant expr, which let fall through. Since
+ // createNodeForSelect only works for a condition that is an `ICmpInst`, and
+ // constant expressions cannot have instructions as operands, we'd have
+ // returned getUnknown for a select constant expressions anyway.
+ if (isa<Instruction>(U))
+ return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0),
+ U->getOperand(1), U->getOperand(2));
default: // We cannot analyze this expression.
break;
if (!Pair.second)
return Pair.first->second;
- // ComputeBackedgeTakenCount may allocate memory for its result. Inserting it
+ // computeBackedgeTakenCount may allocate memory for its result. Inserting it
// into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
// must be cleared in this scope.
- BackedgeTakenInfo Result = ComputeBackedgeTakenCount(L);
+ BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
if (Result.getExact(this) != getCouldNotCompute()) {
assert(isLoopInvariant(Result.getExact(this), L) &&
}
// Re-lookup the insert position, since the call to
- // ComputeBackedgeTakenCount above could result in a
+ // computeBackedgeTakenCount above could result in a
// recusive call to getBackedgeTakenInfo (on a different
// loop), which would invalidate the iterator computed
// earlier.
delete[] ExitNotTaken.getNextExit();
}
-/// ComputeBackedgeTakenCount - Compute the number of times the backedge
+/// computeBackedgeTakenCount - Compute the number of times the backedge
/// of the specified loop will execute.
ScalarEvolution::BackedgeTakenInfo
-ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
+ScalarEvolution::computeBackedgeTakenCount(const Loop *L) {
SmallVector<BasicBlock *, 8> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
// and compute maxBECount.
for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
BasicBlock *ExitBB = ExitingBlocks[i];
- ExitLimit EL = ComputeExitLimit(L, ExitBB);
+ ExitLimit EL = computeExitLimit(L, ExitBB);
// 1. For each exit that can be computed, add an entry to ExitCounts.
// CouldComputeBECount is true only if all exits can be computed.
return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount);
}
-/// ComputeExitLimit - Compute the number of times the backedge of the specified
-/// loop will execute if it exits via the specified block.
ScalarEvolution::ExitLimit
-ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) {
+ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) {
- // Okay, we've chosen an exiting block. See what condition causes us to
- // exit at this block and remember the exit block and whether all other targets
+ // Okay, we've chosen an exiting block. See what condition causes us to exit
+ // at this block and remember the exit block and whether all other targets
// lead to the loop header.
bool MustExecuteLoopHeader = true;
BasicBlock *Exit = nullptr;
if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
assert(BI->isConditional() && "If unconditional, it can't be in loop!");
// Proceed to the next level to examine the exit condition expression.
- return ComputeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0),
+ return computeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0),
BI->getSuccessor(1),
/*ControlsExit=*/IsOnlyExit);
}
if (SwitchInst *SI = dyn_cast<SwitchInst>(Term))
- return ComputeExitLimitFromSingleExitSwitch(L, SI, Exit,
+ return computeExitLimitFromSingleExitSwitch(L, SI, Exit,
/*ControlsExit=*/IsOnlyExit);
return getCouldNotCompute();
}
-/// ComputeExitLimitFromCond - Compute the number of times the
+/// computeExitLimitFromCond - Compute the number of times the
/// backedge of the specified loop will execute if its exit condition
/// were a conditional branch of ExitCond, TBB, and FBB.
///
/// condition is true and can infer that failing to meet the condition prior to
/// integer wraparound results in undefined behavior.
ScalarEvolution::ExitLimit
-ScalarEvolution::ComputeExitLimitFromCond(const Loop *L,
+ScalarEvolution::computeExitLimitFromCond(const Loop *L,
Value *ExitCond,
BasicBlock *TBB,
BasicBlock *FBB,
if (BO->getOpcode() == Instruction::And) {
// Recurse on the operands of the and.
bool EitherMayExit = L->contains(TBB);
- ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB,
+ ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB,
ControlsExit && !EitherMayExit);
- ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB,
+ ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB,
ControlsExit && !EitherMayExit);
const SCEV *BECount = getCouldNotCompute();
const SCEV *MaxBECount = getCouldNotCompute();
if (BO->getOpcode() == Instruction::Or) {
// Recurse on the operands of the or.
bool EitherMayExit = L->contains(FBB);
- ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB,
+ ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB,
ControlsExit && !EitherMayExit);
- ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB,
+ ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB,
ControlsExit && !EitherMayExit);
const SCEV *BECount = getCouldNotCompute();
const SCEV *MaxBECount = getCouldNotCompute();
// With an icmp, it may be feasible to compute an exact backedge-taken count.
// Proceed to the next level to examine the icmp.
if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
- return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit);
+ return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit);
// Check for a constant condition. These are normally stripped out by
// SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
}
// If it's not an integer or pointer comparison then compute it the hard way.
- return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
+ return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
}
-/// ComputeExitLimitFromICmp - Compute the number of times the
-/// backedge of the specified loop will execute if its exit condition
-/// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB.
ScalarEvolution::ExitLimit
-ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L,
+ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
ICmpInst *ExitCond,
BasicBlock *TBB,
BasicBlock *FBB,
if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
ExitLimit ItCnt =
- ComputeLoadConstantCompareExitLimit(LI, RHS, L, Cond);
+ computeLoadConstantCompareExitLimit(LI, RHS, L, Cond);
if (ItCnt.hasAnyInfo())
return ItCnt;
}
+ ExitLimit ShiftEL = computeShiftCompareExitLimit(
+ ExitCond->getOperand(0), ExitCond->getOperand(1), L, Cond);
+ if (ShiftEL.hasAnyInfo())
+ return ShiftEL;
+
const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
if (AddRec->getLoop() == L) {
// Form the constant range.
ConstantRange CompRange(
- ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue()));
+ ICmpInst::makeConstantRange(Cond, RHSC->getAPInt()));
const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
break;
}
default:
-#if 0
- dbgs() << "ComputeBackedgeTakenCount ";
- if (ExitCond->getOperand(0)->getType()->isUnsigned())
- dbgs() << "[unsigned] ";
- dbgs() << *LHS << " "
- << Instruction::getOpcodeName(Instruction::ICmp)
- << " " << *RHS << "\n";
-#endif
break;
}
- return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
+ return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
}
ScalarEvolution::ExitLimit
-ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L,
+ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
SwitchInst *Switch,
BasicBlock *ExitingBlock,
bool ControlsExit) {
return cast<SCEVConstant>(Val)->getValue();
}
-/// ComputeLoadConstantCompareExitLimit - Given an exit condition of
+/// computeLoadConstantCompareExitLimit - Given an exit condition of
/// 'icmp op load X, cst', try to see if we can compute the backedge
/// execution count.
ScalarEvolution::ExitLimit
-ScalarEvolution::ComputeLoadConstantCompareExitLimit(
+ScalarEvolution::computeLoadConstantCompareExitLimit(
LoadInst *LI,
Constant *RHS,
const Loop *L,
// Form the GEP offset.
Indexes[VarIdxNum] = Val;
- Constant *Result = ConstantFoldLoadThroughGEPIndices(GV->getInitializer(),
- Indexes);
- if (!Result) break; // Cannot compute!
+ Constant *Result = ConstantFoldLoadThroughGEPIndices(GV->getInitializer(),
+ Indexes);
+ if (!Result) break; // Cannot compute!
+
+ // Evaluate the condition for this iteration.
+ Result = ConstantExpr::getICmp(predicate, Result, RHS);
+ if (!isa<ConstantInt>(Result)) break; // Couldn't decide for sure
+ if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
+ ++NumArrayLenItCounts;
+ return getConstant(ItCst); // Found terminating iteration!
+ }
+ }
+ return getCouldNotCompute();
+}
+
+ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
+ Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
+ ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
+ if (!RHS)
+ return getCouldNotCompute();
+
+ const BasicBlock *Latch = L->getLoopLatch();
+ if (!Latch)
+ return getCouldNotCompute();
+
+ const BasicBlock *Predecessor = L->getLoopPredecessor();
+ if (!Predecessor)
+ return getCouldNotCompute();
+
+ // Return true if V is of the form "LHS `shift_op` <positive constant>".
+ // Return LHS in OutLHS and shift_opt in OutOpCode.
+ auto MatchPositiveShift =
+ [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
+
+ using namespace PatternMatch;
+
+ ConstantInt *ShiftAmt;
+ if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
+ OutOpCode = Instruction::LShr;
+ else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
+ OutOpCode = Instruction::AShr;
+ else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
+ OutOpCode = Instruction::Shl;
+ else
+ return false;
+
+ return ShiftAmt->getValue().isStrictlyPositive();
+ };
+
+ // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
+ //
+ // loop:
+ // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
+ // %iv.shifted = lshr i32 %iv, <positive constant>
+ //
+ // Return true on a succesful match. Return the corresponding PHI node (%iv
+ // above) in PNOut and the opcode of the shift operation in OpCodeOut.
+ auto MatchShiftRecurrence =
+ [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
+ Optional<Instruction::BinaryOps> PostShiftOpCode;
+
+ {
+ Instruction::BinaryOps OpC;
+ Value *V;
+
+ // If we encounter a shift instruction, "peel off" the shift operation,
+ // and remember that we did so. Later when we inspect %iv's backedge
+ // value, we will make sure that the backedge value uses the same
+ // operation.
+ //
+ // Note: the peeled shift operation does not have to be the same
+ // instruction as the one feeding into the PHI's backedge value. We only
+ // really care about it being the same *kind* of shift instruction --
+ // that's all that is required for our later inferences to hold.
+ if (MatchPositiveShift(LHS, V, OpC)) {
+ PostShiftOpCode = OpC;
+ LHS = V;
+ }
+ }
+
+ PNOut = dyn_cast<PHINode>(LHS);
+ if (!PNOut || PNOut->getParent() != L->getHeader())
+ return false;
+
+ Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
+ Value *OpLHS;
+
+ return
+ // The backedge value for the PHI node must be a shift by a positive
+ // amount
+ MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
+
+ // of the PHI node itself
+ OpLHS == PNOut &&
+
+ // and the kind of shift should be match the kind of shift we peeled
+ // off, if any.
+ (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut);
+ };
+
+ PHINode *PN;
+ Instruction::BinaryOps OpCode;
+ if (!MatchShiftRecurrence(LHS, PN, OpCode))
+ return getCouldNotCompute();
+
+ const DataLayout &DL = getDataLayout();
+
+ // The key rationale for this optimization is that for some kinds of shift
+ // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
+ // within a finite number of iterations. If the condition guarding the
+ // backedge (in the sense that the backedge is taken if the condition is true)
+ // is false for the value the shift recurrence stabilizes to, then we know
+ // that the backedge is taken only a finite number of times.
+
+ ConstantInt *StableValue = nullptr;
+ switch (OpCode) {
+ default:
+ llvm_unreachable("Impossible case!");
+
+ case Instruction::AShr: {
+ // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
+ // bitwidth(K) iterations.
+ Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
+ bool KnownZero, KnownOne;
+ ComputeSignBit(FirstValue, KnownZero, KnownOne, DL, 0, nullptr,
+ Predecessor->getTerminator(), &DT);
+ auto *Ty = cast<IntegerType>(RHS->getType());
+ if (KnownZero)
+ StableValue = ConstantInt::get(Ty, 0);
+ else if (KnownOne)
+ StableValue = ConstantInt::get(Ty, -1, true);
+ else
+ return getCouldNotCompute();
+
+ break;
+ }
+ case Instruction::LShr:
+ case Instruction::Shl:
+ // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
+ // stabilize to 0 in at most bitwidth(K) iterations.
+ StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
+ break;
+ }
+
+ auto *Result =
+ ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
+ assert(Result->getType()->isIntegerTy(1) &&
+ "Otherwise cannot be an operand to a branch instruction");
- // Evaluate the condition for this iteration.
- Result = ConstantExpr::getICmp(predicate, Result, RHS);
- if (!isa<ConstantInt>(Result)) break; // Couldn't decide for sure
- if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
-#if 0
- dbgs() << "\n***\n*** Computed loop count " << *ItCst
- << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
- << "***\n";
-#endif
- ++NumArrayLenItCounts;
- return getConstant(ItCst); // Found terminating iteration!
- }
+ if (Result->isZeroValue()) {
+ unsigned BitWidth = getTypeSizeInBits(RHS->getType());
+ const SCEV *UpperBound =
+ getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
+ return ExitLimit(getCouldNotCompute(), UpperBound);
}
+
return getCouldNotCompute();
}
-
/// CanConstantFold - Return true if we can constant fold an instruction of the
/// specified type, assuming that all operands were constants.
static bool CanConstantFold(const Instruction *I) {
// Otherwise, we can evaluate this instruction if all of its operands are
// constant or derived from a PHI node themselves.
PHINode *PHI = nullptr;
- for (Instruction::op_iterator OpI = UseInst->op_begin(),
- OpE = UseInst->op_end(); OpI != OpE; ++OpI) {
-
- if (isa<Constant>(*OpI)) continue;
+ for (Value *Op : UseInst->operands()) {
+ if (isa<Constant>(Op)) continue;
- Instruction *OpInst = dyn_cast<Instruction>(*OpI);
+ Instruction *OpInst = dyn_cast<Instruction>(Op);
if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
PHINode *P = dyn_cast<PHINode>(OpInst);
Instruction *I = dyn_cast<Instruction>(V);
if (!I || !canConstantEvolve(I, L)) return nullptr;
- if (PHINode *PN = dyn_cast<PHINode>(I)) {
+ if (PHINode *PN = dyn_cast<PHINode>(I))
return PN;
- }
// Record non-constant instructions contained by the loop.
DenseMap<Instruction *, PHINode *> PHIMap;
TLI);
}
+
+// If every incoming value to PN except the one for BB is a specific Constant,
+// return that, else return nullptr.
+static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) {
+ Constant *IncomingVal = nullptr;
+
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+ if (PN->getIncomingBlock(i) == BB)
+ continue;
+
+ auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
+ if (!CurrentVal)
+ return nullptr;
+
+ if (IncomingVal != CurrentVal) {
+ if (IncomingVal)
+ return nullptr;
+ IncomingVal = CurrentVal;
+ }
+ }
+
+ return IncomingVal;
+}
+
/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
/// in the header of its containing loop, we know the loop executes a
/// constant number of times, and the PHI node is just a recurrence
ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
const APInt &BEs,
const Loop *L) {
- DenseMap<PHINode*, Constant*>::const_iterator I =
- ConstantEvolutionLoopExitValue.find(PN);
+ auto I = ConstantEvolutionLoopExitValue.find(PN);
if (I != ConstantEvolutionLoopExitValue.end())
return I->second;
BasicBlock *Header = L->getHeader();
assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
- // Since the loop is canonicalized, the PHI node must have two entries. One
- // entry must be a constant (coming in from outside of the loop), and the
- // second must be derived from the same PHI.
- bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
- PHINode *PHI = nullptr;
- for (BasicBlock::iterator I = Header->begin();
- (PHI = dyn_cast<PHINode>(I)); ++I) {
- Constant *StartCST =
- dyn_cast<Constant>(PHI->getIncomingValue(!SecondIsBackedge));
+ BasicBlock *Latch = L->getLoopLatch();
+ if (!Latch)
+ return nullptr;
+
+ for (auto &I : *Header) {
+ PHINode *PHI = dyn_cast<PHINode>(&I);
+ if (!PHI) break;
+ auto *StartCST = getOtherIncomingValue(PHI, Latch);
if (!StartCST) continue;
CurrentIterVals[PHI] = StartCST;
}
if (!CurrentIterVals.count(PN))
return RetVal = nullptr;
- Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
+ Value *BEValue = PN->getIncomingValueForBlock(Latch);
// Execute the loop symbolically to determine the exit value.
if (BEs.getActiveBits() >= 32)
unsigned NumIterations = BEs.getZExtValue(); // must be in range
unsigned IterationNum = 0;
- const DataLayout &DL = F.getParent()->getDataLayout();
+ const DataLayout &DL = getDataLayout();
for (; ; ++IterationNum) {
if (IterationNum == NumIterations)
return RetVal = CurrentIterVals[PN]; // Got exit value!
// cease to be able to evaluate one of them or if they stop evolving,
// because that doesn't necessarily prevent us from computing PN.
SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute;
- for (DenseMap<Instruction *, Constant *>::const_iterator
- I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){
- PHINode *PHI = dyn_cast<PHINode>(I->first);
+ for (const auto &I : CurrentIterVals) {
+ PHINode *PHI = dyn_cast<PHINode>(I.first);
if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
- PHIsToCompute.push_back(std::make_pair(PHI, I->second));
+ PHIsToCompute.emplace_back(PHI, I.second);
}
// We use two distinct loops because EvaluateExpression may invalidate any
// iterators into CurrentIterVals.
- for (SmallVectorImpl<std::pair<PHINode *, Constant*> >::const_iterator
- I = PHIsToCompute.begin(), E = PHIsToCompute.end(); I != E; ++I) {
- PHINode *PHI = I->first;
+ for (const auto &I : PHIsToCompute) {
+ PHINode *PHI = I.first;
Constant *&NextPHI = NextIterVals[PHI];
if (!NextPHI) { // Not already computed.
- Value *BEValue = PHI->getIncomingValue(SecondIsBackedge);
+ Value *BEValue = PHI->getIncomingValueForBlock(Latch);
NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
}
- if (NextPHI != I->second)
+ if (NextPHI != I.second)
StoppedEvolving = false;
}
}
}
-/// ComputeExitCountExhaustively - If the loop is known to execute a
-/// constant number of times (the condition evolves only from constants),
-/// try to evaluate a few iterations of the loop until we get the exit
-/// condition gets a value of ExitWhen (true or false). If we cannot
-/// evaluate the trip count of the loop, return getCouldNotCompute().
-const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L,
+const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
Value *Cond,
bool ExitWhen) {
PHINode *PN = getConstantEvolvingPHI(Cond, L);
BasicBlock *Header = L->getHeader();
assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
- // One entry must be a constant (coming in from outside of the loop), and the
- // second must be derived from the same PHI.
- bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
- PHINode *PHI = nullptr;
- for (BasicBlock::iterator I = Header->begin();
- (PHI = dyn_cast<PHINode>(I)); ++I) {
- Constant *StartCST =
- dyn_cast<Constant>(PHI->getIncomingValue(!SecondIsBackedge));
+ BasicBlock *Latch = L->getLoopLatch();
+ assert(Latch && "Should follow from NumIncomingValues == 2!");
+
+ for (auto &I : *Header) {
+ PHINode *PHI = dyn_cast<PHINode>(&I);
+ if (!PHI)
+ break;
+ auto *StartCST = getOtherIncomingValue(PHI, Latch);
if (!StartCST) continue;
CurrentIterVals[PHI] = StartCST;
}
// the loop symbolically to determine when the condition gets a value of
// "ExitWhen".
unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
- const DataLayout &DL = F.getParent()->getDataLayout();
+ const DataLayout &DL = getDataLayout();
for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
- ConstantInt *CondVal = dyn_cast_or_null<ConstantInt>(
+ auto *CondVal = dyn_cast_or_null<ConstantInt>(
EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
// Couldn't symbolically evaluate.
// calling EvaluateExpression on them because that may invalidate iterators
// into CurrentIterVals.
SmallVector<PHINode *, 8> PHIsToCompute;
- for (DenseMap<Instruction *, Constant *>::const_iterator
- I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){
- PHINode *PHI = dyn_cast<PHINode>(I->first);
+ for (const auto &I : CurrentIterVals) {
+ PHINode *PHI = dyn_cast<PHINode>(I.first);
if (!PHI || PHI->getParent() != Header) continue;
PHIsToCompute.push_back(PHI);
}
- for (SmallVectorImpl<PHINode *>::const_iterator I = PHIsToCompute.begin(),
- E = PHIsToCompute.end(); I != E; ++I) {
- PHINode *PHI = *I;
+ for (PHINode *PHI : PHIsToCompute) {
Constant *&NextPHI = NextIterVals[PHI];
if (NextPHI) continue; // Already computed!
- Value *BEValue = PHI->getIncomingValue(SecondIsBackedge);
+ Value *BEValue = PHI->getIncomingValueForBlock(Latch);
NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
}
CurrentIterVals.swap(NextIterVals);
/// In the case that a relevant loop exit value cannot be computed, the
/// original value V is returned.
const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
+ SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
+ ValuesAtScopes[V];
// Check to see if we've folded this expression at this loop before.
- SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = ValuesAtScopes[V];
- for (unsigned u = 0; u < Values.size(); u++) {
- if (Values[u].first == L)
- return Values[u].second ? Values[u].second : V;
- }
- Values.push_back(std::make_pair(L, static_cast<const SCEV *>(nullptr)));
+ for (auto &LS : Values)
+ if (LS.first == L)
+ return LS.second ? LS.second : V;
+
+ Values.emplace_back(L, nullptr);
+
// Otherwise compute it.
const SCEV *C = computeSCEVAtScope(V, L);
- SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values2 = ValuesAtScopes[V];
- for (unsigned u = Values2.size(); u > 0; u--) {
- if (Values2[u - 1].first == L) {
- Values2[u - 1].second = C;
+ for (auto &LS : reverse(ValuesAtScopes[V]))
+ if (LS.first == L) {
+ LS.second = C;
break;
}
- }
return C;
}
// Okay, we know how many times the containing loop executes. If
// this is a constant evolving PHI node, get the final value at
// the specified iteration number.
- Constant *RV = getConstantEvolutionLoopExitValue(PN,
- BTCC->getValue()->getValue(),
- LI);
+ Constant *RV =
+ getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), LI);
if (RV) return getSCEV(RV);
}
}
if (CanConstantFold(I)) {
SmallVector<Constant *, 4> Operands;
bool MadeImprovement = false;
- for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
- Value *Op = I->getOperand(i);
+ for (Value *Op : I->operands()) {
if (Constant *C = dyn_cast<Constant>(Op)) {
Operands.push_back(C);
continue;
// Check to see if getSCEVAtScope actually made an improvement.
if (MadeImprovement) {
Constant *C = nullptr;
- const DataLayout &DL = F.getParent()->getDataLayout();
+ const DataLayout &DL = getDataLayout();
if (const CmpInst *CI = dyn_cast<CmpInst>(I))
C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
Operands[1], DL, &TLI);
return std::make_pair(CNC, CNC);
}
- uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
- const APInt &L = LC->getValue()->getValue();
- const APInt &M = MC->getValue()->getValue();
- const APInt &N = NC->getValue()->getValue();
+ uint32_t BitWidth = LC->getAPInt().getBitWidth();
+ const APInt &L = LC->getAPInt();
+ const APInt &M = MC->getAPInt();
+ const APInt &N = NC->getAPInt();
APInt Two(BitWidth, 2);
APInt Four(BitWidth, 4);
const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
if (R1 && R2) {
-#if 0
- dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1
- << " sol#2: " << *R2 << "\n";
-#endif
// Pick the smallest positive root value.
if (ConstantInt *CB =
dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT,
// For negative steps (counting down to zero):
// N = Start/-Step
// First compute the unsigned distance from zero in the direction of Step.
- bool CountDown = StepC->getValue()->getValue().isNegative();
+ bool CountDown = StepC->getAPInt().isNegative();
const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
// Handle unitary steps, which cannot wraparound.
// done by counting and comparing the number of trailing zeros of Step and
// Distance.
if (!CountDown) {
- const APInt &StepV = StepC->getValue()->getValue();
+ const APInt &StepV = StepC->getAPInt();
// StepV.isPowerOf2() returns true if StepV is an positive power of two. It
// also returns true if StepV is maximally negative (eg, INT_MIN), but that
// case is not handled as this code is guarded by !CountDown.
// Then, try to solve the above equation provided that Start is constant.
if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
- return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
- -StartC->getValue()->getValue(),
+ return SolveLinEquationWithOverflow(StepC->getAPInt(), -StartC->getAPInt(),
*this);
return getCouldNotCompute();
}
// If there's a constant operand, canonicalize comparisons with boundary
// cases, and canonicalize *-or-equal comparisons to regular comparisons.
if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
- const APInt &RA = RC->getValue()->getValue();
+ const APInt &RA = RC->getAPInt();
switch (Pred) {
default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
case ICmpInst::ICMP_EQ:
Pred = ICmpInst::ICMP_ULT;
Changed = true;
} else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
- LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
- SCEV::FlagNUW);
+ LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
Pred = ICmpInst::ICMP_ULT;
Changed = true;
}
break;
case ICmpInst::ICMP_UGE:
if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
- RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
- SCEV::FlagNUW);
+ RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
Pred = ICmpInst::ICMP_UGT;
Changed = true;
} else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
if (LeftGuarded && RightGuarded)
return true;
+ if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
+ return true;
+
// Otherwise see what can be done with known constant ranges.
return isKnownPredicateWithRanges(Pred, LHS, RHS);
}
return false;
}
+bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
+ const SCEV *LHS,
+ const SCEV *RHS) {
+
+ // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.
+ // Return Y via OutY.
+ auto MatchBinaryAddToConst =
+ [this](const SCEV *Result, const SCEV *X, APInt &OutY,
+ SCEV::NoWrapFlags ExpectedFlags) {
+ const SCEV *NonConstOp, *ConstOp;
+ SCEV::NoWrapFlags FlagsPresent;
+
+ if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) ||
+ !isa<SCEVConstant>(ConstOp) || NonConstOp != X)
+ return false;
+
+ OutY = cast<SCEVConstant>(ConstOp)->getAPInt();
+ return (FlagsPresent & ExpectedFlags) == ExpectedFlags;
+ };
+
+ APInt C;
+
+ switch (Pred) {
+ default:
+ break;
+
+ case ICmpInst::ICMP_SGE:
+ std::swap(LHS, RHS);
+ case ICmpInst::ICMP_SLE:
+ // X s<= (X + C)<nsw> if C >= 0
+ if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative())
+ return true;
+
+ // (X + C)<nsw> s<= X if C <= 0
+ if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
+ !C.isStrictlyPositive())
+ return true;
+ break;
+
+ case ICmpInst::ICMP_SGT:
+ std::swap(LHS, RHS);
+ case ICmpInst::ICMP_SLT:
+ // X s< (X + C)<nsw> if C > 0
+ if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) &&
+ C.isStrictlyPositive())
+ return true;
+
+ // (X + C)<nsw> s< X if C < 0
+ if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
+ return true;
+ break;
+ }
+
+ return false;
+}
+
+bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
+ const SCEV *LHS,
+ const SCEV *RHS) {
+ if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
+ return false;
+
+ // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
+ // the stack can result in exponential time complexity.
+ SaveAndRestore<bool> Restore(ProvingSplitPredicate, true);
+
+ // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
+ //
+ // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
+ // isKnownPredicate. isKnownPredicate is more powerful, but also more
+ // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
+ // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
+ // use isKnownPredicate later if needed.
+ return isKnownNonNegative(RHS) &&
+ isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
+ isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
+}
+
/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
/// protected by a conditional between LHS and RHS. This is used to
/// to eliminate casts.
return false;
}
+namespace {
/// RAII wrapper to prevent recursive application of isImpliedCond.
/// ScalarEvolution's PendingLoopPredicates set must be empty unless we are
/// currently evaluating isImpliedCond.
LoopPreds.erase(Cond);
}
};
+} // end anonymous namespace
/// isImpliedCond - Test whether the condition described by Pred, LHS,
/// and RHS is true whenever the given Cond value evaluates to true.
RHS, LHS, FoundLHS, FoundRHS);
}
+ // Unsigned comparison is the same as signed comparison when both the operands
+ // are non-negative.
+ if (CmpInst::isUnsigned(FoundPred) &&
+ CmpInst::getSignedPredicate(FoundPred) == Pred &&
+ isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS))
+ return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
+
// Check if we can make progress by sharpening ranges.
if (FoundPred == ICmpInst::ICMP_NE &&
(isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
APInt Min = ICmpInst::isSigned(Pred) ?
getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin();
- if (Min == C->getValue()->getValue()) {
+ if (Min == C->getAPInt()) {
// Given (V >= Min && V != Min) we conclude V >= (Min + 1).
// This is true even if (Min + 1) wraps around -- in case of
// wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
return false;
}
-// Return true if More == (Less + C), where C is a constant.
-static bool IsConstDiff(ScalarEvolution &SE, const SCEV *Less, const SCEV *More,
- APInt &C) {
- // We avoid subtracting expressions here because this function is usually
- // fairly deep in the call stack (i.e. is called many times).
+bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
+ const SCEV *&L, const SCEV *&R,
+ SCEV::NoWrapFlags &Flags) {
+ const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
+ if (!AE || AE->getNumOperands() != 2)
+ return false;
- auto SplitBinaryAdd = [](const SCEV *Expr, const SCEV *&L, const SCEV *&R) {
- const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
- if (!AE || AE->getNumOperands() != 2)
- return false;
+ L = AE->getOperand(0);
+ R = AE->getOperand(1);
+ Flags = AE->getNoWrapFlags();
+ return true;
+}
- L = AE->getOperand(0);
- R = AE->getOperand(1);
- return true;
- };
+bool ScalarEvolution::computeConstantDifference(const SCEV *Less,
+ const SCEV *More,
+ APInt &C) {
+ // We avoid subtracting expressions here because this function is usually
+ // fairly deep in the call stack (i.e. is called many times).
if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
const auto *LAR = cast<SCEVAddRecExpr>(Less);
if (!LAR->isAffine() || !MAR->isAffine())
return false;
- if (LAR->getStepRecurrence(SE) != MAR->getStepRecurrence(SE))
+ if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
return false;
Less = LAR->getStart();
}
if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
- const auto &M = cast<SCEVConstant>(More)->getValue()->getValue();
- const auto &L = cast<SCEVConstant>(Less)->getValue()->getValue();
+ const auto &M = cast<SCEVConstant>(More)->getAPInt();
+ const auto &L = cast<SCEVConstant>(Less)->getAPInt();
C = M - L;
return true;
}
const SCEV *L, *R;
- if (SplitBinaryAdd(Less, L, R))
+ SCEV::NoWrapFlags Flags;
+ if (splitBinaryAdd(Less, L, R, Flags))
if (const auto *LC = dyn_cast<SCEVConstant>(L))
if (R == More) {
- C = -(LC->getValue()->getValue());
+ C = -(LC->getAPInt());
return true;
}
- if (SplitBinaryAdd(More, L, R))
+ if (splitBinaryAdd(More, L, R, Flags))
if (const auto *LC = dyn_cast<SCEVConstant>(L))
if (R == Less) {
- C = LC->getValue()->getValue();
+ C = LC->getAPInt();
return true;
}
// C)".
APInt LDiff, RDiff;
- if (!IsConstDiff(*this, FoundLHS, LHS, LDiff) ||
- !IsConstDiff(*this, FoundRHS, RHS, RDiff) ||
+ if (!computeConstantDifference(FoundLHS, LHS, LDiff) ||
+ !computeConstantDifference(FoundRHS, RHS, RDiff) ||
LDiff != RDiff)
return false;
/// If Expr computes ~A, return A else return nullptr
static const SCEV *MatchNotExpr(const SCEV *Expr) {
const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
- if (!Add || Add->getNumOperands() != 2) return nullptr;
-
- const SCEVConstant *AddLHS = dyn_cast<SCEVConstant>(Add->getOperand(0));
- if (!(AddLHS && AddLHS->getValue()->getValue().isAllOnesValue()))
+ if (!Add || Add->getNumOperands() != 2 ||
+ !Add->getOperand(0)->isAllOnesValue())
return nullptr;
const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
- if (!AddRHS || AddRHS->getNumOperands() != 2) return nullptr;
-
- const SCEVConstant *MulLHS = dyn_cast<SCEVConstant>(AddRHS->getOperand(0));
- if (!(MulLHS && MulLHS->getValue()->getValue().isAllOnesValue()))
+ if (!AddRHS || AddRHS->getNumOperands() != 2 ||
+ !AddRHS->getOperand(0)->isAllOnesValue())
return nullptr;
return AddRHS->getOperand(1);
const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr);
if (!MaxExpr) return false;
- auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate);
- return It != MaxExpr->op_end();
+ return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end();
}
auto IsKnownPredicateFull =
[this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
return isKnownPredicateWithRanges(Pred, LHS, RHS) ||
- IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
- IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS);
+ IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
+ IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
+ isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
};
switch (Pred) {
!isa<SCEVConstant>(AddLHS->getOperand(0)))
return false;
- APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getValue()->getValue();
+ APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
// `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
// antecedent "`FoundLHS` `Pred` `FoundRHS`".
// Since `LHS` is `FoundLHS` + `AddLHS->getOperand(0)`, we can compute a range
// for `LHS`:
- APInt Addend =
- cast<SCEVConstant>(AddLHS->getOperand(0))->getValue()->getValue();
+ APInt Addend = cast<SCEVConstant>(AddLHS->getOperand(0))->getAPInt();
ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(Addend));
// We can also compute the range of values for `LHS` that satisfy the
// consequent, "`LHS` `Pred` `RHS`":
- APInt ConstRHS = cast<SCEVConstant>(RHS)->getValue()->getValue();
+ APInt ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
ConstantRange SatisfyingLHSRange =
ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS);
// overflow, in which case if RHS - Start is a constant, we don't need to
// do a max operation since we can just figure it out statically
if (NoWrap && isa<SCEVConstant>(Diff)) {
- APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue();
+ APInt D = dyn_cast<const SCEVConstant>(Diff)->getAPInt();
if (D.isNegative())
End = Start;
} else
// overflow, in which case if RHS - Start is a constant, we don't need to
// do a max operation since we can just figure it out statically
if (NoWrap && isa<SCEVConstant>(Diff)) {
- APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue();
+ APInt D = dyn_cast<const SCEVConstant>(Diff)->getAPInt();
if (!D.isNegative())
End = Start;
} else
Operands[0] = SE.getZero(SC->getType());
const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
getNoWrapFlags(FlagNW));
- if (const SCEVAddRecExpr *ShiftedAddRec =
- dyn_cast<SCEVAddRecExpr>(Shifted))
+ if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
return ShiftedAddRec->getNumIterationsInRange(
- Range.subtract(SC->getValue()->getValue()), SE);
+ Range.subtract(SC->getAPInt()), SE);
// This is strange and shouldn't happen.
return SE.getCouldNotCompute();
}
// The only time we can solve this is when we have all constant indices.
// Otherwise, we cannot determine the overflow conditions.
- for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
- if (!isa<SCEVConstant>(getOperand(i)))
- return SE.getCouldNotCompute();
-
+ if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
+ return SE.getCouldNotCompute();
// Okay at this point we know that all elements of the chrec are constants and
// that the start element is zero.
// If A is negative then the lower of the range is the last possible loop
// value. Also note that we already checked for a full range.
APInt One(BitWidth,1);
- APInt A = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
+ APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
// The exit value should be (End+A)/A.
FlagAnyWrap);
// Next, solve the constructed addrec
- std::pair<const SCEV *,const SCEV *> Roots =
- SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
+ auto Roots = SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
if (R1) {
// Pick the smallest positive root value.
- if (ConstantInt *CB =
- dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
- R1->getValue(), R2->getValue()))) {
+ if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp(
+ ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) {
if (!CB->getZExtValue())
std::swap(R1, R2); // R1 is the minimum root now.
if (Range.contains(R1Val->getValue())) {
// The next iteration must be out of the range...
ConstantInt *NextVal =
- ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1);
+ ConstantInt::get(SE.getContext(), R1->getAPInt() + 1);
R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
if (!Range.contains(R1Val->getValue()))
// If R1 was not in the range, then it is a good return value. Make
// sure that R1-1 WAS in the range though, just in case.
ConstantInt *NextVal =
- ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1);
+ ConstantInt::get(SE.getContext(), R1->getAPInt() - 1);
R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
if (Range.contains(R1Val->getValue()))
return R1;
}
bool isDone() const { return false; }
};
+
+// Check if a SCEV contains an AddRecExpr.
+struct SCEVHasAddRec {
+ bool &ContainsAddRec;
+
+ SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) {
+ ContainsAddRec = false;
+ }
+
+ bool follow(const SCEV *S) {
+ if (isa<SCEVAddRecExpr>(S)) {
+ ContainsAddRec = true;
+
+ // Stop recursion: once we collected a term, do not walk its operands.
+ return false;
+ }
+
+ // Keep looking.
+ return true;
+ }
+ bool isDone() const { return false; }
+};
+
+// Find factors that are multiplied with an expression that (possibly as a
+// subexpression) contains an AddRecExpr. In the expression:
+//
+// 8 * (100 + %p * %q * (%a + {0, +, 1}_loop))
+//
+// "%p * %q" are factors multiplied by the expression "(%a + {0, +, 1}_loop)"
+// that contains the AddRec {0, +, 1}_loop. %p * %q are likely to be array size
+// parameters as they form a product with an induction variable.
+//
+// This collector expects all array size parameters to be in the same MulExpr.
+// It might be necessary to later add support for collecting parameters that are
+// spread over different nested MulExpr.
+struct SCEVCollectAddRecMultiplies {
+ SmallVectorImpl<const SCEV *> &Terms;
+ ScalarEvolution &SE;
+
+ SCEVCollectAddRecMultiplies(SmallVectorImpl<const SCEV *> &T, ScalarEvolution &SE)
+ : Terms(T), SE(SE) {}
+
+ bool follow(const SCEV *S) {
+ if (auto *Mul = dyn_cast<SCEVMulExpr>(S)) {
+ bool HasAddRec = false;
+ SmallVector<const SCEV *, 0> Operands;
+ for (auto Op : Mul->operands()) {
+ if (isa<SCEVUnknown>(Op)) {
+ Operands.push_back(Op);
+ } else {
+ bool ContainsAddRec;
+ SCEVHasAddRec ContiansAddRec(ContainsAddRec);
+ visitAll(Op, ContiansAddRec);
+ HasAddRec |= ContainsAddRec;
+ }
+ }
+ if (Operands.size() == 0)
+ return true;
+
+ if (!HasAddRec)
+ return false;
+
+ Terms.push_back(SE.getMulExpr(Operands));
+ // Stop recursion: once we collected a term, do not walk its operands.
+ return false;
+ }
+
+ // Keep looking.
+ return true;
+ }
+ bool isDone() const { return false; }
+};
}
-/// Find parametric terms in this SCEVAddRecExpr.
+/// Find parametric terms in this SCEVAddRecExpr. We first for parameters in
+/// two places:
+/// 1) The strides of AddRec expressions.
+/// 2) Unknowns that are multiplied with AddRec expressions.
void ScalarEvolution::collectParametricTerms(const SCEV *Expr,
SmallVectorImpl<const SCEV *> &Terms) {
SmallVector<const SCEV *, 4> Strides;
for (const SCEV *T : Terms)
dbgs() << *T << "\n";
});
+
+ SCEVCollectAddRecMultiplies MulCollector(Terms, *this);
+ visitAll(Expr, MulCollector);
}
static bool findArrayDimensionsRec(ScalarEvolution &SE,
return true;
}
-namespace {
-struct FindParameter {
- bool FoundParameter;
- FindParameter() : FoundParameter(false) {}
-
- bool follow(const SCEV *S) {
- if (isa<SCEVUnknown>(S)) {
- FoundParameter = true;
- // Stop recursion: we found a parameter.
- return false;
- }
- // Keep looking.
- return true;
- }
- bool isDone() const {
- // Stop recursion if we have found a parameter.
- return FoundParameter;
- }
-};
-}
-
// Returns true when S contains at least a SCEVUnknown parameter.
static inline bool
containsParameters(const SCEV *S) {
+ struct FindParameter {
+ bool FoundParameter;
+ FindParameter() : FoundParameter(false) {}
+
+ bool follow(const SCEV *S) {
+ if (isa<SCEVUnknown>(S)) {
+ FoundParameter = true;
+ // Stop recursion: we found a parameter.
+ return false;
+ }
+ // Keep looking.
+ return true;
+ }
+ bool isDone() const {
+ // Stop recursion if we have found a parameter.
+ return FoundParameter;
+ }
+ };
+
FindParameter F;
SCEVTraversal<FindParameter> ST(F);
ST.visitAll(S);
ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
- // Divide all terms by the element size.
+ // Try to divide all terms by the element size. If term is not divisible by
+ // element size, proceed with the original term.
for (const SCEV *&Term : Terms) {
const SCEV *Q, *R;
SCEVDivision::divide(SE, Term, ElementSize, &Q, &R);
- Term = Q;
+ if (!Q->isZero())
+ Term = Q;
}
SmallVector<const SCEV *, 4> NewTerms;
if (Sizes.empty())
return;
- if (auto AR = dyn_cast<SCEVAddRecExpr>(Expr))
+ if (auto *AR = dyn_cast<SCEVAddRecExpr>(Expr))
if (!AR->isAffine())
return;
LoopInfo &LI)
: F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
CouldNotCompute(new SCEVCouldNotCompute()),
- WalkingBEDominatingConds(false), ValuesAtScopes(64), LoopDispositions(64),
- BlockDispositions(64), FirstUnknown(nullptr) {}
+ WalkingBEDominatingConds(false), ProvingSplitPredicate(false),
+ ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64),
+ FirstUnknown(nullptr) {}
ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
: F(Arg.F), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), LI(Arg.LI),
CouldNotCompute(std::move(Arg.CouldNotCompute)),
ValueExprMap(std::move(Arg.ValueExprMap)),
- WalkingBEDominatingConds(false),
+ WalkingBEDominatingConds(false), ProvingSplitPredicate(false),
BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
ConstantEvolutionLoopExitValue(
std::move(Arg.ConstantEvolutionLoopExitValue)),
UnsignedRanges(std::move(Arg.UnsignedRanges)),
SignedRanges(std::move(Arg.SignedRanges)),
UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
+ UniquePreds(std::move(Arg.UniquePreds)),
SCEVAllocator(std::move(Arg.SCEVAllocator)),
FirstUnknown(Arg.FirstUnknown) {
Arg.FirstUnknown = nullptr;
// Free any extra memory created for ExitNotTakenInfo in the unlikely event
// that a loop had multiple computable exits.
- for (DenseMap<const Loop*, BackedgeTakenInfo>::iterator I =
- BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end();
- I != E; ++I) {
- I->second.clear();
- }
+ for (auto &BTCI : BackedgeTakenCounts)
+ BTCI.second.clear();
assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
+ assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
}
bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
OS << "Classifying expressions for: ";
F.printAsOperand(OS, /*PrintType=*/false);
OS << "\n";
- for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
- if (isSCEVable(I->getType()) && !isa<CmpInst>(*I)) {
- OS << *I << '\n';
+ for (Instruction &I : instructions(F))
+ if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
+ OS << I << '\n';
OS << " --> ";
- const SCEV *SV = SE.getSCEV(&*I);
+ const SCEV *SV = SE.getSCEV(&I);
SV->print(OS);
if (!isa<SCEVCouldNotCompute>(SV)) {
OS << " U: ";
SE.getSignedRange(SV).print(OS);
}
- const Loop *L = LI.getLoopFor((*I).getParent());
+ const Loop *L = LI.getLoopFor(I.getParent());
const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
if (AtUse != SV) {
// This recurrence is variant w.r.t. L if any of its operands
// are variant.
- for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
- I != E; ++I)
- if (!isLoopInvariant(*I, L))
+ for (auto *Op : AR->operands())
+ if (!isLoopInvariant(Op, L))
return LoopVariant;
// Otherwise it's loop-invariant.
case scMulExpr:
case scUMaxExpr:
case scSMaxExpr: {
- const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
bool HasVarying = false;
- for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
- I != E; ++I) {
- LoopDisposition D = getLoopDisposition(*I, L);
+ for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
+ LoopDisposition D = getLoopDisposition(Op, L);
if (D == LoopVariant)
return LoopVariant;
if (D == LoopComputable)
// invariant if they are not contained in the specified loop.
// Instructions are never considered invariant in the function body
// (null loop) because they are defined within the "loop".
- if (Instruction *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
+ if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
return LoopInvariant;
case scCouldNotCompute:
case scSMaxExpr: {
const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
bool Proper = true;
- for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
- I != E; ++I) {
- BlockDisposition D = getBlockDisposition(*I, BB);
+ for (const SCEV *NAryOp : NAry->operands()) {
+ BlockDisposition D = getBlockDisposition(NAryOp, BB);
if (D == DoesNotDominateBlock)
return DoesNotDominateBlock;
if (D == DominatesBlock)
return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
}
-namespace {
-// Search for a SCEV expression node within an expression tree.
-// Implements SCEVTraversal::Visitor.
-struct SCEVSearch {
- const SCEV *Node;
- bool IsFound;
+bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
+ // Search for a SCEV expression node within an expression tree.
+ // Implements SCEVTraversal::Visitor.
+ struct SCEVSearch {
+ const SCEV *Node;
+ bool IsFound;
- SCEVSearch(const SCEV *N): Node(N), IsFound(false) {}
+ SCEVSearch(const SCEV *N): Node(N), IsFound(false) {}
- bool follow(const SCEV *S) {
- IsFound |= (S == Node);
- return !IsFound;
- }
- bool isDone() const { return IsFound; }
-};
-}
+ bool follow(const SCEV *S) {
+ IsFound |= (S == Node);
+ return !IsFound;
+ }
+ bool isDone() const { return IsFound; }
+ };
-bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
SCEVSearch Search(Op);
visitAll(S, Search);
return Search.IsFound;
/// getLoopBackedgeTakenCounts - Helper method for verifyAnalysis.
static void
getLoopBackedgeTakenCounts(Loop *L, VerifyMap &Map, ScalarEvolution &SE) {
- for (Loop::reverse_iterator I = L->rbegin(), E = L->rend(); I != E; ++I) {
- getLoopBackedgeTakenCounts(*I, Map, SE); // recurse.
-
- std::string &S = Map[L];
- if (S.empty()) {
- raw_string_ostream OS(S);
- SE.getBackedgeTakenCount(L)->print(OS);
+ std::string &S = Map[L];
+ if (S.empty()) {
+ raw_string_ostream OS(S);
+ SE.getBackedgeTakenCount(L)->print(OS);
- // false and 0 are semantically equivalent. This can happen in dead loops.
- replaceSubString(OS.str(), "false", "0");
- // Remove wrap flags, their use in SCEV is highly fragile.
- // FIXME: Remove this when SCEV gets smarter about them.
- replaceSubString(OS.str(), "<nw>", "");
- replaceSubString(OS.str(), "<nsw>", "");
- replaceSubString(OS.str(), "<nuw>", "");
- }
+ // false and 0 are semantically equivalent. This can happen in dead loops.
+ replaceSubString(OS.str(), "false", "0");
+ // Remove wrap flags, their use in SCEV is highly fragile.
+ // FIXME: Remove this when SCEV gets smarter about them.
+ replaceSubString(OS.str(), "<nw>", "");
+ replaceSubString(OS.str(), "<nsw>", "");
+ replaceSubString(OS.str(), "<nuw>", "");
}
+
+ for (auto *R : reverse(*L))
+ getLoopBackedgeTakenCounts(R, Map, SE); // recurse.
}
void ScalarEvolution::verify() const {
AU.addRequiredTransitive<DominatorTreeWrapperPass>();
AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
}
+
+const SCEVPredicate *
+ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
+ const SCEVConstant *RHS) {
+ FoldingSetNodeID ID;
+ // Unique this node based on the arguments
+ ID.AddInteger(SCEVPredicate::P_Equal);
+ ID.AddPointer(LHS);
+ ID.AddPointer(RHS);
+ void *IP = nullptr;
+ if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
+ return S;
+ SCEVEqualPredicate *Eq = new (SCEVAllocator)
+ SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
+ UniquePreds.InsertNode(Eq, IP);
+ return Eq;
+}
+
+namespace {
+class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
+public:
+ static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
+ SCEVUnionPredicate &A) {
+ SCEVPredicateRewriter Rewriter(SE, A);
+ return Rewriter.visit(Scev);
+ }
+
+ SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P)
+ : SCEVRewriteVisitor(SE), P(P) {}
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ auto ExprPreds = P.getPredicatesForExpr(Expr);
+ for (auto *Pred : ExprPreds)
+ if (const auto *IPred = dyn_cast<const SCEVEqualPredicate>(Pred))
+ if (IPred->getLHS() == Expr)
+ return IPred->getRHS();
+
+ return Expr;
+ }
+
+private:
+ SCEVUnionPredicate &P;
+};
+} // end anonymous namespace
+
+const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev,
+ SCEVUnionPredicate &Preds) {
+ return SCEVPredicateRewriter::rewrite(Scev, *this, Preds);
+}
+
+/// SCEV predicates
+SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
+ SCEVPredicateKind Kind)
+ : FastID(ID), Kind(Kind) {}
+
+SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
+ const SCEVUnknown *LHS,
+ const SCEVConstant *RHS)
+ : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {}
+
+bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
+ const auto *Op = dyn_cast<const SCEVEqualPredicate>(N);
+
+ if (!Op)
+ return false;
+
+ return Op->LHS == LHS && Op->RHS == RHS;
+}
+
+bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
+
+const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
+
+void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
+ OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
+}
+
+/// Union predicates don't get cached so create a dummy set ID for it.
+SCEVUnionPredicate::SCEVUnionPredicate()
+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
+
+bool SCEVUnionPredicate::isAlwaysTrue() const {
+ return all_of(Preds,
+ [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
+}
+
+ArrayRef<const SCEVPredicate *>
+SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
+ auto I = SCEVToPreds.find(Expr);
+ if (I == SCEVToPreds.end())
+ return ArrayRef<const SCEVPredicate *>();
+ return I->second;
+}
+
+bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
+ if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N))
+ return all_of(Set->Preds,
+ [this](const SCEVPredicate *I) { return this->implies(I); });
+
+ auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
+ if (ScevPredsIt == SCEVToPreds.end())
+ return false;
+ auto &SCEVPreds = ScevPredsIt->second;
+
+ return any_of(SCEVPreds,
+ [N](const SCEVPredicate *I) { return I->implies(N); });
+}
+
+const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }
+
+void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
+ for (auto Pred : Preds)
+ Pred->print(OS, Depth);
+}
+
+void SCEVUnionPredicate::add(const SCEVPredicate *N) {
+ if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) {
+ for (auto Pred : Set->Preds)
+ add(Pred);
+ return;
+ }
+
+ if (implies(N))
+ return;
+
+ const SCEV *Key = N->getExpr();
+ assert(Key && "Only SCEVUnionPredicate doesn't have an "
+ " associated expression!");
+
+ SCEVToPreds[Key].push_back(N);
+ Preds.push_back(N);
+}
+
+PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE)
+ : SE(SE), Generation(0) {}
+
+const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
+ const SCEV *Expr = SE.getSCEV(V);
+ RewriteEntry &Entry = RewriteMap[Expr];
+
+ // If we already have an entry and the version matches, return it.
+ if (Entry.second && Generation == Entry.first)
+ return Entry.second;
+
+ // We found an entry but it's stale. Rewrite the stale entry
+ // acording to the current predicate.
+ if (Entry.second)
+ Expr = Entry.second;
+
+ const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds);
+ Entry = {Generation, NewSCEV};
+
+ return NewSCEV;
+}
+
+void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
+ if (Preds.implies(&Pred))
+ return;
+ Preds.add(&Pred);
+ updateGeneration();
+}
+
+const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const {
+ return Preds;
+}
+
+void PredicatedScalarEvolution::updateGeneration() {
+ // If the generation number wrapped recompute everything.
+ if (++Generation == 0) {
+ for (auto &II : RewriteMap) {
+ const SCEV *Rewritten = II.second.second;
+ II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)};
+ }
+ }
+}