}
}
+namespace {
class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
public:
static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
const Loop *L;
bool Valid;
};
+} // end anonymous namespace
const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
const Loop *L = LI.getLoopFor(PN->getParent());
if (PN->getNumIncomingValues() == 2) {
const Loop *L = LI.getLoopFor(PN->getParent());
+ // We don't want to break LCSSA, even in a SCEV expression tree.
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
+ if (LI.getLoopFor(PN->getIncomingBlock(i)) != L)
+ return nullptr;
+
// Try to match
//
// br %cond, label %left, label %right
TLI);
}
+
+// If every incoming value to PN except the one for BB is a specific Constant,
+// return that, else return nullptr.
+static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) {
+ Constant *IncomingVal = nullptr;
+
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+ if (PN->getIncomingBlock(i) == BB)
+ continue;
+
+ auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
+ if (!CurrentVal)
+ return nullptr;
+
+ if (IncomingVal != CurrentVal) {
+ if (IncomingVal)
+ return nullptr;
+ IncomingVal = CurrentVal;
+ }
+ }
+
+ return IncomingVal;
+}
+
/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
/// in the header of its containing loop, we know the loop executes a
/// constant number of times, and the PHI node is just a recurrence
if (!Latch)
return nullptr;
- // Since the loop has one latch, the PHI node must have two entries. One
- // entry must be a constant (coming in from outside of the loop), and the
- // second must be derived from the same PHI.
-
- BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0)
- ? PN->getIncomingBlock(1)
- : PN->getIncomingBlock(0);
-
- assert(PN->getNumIncomingValues() == 2 && "Follows from having one latch!");
-
- // Note: not all PHI nodes in the same block have to have their incoming
- // values in the same order, so we use the basic block to look up the incoming
- // value, not an index.
-
for (auto &I : *Header) {
PHINode *PHI = dyn_cast<PHINode>(&I);
if (!PHI) break;
- auto *StartCST =
- dyn_cast<Constant>(PHI->getIncomingValueForBlock(NonLatch));
+ auto *StartCST = getOtherIncomingValue(PHI, Latch);
if (!StartCST) continue;
CurrentIterVals[PHI] = StartCST;
}
BasicBlock *Latch = L->getLoopLatch();
assert(Latch && "Should follow from NumIncomingValues == 2!");
- // NonLatch is the preheader, or something equivalent.
- BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0)
- ? PN->getIncomingBlock(1)
- : PN->getIncomingBlock(0);
-
- // Note: not all PHI nodes in the same block have to have their incoming
- // values in the same order, so we use the basic block to look up the incoming
- // value, not an index.
-
for (auto &I : *Header) {
PHINode *PHI = dyn_cast<PHINode>(&I);
if (!PHI)
break;
- auto *StartCST =
- dyn_cast<Constant>(PHI->getIncomingValueForBlock(NonLatch));
+ auto *StartCST = getOtherIncomingValue(PHI, Latch);
if (!StartCST) continue;
CurrentIterVals[PHI] = StartCST;
}
/// In the case that a relevant loop exit value cannot be computed, the
/// original value V is returned.
const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
+ SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values =
+ ValuesAtScopes[V];
// Check to see if we've folded this expression at this loop before.
- SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = ValuesAtScopes[V];
- for (unsigned u = 0; u < Values.size(); u++) {
- if (Values[u].first == L)
- return Values[u].second ? Values[u].second : V;
- }
- Values.push_back(std::make_pair(L, static_cast<const SCEV *>(nullptr)));
+ for (auto &LS : Values)
+ if (LS.first == L)
+ return LS.second ? LS.second : V;
+
+ Values.emplace_back(L, nullptr);
+
// Otherwise compute it.
const SCEV *C = computeSCEVAtScope(V, L);
- SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values2 = ValuesAtScopes[V];
- for (unsigned u = Values2.size(); u > 0; u--) {
- if (Values2[u - 1].first == L) {
- Values2[u - 1].second = C;
+ for (auto &LS : reverse(ValuesAtScopes[V]))
+ if (LS.first == L) {
+ LS.second = C;
break;
}
- }
return C;
}
Pred = ICmpInst::ICMP_ULT;
Changed = true;
} else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
- LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
- SCEV::FlagNUW);
+ LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
Pred = ICmpInst::ICMP_ULT;
Changed = true;
}
break;
case ICmpInst::ICMP_UGE:
if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
- RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
- SCEV::FlagNUW);
+ RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
Pred = ICmpInst::ICMP_UGT;
Changed = true;
} else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
!C.isStrictlyPositive())
return true;
+ break;
case ICmpInst::ICMP_SGT:
std::swap(LHS, RHS);
// (X + C)<nsw> s< X if C < 0
if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
return true;
+ break;
}
return false;
// expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
// interesting cases seen in practice. We can consider "upgrading" L >= 0 to
// use isKnownPredicate later if needed.
- if (isKnownNonNegative(RHS) &&
- isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
- isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS))
- return true;
-
- return false;
+ return isKnownNonNegative(RHS) &&
+ isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
+ isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
}
/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
FlagAnyWrap);
// Next, solve the constructed addrec
- std::pair<const SCEV *,const SCEV *> Roots =
- SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
+ auto Roots = SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
if (R1) {
// Pick the smallest positive root value.
- if (ConstantInt *CB =
- dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
- R1->getValue(), R2->getValue()))) {
+ if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp(
+ ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) {
if (!CB->getZExtValue())
std::swap(R1, R2); // R1 is the minimum root now.
UnsignedRanges(std::move(Arg.UnsignedRanges)),
SignedRanges(std::move(Arg.SignedRanges)),
UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
+ UniquePreds(std::move(Arg.UniquePreds)),
SCEVAllocator(std::move(Arg.SCEVAllocator)),
FirstUnknown(Arg.FirstUnknown) {
Arg.FirstUnknown = nullptr;
// This recurrence is variant w.r.t. L if any of its operands
// are variant.
- for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
- I != E; ++I)
- if (!isLoopInvariant(*I, L))
+ for (auto *Op : AR->operands())
+ if (!isLoopInvariant(Op, L))
return LoopVariant;
// Otherwise it's loop-invariant.
case scMulExpr:
case scUMaxExpr:
case scSMaxExpr: {
- const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
bool HasVarying = false;
- for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
- I != E; ++I) {
- LoopDisposition D = getLoopDisposition(*I, L);
+ for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) {
+ LoopDisposition D = getLoopDisposition(Op, L);
if (D == LoopVariant)
return LoopVariant;
if (D == LoopComputable)
// invariant if they are not contained in the specified loop.
// Instructions are never considered invariant in the function body
// (null loop) because they are defined within the "loop".
- if (Instruction *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
+ if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
return LoopInvariant;
case scCouldNotCompute:
AU.addRequiredTransitive<DominatorTreeWrapperPass>();
AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
}
+
+const SCEVPredicate *
+ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
+ const SCEVConstant *RHS) {
+ FoldingSetNodeID ID;
+ // Unique this node based on the arguments
+ ID.AddInteger(SCEVPredicate::P_Equal);
+ ID.AddPointer(LHS);
+ ID.AddPointer(RHS);
+ void *IP = nullptr;
+ if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
+ return S;
+ SCEVEqualPredicate *Eq = new (SCEVAllocator)
+ SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
+ UniquePreds.InsertNode(Eq, IP);
+ return Eq;
+}
+
+namespace {
+class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
+public:
+ static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
+ SCEVUnionPredicate &A) {
+ SCEVPredicateRewriter Rewriter(SE, A);
+ return Rewriter.visit(Scev);
+ }
+
+ SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P)
+ : SCEVRewriteVisitor(SE), P(P) {}
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ auto ExprPreds = P.getPredicatesForExpr(Expr);
+ for (auto *Pred : ExprPreds)
+ if (const auto *IPred = dyn_cast<const SCEVEqualPredicate>(Pred))
+ if (IPred->getLHS() == Expr)
+ return IPred->getRHS();
+
+ return Expr;
+ }
+
+private:
+ SCEVUnionPredicate &P;
+};
+} // end anonymous namespace
+
+const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev,
+ SCEVUnionPredicate &Preds) {
+ return SCEVPredicateRewriter::rewrite(Scev, *this, Preds);
+}
+
+/// SCEV predicates
+SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
+ SCEVPredicateKind Kind)
+ : FastID(ID), Kind(Kind) {}
+
+SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
+ const SCEVUnknown *LHS,
+ const SCEVConstant *RHS)
+ : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {}
+
+bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
+ const auto *Op = dyn_cast<const SCEVEqualPredicate>(N);
+
+ if (!Op)
+ return false;
+
+ return Op->LHS == LHS && Op->RHS == RHS;
+}
+
+bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
+
+const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
+
+void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
+ OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
+}
+
+/// Union predicates don't get cached so create a dummy set ID for it.
+SCEVUnionPredicate::SCEVUnionPredicate()
+ : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
+
+bool SCEVUnionPredicate::isAlwaysTrue() const {
+ return std::all_of(Preds.begin(), Preds.end(),
+ [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
+}
+
+ArrayRef<const SCEVPredicate *>
+SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
+ auto I = SCEVToPreds.find(Expr);
+ if (I == SCEVToPreds.end())
+ return ArrayRef<const SCEVPredicate *>();
+ return I->second;
+}
+
+bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
+ if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N))
+ return std::all_of(
+ Set->Preds.begin(), Set->Preds.end(),
+ [this](const SCEVPredicate *I) { return this->implies(I); });
+
+ auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
+ if (ScevPredsIt == SCEVToPreds.end())
+ return false;
+ auto &SCEVPreds = ScevPredsIt->second;
+
+ return std::any_of(SCEVPreds.begin(), SCEVPreds.end(),
+ [N](const SCEVPredicate *I) { return I->implies(N); });
+}
+
+const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }
+
+void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
+ for (auto Pred : Preds)
+ Pred->print(OS, Depth);
+}
+
+void SCEVUnionPredicate::add(const SCEVPredicate *N) {
+ if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) {
+ for (auto Pred : Set->Preds)
+ add(Pred);
+ return;
+ }
+
+ if (implies(N))
+ return;
+
+ const SCEV *Key = N->getExpr();
+ assert(Key && "Only SCEVUnionPredicate doesn't have an "
+ " associated expression!");
+
+ SCEVToPreds[Key].push_back(N);
+ Preds.push_back(N);
+}