-//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
+//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
//
// The LLVM Compiler Infrastructure
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
#include "llvm/IR/Operator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Target/TargetLibraryInfo.h"
#include <algorithm>
using namespace llvm;
INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution",
"Scalar Evolution Analysis", false, true)
-INITIALIZE_PASS_DEPENDENCY(LoopInfo)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution",
"Scalar Evolution Analysis", false, true)
char ScalarEvolution::ID = 0;
}
}
+namespace {
+struct FindSCEVSize {
+ int Size;
+ FindSCEVSize() : Size(0) {}
+
+ bool follow(const SCEV *S) {
+ ++Size;
+ // Keep looking at all operands of S.
+ return true;
+ }
+ bool isDone() const {
+ return false;
+ }
+};
+}
+
+// Returns the size of the SCEV S.
+static inline int sizeOfSCEV(const SCEV *S) {
+ FindSCEVSize F;
+ SCEVTraversal<FindSCEVSize> ST(F);
+ ST.visitAll(S);
+ return F.Size;
+}
+
+namespace {
+
+struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> {
+public:
+ // Computes the Quotient and Remainder of the division of Numerator by
+ // Denominator.
+ static void divide(ScalarEvolution &SE, const SCEV *Numerator,
+ const SCEV *Denominator, const SCEV **Quotient,
+ const SCEV **Remainder) {
+ assert(Numerator && Denominator && "Uninitialized SCEV");
+
+ SCEVDivision D(SE, Numerator, Denominator);
+
+ // Check for the trivial case here to avoid having to check for it in the
+ // rest of the code.
+ if (Numerator == Denominator) {
+ *Quotient = D.One;
+ *Remainder = D.Zero;
+ return;
+ }
+
+ if (Numerator->isZero()) {
+ *Quotient = D.Zero;
+ *Remainder = D.Zero;
+ return;
+ }
+
+ // Split the Denominator when it is a product.
+ if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) {
+ const SCEV *Q, *R;
+ *Quotient = Numerator;
+ for (const SCEV *Op : T->operands()) {
+ divide(SE, *Quotient, Op, &Q, &R);
+ *Quotient = Q;
+
+ // Bail out when the Numerator is not divisible by one of the terms of
+ // the Denominator.
+ if (!R->isZero()) {
+ *Quotient = D.Zero;
+ *Remainder = Numerator;
+ return;
+ }
+ }
+ *Remainder = D.Zero;
+ return;
+ }
+
+ D.visit(Numerator);
+ *Quotient = D.Quotient;
+ *Remainder = D.Remainder;
+ }
+
+ // Except in the trivial case described above, we do not know how to divide
+ // Expr by Denominator for the following functions with empty implementation.
+ void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {}
+ void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {}
+ void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {}
+ void visitUDivExpr(const SCEVUDivExpr *Numerator) {}
+ void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {}
+ void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {}
+ void visitUnknown(const SCEVUnknown *Numerator) {}
+ void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {}
+
+ void visitConstant(const SCEVConstant *Numerator) {
+ if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
+ APInt NumeratorVal = Numerator->getValue()->getValue();
+ APInt DenominatorVal = D->getValue()->getValue();
+ uint32_t NumeratorBW = NumeratorVal.getBitWidth();
+ uint32_t DenominatorBW = DenominatorVal.getBitWidth();
+
+ if (NumeratorBW > DenominatorBW)
+ DenominatorVal = DenominatorVal.sext(NumeratorBW);
+ else if (NumeratorBW < DenominatorBW)
+ NumeratorVal = NumeratorVal.sext(DenominatorBW);
+
+ APInt QuotientVal(NumeratorVal.getBitWidth(), 0);
+ APInt RemainderVal(NumeratorVal.getBitWidth(), 0);
+ APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal);
+ Quotient = SE.getConstant(QuotientVal);
+ Remainder = SE.getConstant(RemainderVal);
+ return;
+ }
+ }
+
+ void visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
+ const SCEV *StartQ, *StartR, *StepQ, *StepR;
+ assert(Numerator->isAffine() && "Numerator should be affine");
+ divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
+ divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
+ Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
+ Numerator->getNoWrapFlags());
+ Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
+ Numerator->getNoWrapFlags());
+ }
+
+ void visitAddExpr(const SCEVAddExpr *Numerator) {
+ SmallVector<const SCEV *, 2> Qs, Rs;
+ Type *Ty = Denominator->getType();
+
+ for (const SCEV *Op : Numerator->operands()) {
+ const SCEV *Q, *R;
+ divide(SE, Op, Denominator, &Q, &R);
+
+ // Bail out if types do not match.
+ if (Ty != Q->getType() || Ty != R->getType()) {
+ Quotient = Zero;
+ Remainder = Numerator;
+ return;
+ }
+
+ Qs.push_back(Q);
+ Rs.push_back(R);
+ }
+
+ if (Qs.size() == 1) {
+ Quotient = Qs[0];
+ Remainder = Rs[0];
+ return;
+ }
+ Quotient = SE.getAddExpr(Qs);
+ Remainder = SE.getAddExpr(Rs);
+ }
+
+ void visitMulExpr(const SCEVMulExpr *Numerator) {
+ SmallVector<const SCEV *, 2> Qs;
+ Type *Ty = Denominator->getType();
+
+ bool FoundDenominatorTerm = false;
+ for (const SCEV *Op : Numerator->operands()) {
+ // Bail out if types do not match.
+ if (Ty != Op->getType()) {
+ Quotient = Zero;
+ Remainder = Numerator;
+ return;
+ }
+
+ if (FoundDenominatorTerm) {
+ Qs.push_back(Op);
+ continue;
+ }
+
+ // Check whether Denominator divides one of the product operands.
+ const SCEV *Q, *R;
+ divide(SE, Op, Denominator, &Q, &R);
+ if (!R->isZero()) {
+ Qs.push_back(Op);
+ continue;
+ }
+
+ // Bail out if types do not match.
+ if (Ty != Q->getType()) {
+ Quotient = Zero;
+ Remainder = Numerator;
+ return;
+ }
+
+ FoundDenominatorTerm = true;
+ Qs.push_back(Q);
+ }
+
+ if (FoundDenominatorTerm) {
+ Remainder = Zero;
+ if (Qs.size() == 1)
+ Quotient = Qs[0];
+ else
+ Quotient = SE.getMulExpr(Qs);
+ return;
+ }
+
+ if (!isa<SCEVUnknown>(Denominator)) {
+ Quotient = Zero;
+ Remainder = Numerator;
+ return;
+ }
+
+ // The Remainder is obtained by replacing Denominator by 0 in Numerator.
+ ValueToValueMap RewriteMap;
+ RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
+ cast<SCEVConstant>(Zero)->getValue();
+ Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
+
+ if (Remainder->isZero()) {
+ // The Quotient is obtained by replacing Denominator by 1 in Numerator.
+ RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
+ cast<SCEVConstant>(One)->getValue();
+ Quotient =
+ SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
+ return;
+ }
+
+ // Quotient is (Numerator - Remainder) divided by Denominator.
+ const SCEV *Q, *R;
+ const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
+ if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) {
+ // This SCEV does not seem to simplify: fail the division here.
+ Quotient = Zero;
+ Remainder = Numerator;
+ return;
+ }
+ divide(SE, Diff, Denominator, &Q, &R);
+ assert(R == Zero &&
+ "(Numerator - Remainder) should evenly divide Denominator");
+ Quotient = Q;
+ }
+
+private:
+ SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
+ const SCEV *Denominator)
+ : SE(S), Denominator(Denominator) {
+ Zero = SE.getConstant(Denominator->getType(), 0);
+ One = SE.getConstant(Denominator->getType(), 1);
+
+ // By default, we don't know how to divide Expr by Denominator.
+ // Providing the default here simplifies the rest of the code.
+ Quotient = Zero;
+ Remainder = Numerator;
+ }
+
+ ScalarEvolution &SE;
+ const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One;
+};
+
+}
//===----------------------------------------------------------------------===//
// Simple SCEV method implementations
const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
- if (PreAR && PreAR->getNoWrapFlags(SCEV::FlagNSW))
+ // WARNING: FIXME: the optimization below assumes that a sign-overflowing nsw
+ // operation is undefined behavior. This is strictly more aggressive than the
+ // interpretation of nsw in other parts of LLVM (for instance, they may
+ // unconditionally hoist nsw arithmetic through control flow). This logic
+ // needs to be revisited once we have a consistent semantics for poison
+ // values.
+ //
+ // "{S,+,X} is <nsw>" and "{S,+,X} is evaluated at least once" implies "S+X
+ // does not sign-overflow" (we'd have undefined behavior if it did). If
+ // `L->getExitingBlock() == L->getLoopLatch()` then `PreAR` (= {S,+,X}<nsw>)
+ // is evaluated every-time `AR` (= {S+X,+,X}) is evaluated, and hence within
+ // `AR` we are safe to assume that "S+X" will not sign-overflow.
+ //
+
+ BasicBlock *ExitingBlock = L->getExitingBlock();
+ BasicBlock *LatchBlock = L->getLoopLatch();
+ if (PreAR && PreAR->getNoWrapFlags(SCEV::FlagNSW) &&
+ ExitingBlock != nullptr && ExitingBlock == LatchBlock)
return PreStart;
// 2. Direct overflow check on the step operation's expression.
getMulExpr(WideMaxBECount,
getZeroExtendExpr(Step, WideTy)));
if (SAdd == OperandExtendedAdd) {
- // Cache knowledge of AR NSW, which is propagated to this AddRec.
- const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
+ // If AR wraps around then
+ //
+ // abs(Step) * MaxBECount > unsigned-max(AR->getType())
+ // => SAdd != OperandExtendedAdd
+ //
+ // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
+ // (SAdd == OperandExtendedAdd => AR is NW)
+
+ const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
+
// Return the expression with the addrec on the outside.
return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
getZeroExtendExpr(Step, Ty),
};
}
+// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
+// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
+// can't-overflow flags for the operation if possible.
+static SCEV::NoWrapFlags
+StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
+ const SmallVectorImpl<const SCEV *> &Ops,
+ SCEV::NoWrapFlags OldFlags) {
+ using namespace std::placeholders;
+
+ bool CanAnalyze =
+ Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
+ (void)CanAnalyze;
+ assert(CanAnalyze && "don't call from other places!");
+
+ int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
+ SCEV::NoWrapFlags SignOrUnsignWrap =
+ ScalarEvolution::maskFlags(OldFlags, SignOrUnsignMask);
+
+ // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
+ auto IsKnownNonNegative =
+ std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1);
+
+ if (SignOrUnsignWrap == SCEV::FlagNSW &&
+ std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative))
+ return ScalarEvolution::setFlags(OldFlags,
+ (SCEV::NoWrapFlags)SignOrUnsignMask);
+
+ return OldFlags;
+}
+
/// getAddExpr - Get a canonical add expression, or something simpler if
/// possible.
const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
"SCEVAddExpr operand types don't match!");
#endif
- // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
- // And vice-versa.
- int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
- SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
- if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
- bool All = true;
- for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
- E = Ops.end(); I != E; ++I)
- if (!isKnownNonNegative(*I)) {
- All = false;
- break;
- }
- if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
- }
+ Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags);
// Sort by complexity, this groups all similar expression types together.
GroupByComplexity(Ops, LI);
return r;
}
-/// getMulExpr - Get a canonical multiply expression, or something simpler if
-/// possible.
-const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
- SCEV::NoWrapFlags Flags) {
+/// Determine if any of the operands in this SCEV are a constant or if
+/// any of the add or multiply expressions in this SCEV contain a constant.
+static bool containsConstantSomewhere(const SCEV *StartExpr) {
+ SmallVector<const SCEV *, 4> Ops;
+ Ops.push_back(StartExpr);
+ while (!Ops.empty()) {
+ const SCEV *CurrentExpr = Ops.pop_back_val();
+ if (isa<SCEVConstant>(*CurrentExpr))
+ return true;
+
+ if (isa<SCEVAddExpr>(*CurrentExpr) || isa<SCEVMulExpr>(*CurrentExpr)) {
+ const auto *CurrentNAry = cast<SCEVNAryExpr>(CurrentExpr);
+ for (const SCEV *Operand : CurrentNAry->operands())
+ Ops.push_back(Operand);
+ }
+ }
+ return false;
+}
+
+/// getMulExpr - Get a canonical multiply expression, or something simpler if
+/// possible.
+const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
+ SCEV::NoWrapFlags Flags) {
assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) &&
"only nuw or nsw allowed");
assert(!Ops.empty() && "Cannot get empty mul!");
"SCEVMulExpr operand types don't match!");
#endif
- // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
- // And vice-versa.
- int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
- SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
- if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
- bool All = true;
- for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
- E = Ops.end(); I != E; ++I)
- if (!isKnownNonNegative(*I)) {
- All = false;
- break;
- }
- if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
- }
+ Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags);
// Sort by complexity, this groups all similar expression types together.
GroupByComplexity(Ops, LI);
// C1*(C2+V) -> C1*C2 + C1*V
if (Ops.size() == 2)
- if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
- if (Add->getNumOperands() == 2 &&
- isa<SCEVConstant>(Add->getOperand(0)))
- return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
- getMulExpr(LHSC, Add->getOperand(1)));
+ if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
+ // If any of Add's ops are Adds or Muls with a constant,
+ // apply this transformation as well.
+ if (Add->getNumOperands() == 2)
+ if (containsConstantSomewhere(Add))
+ return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
+ getMulExpr(LHSC, Add->getOperand(1)));
++Idx;
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
// Okay, if there weren't any loop invariants to be folded, check to see if
// there are multiple AddRec's with the same loop induction variable being
// multiplied together. If so, we can fold them.
+
+ // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
+ // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
+ // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
+ // ]]],+,...up to x=2n}.
+ // Note that the arguments to choose() are always integers with values
+ // known at compile time, never SCEV objects.
+ //
+ // The implementation avoids pointless extra computations when the two
+ // addrec's are of different length (mathematically, it's equivalent to
+ // an infinite stream of zeros on the right).
+ bool OpsModified = false;
for (unsigned OtherIdx = Idx+1;
- OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
+ OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
++OtherIdx) {
- if (AddRecLoop != cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop())
+ const SCEVAddRecExpr *OtherAddRec =
+ dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
+ if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
continue;
- // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
- // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
- // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
- // ]]],+,...up to x=2n}.
- // Note that the arguments to choose() are always integers with values
- // known at compile time, never SCEV objects.
- //
- // The implementation avoids pointless extra computations when the two
- // addrec's are of different length (mathematically, it's equivalent to
- // an infinite stream of zeros on the right).
- bool OpsModified = false;
- for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
- ++OtherIdx) {
- const SCEVAddRecExpr *OtherAddRec =
- dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
- if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
- continue;
-
- bool Overflow = false;
- Type *Ty = AddRec->getType();
- bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
- SmallVector<const SCEV*, 7> AddRecOps;
- for (int x = 0, xe = AddRec->getNumOperands() +
- OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
- const SCEV *Term = getConstant(Ty, 0);
- for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
- uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
- for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
- ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
- z < ze && !Overflow; ++z) {
- uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
- uint64_t Coeff;
- if (LargerThan64Bits)
- Coeff = umul_ov(Coeff1, Coeff2, Overflow);
- else
- Coeff = Coeff1*Coeff2;
- const SCEV *CoeffTerm = getConstant(Ty, Coeff);
- const SCEV *Term1 = AddRec->getOperand(y-z);
- const SCEV *Term2 = OtherAddRec->getOperand(z);
- Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2));
- }
+ bool Overflow = false;
+ Type *Ty = AddRec->getType();
+ bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
+ SmallVector<const SCEV*, 7> AddRecOps;
+ for (int x = 0, xe = AddRec->getNumOperands() +
+ OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
+ const SCEV *Term = getConstant(Ty, 0);
+ for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
+ uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
+ for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
+ ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
+ z < ze && !Overflow; ++z) {
+ uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
+ uint64_t Coeff;
+ if (LargerThan64Bits)
+ Coeff = umul_ov(Coeff1, Coeff2, Overflow);
+ else
+ Coeff = Coeff1*Coeff2;
+ const SCEV *CoeffTerm = getConstant(Ty, Coeff);
+ const SCEV *Term1 = AddRec->getOperand(y-z);
+ const SCEV *Term2 = OtherAddRec->getOperand(z);
+ Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2));
}
- AddRecOps.push_back(Term);
- }
- if (!Overflow) {
- const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
- SCEV::FlagAnyWrap);
- if (Ops.size() == 2) return NewAddRec;
- Ops[Idx] = NewAddRec;
- Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
- OpsModified = true;
- AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
- if (!AddRec)
- break;
}
+ AddRecOps.push_back(Term);
+ }
+ if (!Overflow) {
+ const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
+ SCEV::FlagAnyWrap);
+ if (Ops.size() == 2) return NewAddRec;
+ Ops[Idx] = NewAddRec;
+ Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
+ OpsModified = true;
+ AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
+ if (!AddRec)
+ break;
}
- if (OpsModified)
- return getMulExpr(Ops);
}
+ if (OpsModified)
+ return getMulExpr(Ops);
// Otherwise couldn't fold anything into this recurrence. Move onto the
// next one.
// meaningful BE count at this point (and if we don't, we'd be stuck
// with a SCEVCouldNotCompute as the cached BE count).
- // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
- // And vice-versa.
- int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
- SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
- if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
- bool All = true;
- for (SmallVectorImpl<const SCEV *>::const_iterator I = Operands.begin(),
- E = Operands.end(); I != E; ++I)
- if (!isKnownNonNegative(*I)) {
- All = false;
- break;
- }
- if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
- }
+ Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
// Canonicalize nested AddRecs in by nesting them in order of loop depth.
if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
if (LHS == RHS)
return getConstant(LHS->getType(), 0);
- // X - Y --> X + -Y
- return getAddExpr(LHS, getNegativeSCEV(RHS), Flags);
+ // X - Y --> X + -Y.
+ // X -(nsw || nuw) Y --> X + -Y.
+ return getAddExpr(LHS, getNegativeSCEV(RHS));
}
/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
Visited.insert(PN);
while (!Worklist.empty()) {
Instruction *I = Worklist.pop_back_val();
- if (!Visited.insert(I)) continue;
+ if (!Visited.insert(I).second)
+ continue;
ValueExprMapType::iterator It =
ValueExprMap.find_as(static_cast<Value *>(I));
if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr)))
Flags = setFlags(Flags, SCEV::FlagNUW);
}
- } else if (const SubOperator *OBO =
- dyn_cast<SubOperator>(BEValueV)) {
- if (OBO->hasNoUnsignedWrap())
- Flags = setFlags(Flags, SCEV::FlagNUW);
- if (OBO->hasNoSignedWrap())
- Flags = setFlags(Flags, SCEV::FlagNSW);
+
+ // We cannot transfer nuw and nsw flags from subtraction
+ // operations -- sub nuw X, Y is not the same as add nuw X, -Y
+ // for instance.
}
const SCEV *StartVal = getSCEV(StartValueV);
// PHI's incoming blocks are in a different loop, in which case doing so
// risks breaking LCSSA form. Instcombine would normally zap these, but
// it doesn't have DominatorTree information, so it may miss cases.
- if (Value *V = SimplifyInstruction(PN, DL, TLI, DT))
+ if (Value *V = SimplifyInstruction(PN, DL, TLI, DT, AC))
if (LI->replacementPreservesLCSSAForm(PN, V))
return getSCEV(V);
// For a SCEVUnknown, ask ValueTracking.
unsigned BitWidth = getTypeSizeInBits(U->getType());
APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
- computeKnownBits(U->getValue(), Zeros, Ones);
+ computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT);
return Zeros.countTrailingOnes();
}
return 0;
}
+/// GetRangeFromMetadata - Helper method to assign a range to V from
+/// metadata present in the IR.
+static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
+ if (Instruction *I = dyn_cast<Instruction>(V)) {
+ if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) {
+ ConstantRange TotalRange(
+ cast<IntegerType>(I->getType())->getBitWidth(), false);
+
+ unsigned NumRanges = MD->getNumOperands() / 2;
+ assert(NumRanges >= 1);
+
+ for (unsigned i = 0; i < NumRanges; ++i) {
+ ConstantInt *Lower =
+ mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 0));
+ ConstantInt *Upper =
+ mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 1));
+ ConstantRange Range(Lower->getValue(), Upper->getValue());
+ TotalRange = TotalRange.unionWith(Range);
+ }
+
+ return TotalRange;
+ }
+ }
+
+ return None;
+}
+
/// getUnsignedRange - Determine the unsigned range for a particular SCEV.
///
ConstantRange
}
if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
+ // Check if the IR explicitly contains !range metadata.
+ Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
+ if (MDRange.hasValue())
+ ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue());
+
// For a SCEVUnknown, ask ValueTracking.
APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
- computeKnownBits(U->getValue(), Zeros, Ones, DL);
+ computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT);
if (Ones == ~Zeros + 1)
return setUnsignedRange(U, ConservativeResult);
return setUnsignedRange(U,
}
if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
+ // Check if the IR explicitly contains !range metadata.
+ Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue());
+ if (MDRange.hasValue())
+ ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue());
+
// For a SCEVUnknown, ask ValueTracking.
if (!U->getValue()->getType()->isIntegerTy() && !DL)
return setSignedRange(U, ConservativeResult);
- unsigned NS = ComputeNumSignBits(U->getValue(), DL);
+ unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AC, nullptr, DT);
if (NS <= 1)
return setSignedRange(U, ConservativeResult);
return setSignedRange(U, ConservativeResult.intersectWith(
unsigned TZ = A.countTrailingZeros();
unsigned BitWidth = A.getBitWidth();
APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
- computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL);
+ computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL, 0, AC,
+ nullptr, DT);
APInt EffectiveMask =
APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
case ICmpInst::ICMP_SGE:
// a >s b ? a+x : b+x -> smax(a, b)+x
// a >s b ? b+x : a+x -> smin(a, b)+x
- if (LHS->getType() == U->getType()) {
- const SCEV *LS = getSCEV(LHS);
- const SCEV *RS = getSCEV(RHS);
+ if (getTypeSizeInBits(LHS->getType()) <=
+ getTypeSizeInBits(U->getType())) {
+ const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), U->getType());
+ const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), U->getType());
const SCEV *LA = getSCEV(U->getOperand(1));
const SCEV *RA = getSCEV(U->getOperand(2));
const SCEV *LDiff = getMinusSCEV(LA, LS);
case ICmpInst::ICMP_UGE:
// a >u b ? a+x : b+x -> umax(a, b)+x
// a >u b ? b+x : a+x -> umin(a, b)+x
- if (LHS->getType() == U->getType()) {
- const SCEV *LS = getSCEV(LHS);
- const SCEV *RS = getSCEV(RHS);
+ if (getTypeSizeInBits(LHS->getType()) <=
+ getTypeSizeInBits(U->getType())) {
+ const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
+ const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), U->getType());
const SCEV *LA = getSCEV(U->getOperand(1));
const SCEV *RA = getSCEV(U->getOperand(2));
const SCEV *LDiff = getMinusSCEV(LA, LS);
break;
case ICmpInst::ICMP_NE:
// n != 0 ? n+x : 1+x -> umax(n, 1)+x
- if (LHS->getType() == U->getType() &&
- isa<ConstantInt>(RHS) &&
- cast<ConstantInt>(RHS)->isZero()) {
- const SCEV *One = getConstant(LHS->getType(), 1);
- const SCEV *LS = getSCEV(LHS);
+ if (getTypeSizeInBits(LHS->getType()) <=
+ getTypeSizeInBits(U->getType()) &&
+ isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+ const SCEV *One = getConstant(U->getType(), 1);
+ const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
const SCEV *LA = getSCEV(U->getOperand(1));
const SCEV *RA = getSCEV(U->getOperand(2));
const SCEV *LDiff = getMinusSCEV(LA, LS);
break;
case ICmpInst::ICMP_EQ:
// n == 0 ? 1+x : n+x -> umax(n, 1)+x
- if (LHS->getType() == U->getType() &&
- isa<ConstantInt>(RHS) &&
- cast<ConstantInt>(RHS)->isZero()) {
- const SCEV *One = getConstant(LHS->getType(), 1);
- const SCEV *LS = getSCEV(LHS);
+ if (getTypeSizeInBits(LHS->getType()) <=
+ getTypeSizeInBits(U->getType()) &&
+ isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
+ const SCEV *One = getConstant(U->getType(), 1);
+ const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), U->getType());
const SCEV *LA = getSCEV(U->getOperand(1));
const SCEV *RA = getSCEV(U->getOperand(2));
const SCEV *LDiff = getMinusSCEV(LA, One);
// Iteration Count Computation Code
//
+unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) {
+ if (BasicBlock *ExitingBB = L->getExitingBlock())
+ return getSmallConstantTripCount(L, ExitingBB);
+
+ // No trip count information for multiple exits.
+ return 0;
+}
+
/// getSmallConstantTripCount - Returns the maximum trip count of this loop as a
/// normal unsigned value. Returns 0 if the trip count is unknown or not
/// constant. Will also return 0 if the maximum trip count is very large (>=
/// before taking the branch. For loops with multiple exits, it may not be the
/// number times that the loop header executes because the loop may exit
/// prematurely via another branch.
-///
-/// FIXME: We conservatively call getBackedgeTakenCount(L) instead of
-/// getExitCount(L, ExitingBlock) to compute a safe trip count considering all
-/// loop exits. getExitCount() may return an exact count for this branch
-/// assuming no-signed-wrap. The number of well-defined iterations may actually
-/// be higher than this trip count if this exit test is skipped and the loop
-/// exits via a different branch. Ideally, getExitCount() would know whether it
-/// depends on a NSW assumption, and we would only fall back to a conservative
-/// trip count in that case.
-unsigned ScalarEvolution::
-getSmallConstantTripCount(Loop *L, BasicBlock * /*ExitingBlock*/) {
+unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L,
+ BasicBlock *ExitingBlock) {
+ assert(ExitingBlock && "Must pass a non-null exiting block!");
+ assert(L->isLoopExiting(ExitingBlock) &&
+ "Exiting block must actually branch out of the loop!");
const SCEVConstant *ExitCount =
- dyn_cast<SCEVConstant>(getBackedgeTakenCount(L));
+ dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
if (!ExitCount)
return 0;
return ((unsigned)ExitConst->getZExtValue()) + 1;
}
+unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) {
+ if (BasicBlock *ExitingBB = L->getExitingBlock())
+ return getSmallConstantTripMultiple(L, ExitingBB);
+
+ // No trip multiple information for multiple exits.
+ return 0;
+}
+
/// getSmallConstantTripMultiple - Returns the largest constant divisor of the
/// trip count of this loop as a normal unsigned value, if possible. This
/// means that the actual trip count is always a multiple of the returned
///
/// As explained in the comments for getSmallConstantTripCount, this assumes
/// that control exits the loop via ExitingBlock.
-unsigned ScalarEvolution::
-getSmallConstantTripMultiple(Loop *L, BasicBlock * /*ExitingBlock*/) {
- const SCEV *ExitCount = getBackedgeTakenCount(L);
+unsigned
+ScalarEvolution::getSmallConstantTripMultiple(Loop *L,
+ BasicBlock *ExitingBlock) {
+ assert(ExitingBlock && "Must pass a non-null exiting block!");
+ assert(L->isLoopExiting(ExitingBlock) &&
+ "Exiting block must actually branch out of the loop!");
+ const SCEV *ExitCount = getExitCount(L, ExitingBlock);
if (ExitCount == getCouldNotCompute())
return 1;
SmallPtrSet<Instruction *, 8> Visited;
while (!Worklist.empty()) {
Instruction *I = Worklist.pop_back_val();
- if (!Visited.insert(I)) continue;
+ if (!Visited.insert(I).second)
+ continue;
ValueExprMapType::iterator It =
ValueExprMap.find_as(static_cast<Value *>(I));
SmallPtrSet<Instruction *, 8> Visited;
while (!Worklist.empty()) {
Instruction *I = Worklist.pop_back_val();
- if (!Visited.insert(I)) continue;
+ if (!Visited.insert(I).second)
+ continue;
ValueExprMapType::iterator It =
ValueExprMap.find_as(static_cast<Value *>(I));
SmallPtrSet<Instruction *, 8> Visited;
while (!Worklist.empty()) {
I = Worklist.pop_back_val();
- if (!Visited.insert(I)) continue;
+ if (!Visited.insert(I).second)
+ continue;
ValueExprMapType::iterator It =
ValueExprMap.find_as(static_cast<Value *>(I));
// non-exiting iterations. Partition the loop exits into two kinds:
// LoopMustExits and LoopMayExits.
//
- // A LoopMustExit meets two requirements:
- //
- // (a) Its ExitLimit.MustExit flag must be set which indicates that the exit
- // test condition cannot be skipped (the tested variable has unit stride or
- // the test is less-than or greater-than, rather than a strict inequality).
- //
- // (b) It must dominate the loop latch, hence must be tested on every loop
- // iteration.
- //
- // If any computable LoopMustExit is found, then MaxBECount is the minimum
- // EL.Max of computable LoopMustExits. Otherwise, MaxBECount is
- // conservatively the maximum EL.Max, where CouldNotCompute is considered
- // greater than any computable EL.Max.
- if (EL.MustExit && EL.Max != getCouldNotCompute() && Latch &&
+ // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
+ // is a LoopMayExit. If any computable LoopMustExit is found, then
+ // MaxBECount is the minimum EL.Max of computable LoopMustExits. Otherwise,
+ // MaxBECount is conservatively the maximum EL.Max, where CouldNotCompute is
+ // considered greater than any computable EL.Max.
+ if (EL.Max != getCouldNotCompute() && Latch &&
DT->dominates(ExitBB, Latch)) {
if (!MustExitMaxBECount)
MustExitMaxBECount = EL.Max;
return getCouldNotCompute();
}
+ bool IsOnlyExit = (L->getExitingBlock() != nullptr);
TerminatorInst *Term = ExitingBlock->getTerminator();
if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
assert(BI->isConditional() && "If unconditional, it can't be in loop!");
// Proceed to the next level to examine the exit condition expression.
return ComputeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0),
BI->getSuccessor(1),
- /*IsSubExpr=*/false);
+ /*ControlsExit=*/IsOnlyExit);
}
if (SwitchInst *SI = dyn_cast<SwitchInst>(Term))
return ComputeExitLimitFromSingleExitSwitch(L, SI, Exit,
- /*IsSubExpr=*/false);
+ /*ControlsExit=*/IsOnlyExit);
return getCouldNotCompute();
}
/// backedge of the specified loop will execute if its exit condition
/// were a conditional branch of ExitCond, TBB, and FBB.
///
-/// @param IsSubExpr is true if ExitCond does not directly control the exit
-/// branch. In this case, we cannot assume that the loop only exits when the
-/// condition is true and cannot infer that failing to meet the condition prior
-/// to integer wraparound results in undefined behavior.
+/// @param ControlsExit is true if ExitCond directly controls the exit
+/// branch. In this case, we can assume that the loop exits only if the
+/// condition is true and can infer that failing to meet the condition prior to
+/// integer wraparound results in undefined behavior.
ScalarEvolution::ExitLimit
ScalarEvolution::ComputeExitLimitFromCond(const Loop *L,
Value *ExitCond,
BasicBlock *TBB,
BasicBlock *FBB,
- bool IsSubExpr) {
+ bool ControlsExit) {
// Check if the controlling expression for this loop is an And or Or.
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
if (BO->getOpcode() == Instruction::And) {
// Recurse on the operands of the and.
bool EitherMayExit = L->contains(TBB);
ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB,
- IsSubExpr || EitherMayExit);
+ ControlsExit && !EitherMayExit);
ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB,
- IsSubExpr || EitherMayExit);
+ ControlsExit && !EitherMayExit);
const SCEV *BECount = getCouldNotCompute();
const SCEV *MaxBECount = getCouldNotCompute();
- bool MustExit = false;
if (EitherMayExit) {
// Both conditions must be true for the loop to continue executing.
// Choose the less conservative count.
MaxBECount = EL0.Max;
else
MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max);
- MustExit = EL0.MustExit || EL1.MustExit;
} else {
// Both conditions must be true at the same time for the loop to exit.
// For now, be conservative.
MaxBECount = EL0.Max;
if (EL0.Exact == EL1.Exact)
BECount = EL0.Exact;
- MustExit = EL0.MustExit && EL1.MustExit;
}
- return ExitLimit(BECount, MaxBECount, MustExit);
+ return ExitLimit(BECount, MaxBECount);
}
if (BO->getOpcode() == Instruction::Or) {
// Recurse on the operands of the or.
bool EitherMayExit = L->contains(FBB);
ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB,
- IsSubExpr || EitherMayExit);
+ ControlsExit && !EitherMayExit);
ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB,
- IsSubExpr || EitherMayExit);
+ ControlsExit && !EitherMayExit);
const SCEV *BECount = getCouldNotCompute();
const SCEV *MaxBECount = getCouldNotCompute();
- bool MustExit = false;
if (EitherMayExit) {
// Both conditions must be false for the loop to continue executing.
// Choose the less conservative count.
MaxBECount = EL0.Max;
else
MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max);
- MustExit = EL0.MustExit || EL1.MustExit;
} else {
// Both conditions must be false at the same time for the loop to exit.
// For now, be conservative.
MaxBECount = EL0.Max;
if (EL0.Exact == EL1.Exact)
BECount = EL0.Exact;
- MustExit = EL0.MustExit && EL1.MustExit;
}
- return ExitLimit(BECount, MaxBECount, MustExit);
+ return ExitLimit(BECount, MaxBECount);
}
}
// With an icmp, it may be feasible to compute an exact backedge-taken count.
// Proceed to the next level to examine the icmp.
if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
- return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, IsSubExpr);
+ return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit);
// Check for a constant condition. These are normally stripped out by
// SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
ICmpInst *ExitCond,
BasicBlock *TBB,
BasicBlock *FBB,
- bool IsSubExpr) {
+ bool ControlsExit) {
// If the condition was exit on true, convert the condition to exit on false
ICmpInst::Predicate Cond;
switch (Cond) {
case ICmpInst::ICMP_NE: { // while (X != Y)
// Convert to: while (X-Y != 0)
- ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, IsSubExpr);
+ ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
if (EL.hasAnyInfo()) return EL;
break;
}
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_ULT: { // while (X < Y)
bool IsSigned = Cond == ICmpInst::ICMP_SLT;
- ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, IsSubExpr);
+ ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, ControlsExit);
if (EL.hasAnyInfo()) return EL;
break;
}
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_UGT: { // while (X > Y)
bool IsSigned = Cond == ICmpInst::ICMP_SGT;
- ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, IsSubExpr);
+ ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit);
if (EL.hasAnyInfo()) return EL;
break;
}
ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L,
SwitchInst *Switch,
BasicBlock *ExitingBlock,
- bool IsSubExpr) {
+ bool ControlsExit) {
assert(!L->contains(ExitingBlock) && "Not an exiting block!");
// Give up if the exit is the default dest of a switch.
const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
// while (X != Y) --> while (X-Y != 0)
- ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, IsSubExpr);
+ ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit);
if (EL.hasAnyInfo())
return EL;
/// effectively V != 0. We know and take advantage of the fact that this
/// expression only being used in a comparison by zero context.
ScalarEvolution::ExitLimit
-ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr) {
+ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) {
// If the value is a constant
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
// If the value is already zero, the branch will execute zero times.
else
MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
: -CR.getUnsignedMin());
- return ExitLimit(Distance, MaxBECount, /*MustExit=*/true);
- }
-
- // If the recurrence is known not to wraparound, unsigned divide computes the
- // back edge count. (Ideally we would have an "isexact" bit for udiv). We know
- // that the value will either become zero (and thus the loop terminates), that
- // the loop will terminate through some other exit condition first, or that
- // the loop has undefined behavior. This means we can't "miss" the exit
- // value, even with nonunit stride, and exit later via the same branch. Note
- // that we can skip this exit if loop later exits via a different
- // branch. Hence MustExit=false.
- //
- // This is only valid for expressions that directly compute the loop exit. It
- // is invalid for subexpressions in which the loop may exit through this
- // branch even if this subexpression is false. In that case, the trip count
- // computed by this udiv could be smaller than the number of well-defined
- // iterations.
- if (!IsSubExpr && AddRec->getNoWrapFlags(SCEV::FlagNW)) {
+ return ExitLimit(Distance, MaxBECount);
+ }
+
+ // As a special case, handle the instance where Step is a positive power of
+ // two. In this case, determining whether Step divides Distance evenly can be
+ // done by counting and comparing the number of trailing zeros of Step and
+ // Distance.
+ if (!CountDown) {
+ const APInt &StepV = StepC->getValue()->getValue();
+ // StepV.isPowerOf2() returns true if StepV is an positive power of two. It
+ // also returns true if StepV is maximally negative (eg, INT_MIN), but that
+ // case is not handled as this code is guarded by !CountDown.
+ if (StepV.isPowerOf2() &&
+ GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros())
+ return getUDivExactExpr(Distance, Step);
+ }
+
+ // If the condition controls loop exit (the loop exits only if the expression
+ // is true) and the addition is no-wrap we can use unsigned divide to
+ // compute the backedge count. In this case, the step may not divide the
+ // distance, but we don't care because if the condition is "missed" the loop
+ // will have undefined behavior due to wrapping.
+ if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) {
const SCEV *Exact =
- getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
- return ExitLimit(Exact, Exact, /*MustExit=*/false);
+ getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
+ return ExitLimit(Exact, Exact);
}
- // If Step is a power of two that evenly divides Start we know that the loop
- // will always terminate. Start may not be a constant so we just have the
- // number of trailing zeros available. This is safe even in presence of
- // overflow as the recurrence will overflow to exactly 0.
- const APInt &StepV = StepC->getValue()->getValue();
- if (StepV.isPowerOf2() &&
- GetMinTrailingZeros(getNegativeSCEV(Start)) >= StepV.countTrailingZeros())
- return getUDivExactExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
-
// Then, try to solve the above equation provided that Start is constant.
if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
// (interprocedural conditions notwithstanding).
if (!L) return true;
+ if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true;
+
BasicBlock *Latch = L->getLoopLatch();
if (!Latch)
return false;
BranchInst *LoopContinuePredicate =
dyn_cast<BranchInst>(Latch->getTerminator());
- if (!LoopContinuePredicate ||
- LoopContinuePredicate->isUnconditional())
- return false;
+ if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
+ isImpliedCond(Pred, LHS, RHS,
+ LoopContinuePredicate->getCondition(),
+ LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
+ return true;
+
+ // Check conditions due to any @llvm.assume intrinsics.
+ for (auto &AssumeVH : AC->assumptions()) {
+ if (!AssumeVH)
+ continue;
+ auto *CI = cast<CallInst>(AssumeVH);
+ if (!DT->dominates(CI, Latch->getTerminator()))
+ continue;
- return isImpliedCond(Pred, LHS, RHS,
- LoopContinuePredicate->getCondition(),
- LoopContinuePredicate->getSuccessor(0) != L->getHeader());
+ if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
+ return true;
+ }
+
+ return false;
}
/// isLoopEntryGuardedByCond - Test whether entry to the loop is protected
// (interprocedural conditions notwithstanding).
if (!L) return false;
+ if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true;
+
// Starting at the loop predecessor, climb up the predecessor chain, as long
// as there are predecessors that can be found that have unique successors
// leading to the original header.
return true;
}
+ // Check conditions due to any @llvm.assume intrinsics.
+ for (auto &AssumeVH : AC->assumptions()) {
+ if (!AssumeVH)
+ continue;
+ auto *CI = cast<CallInst>(AssumeVH);
+ if (!DT->dominates(CI, L->getHeader()))
+ continue;
+
+ if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
+ return true;
+ }
+
return false;
}
RHS, LHS, FoundLHS, FoundRHS);
}
+ // Check if we can make progress by sharpening ranges.
+ if (FoundPred == ICmpInst::ICMP_NE &&
+ (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
+
+ const SCEVConstant *C = nullptr;
+ const SCEV *V = nullptr;
+
+ if (isa<SCEVConstant>(FoundLHS)) {
+ C = cast<SCEVConstant>(FoundLHS);
+ V = FoundRHS;
+ } else {
+ C = cast<SCEVConstant>(FoundRHS);
+ V = FoundLHS;
+ }
+
+ // The guarding predicate tells us that C != V. If the known range
+ // of V is [C, t), we can sharpen the range to [C + 1, t). The
+ // range we consider has to correspond to same signedness as the
+ // predicate we're interested in folding.
+
+ APInt Min = ICmpInst::isSigned(Pred) ?
+ getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin();
+
+ if (Min == C->getValue()->getValue()) {
+ // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
+ // This is true even if (Min + 1) wraps around -- in case of
+ // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
+
+ APInt SharperMin = Min + 1;
+
+ switch (Pred) {
+ case ICmpInst::ICMP_SGE:
+ case ICmpInst::ICMP_UGE:
+ // We know V `Pred` SharperMin. If this implies LHS `Pred`
+ // RHS, we're done.
+ if (isImpliedCondOperands(Pred, LHS, RHS, V,
+ getConstant(SharperMin)))
+ return true;
+
+ case ICmpInst::ICMP_SGT:
+ case ICmpInst::ICMP_UGT:
+ // We know from the range information that (V `Pred` Min ||
+ // V == Min). We know from the guarding condition that !(V
+ // == Min). This gives us
+ //
+ // V `Pred` Min || V == Min && !(V == Min)
+ // => V `Pred` Min
+ //
+ // If V `Pred` Min implies LHS `Pred` RHS, we're done.
+
+ if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min)))
+ return true;
+
+ default:
+ // No change
+ break;
+ }
+ }
+ }
+
// Check whether the actual condition is beyond sufficient.
if (FoundPred == ICmpInst::ICMP_EQ)
if (ICmpInst::isTrueWhenEqual(Pred))
getNotSCEV(FoundLHS));
}
+
+/// If Expr computes ~A, return A else return nullptr
+static const SCEV *MatchNotExpr(const SCEV *Expr) {
+ const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
+ if (!Add || Add->getNumOperands() != 2) return nullptr;
+
+ const SCEVConstant *AddLHS = dyn_cast<SCEVConstant>(Add->getOperand(0));
+ if (!(AddLHS && AddLHS->getValue()->getValue().isAllOnesValue()))
+ return nullptr;
+
+ const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
+ if (!AddRHS || AddRHS->getNumOperands() != 2) return nullptr;
+
+ const SCEVConstant *MulLHS = dyn_cast<SCEVConstant>(AddRHS->getOperand(0));
+ if (!(MulLHS && MulLHS->getValue()->getValue().isAllOnesValue()))
+ return nullptr;
+
+ return AddRHS->getOperand(1);
+}
+
+
+/// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values?
+template<typename MaxExprType>
+static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr,
+ const SCEV *Candidate) {
+ const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr);
+ if (!MaxExpr) return false;
+
+ auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate);
+ return It != MaxExpr->op_end();
+}
+
+
+/// Is MaybeMinExpr an SMin or UMin of Candidate and some other values?
+template<typename MaxExprType>
+static bool IsMinConsistingOf(ScalarEvolution &SE,
+ const SCEV *MaybeMinExpr,
+ const SCEV *Candidate) {
+ const SCEV *MaybeMaxExpr = MatchNotExpr(MaybeMinExpr);
+ if (!MaybeMaxExpr)
+ return false;
+
+ return IsMaxConsistingOf<MaxExprType>(MaybeMaxExpr, SE.getNotSCEV(Candidate));
+}
+
+
+/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
+/// expression?
+static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
+ ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS) {
+ switch (Pred) {
+ default:
+ return false;
+
+ case ICmpInst::ICMP_SGE:
+ std::swap(LHS, RHS);
+ // fall through
+ case ICmpInst::ICMP_SLE:
+ return
+ // min(A, ...) <= A
+ IsMinConsistingOf<SCEVSMaxExpr>(SE, LHS, RHS) ||
+ // A <= max(A, ...)
+ IsMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
+
+ case ICmpInst::ICMP_UGE:
+ std::swap(LHS, RHS);
+ // fall through
+ case ICmpInst::ICMP_ULE:
+ return
+ // min(A, ...) <= A
+ IsMinConsistingOf<SCEVUMaxExpr>(SE, LHS, RHS) ||
+ // A <= max(A, ...)
+ IsMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
+ }
+
+ llvm_unreachable("covered switch fell through?!");
+}
+
/// isImpliedCondOperandsHelper - Test whether the condition described by
/// Pred, LHS, and RHS is true whenever the condition described by Pred,
/// FoundLHS, and FoundRHS is true.
const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS,
const SCEV *FoundRHS) {
+ auto IsKnownPredicateFull =
+ [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
+ return isKnownPredicateWithRanges(Pred, LHS, RHS) ||
+ IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS);
+ };
+
switch (Pred) {
default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
case ICmpInst::ICMP_EQ:
break;
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
- if (isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
- isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, RHS, FoundRHS))
+ if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
+ IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS))
return true;
break;
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
- if (isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
- isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, RHS, FoundRHS))
+ if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
+ IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS))
return true;
break;
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
- if (isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
- isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, RHS, FoundRHS))
+ if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
+ IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS))
return true;
break;
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
- if (isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
- isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, RHS, FoundRHS))
+ if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
+ IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS))
return true;
break;
}
return false;
}
-// Verify if an linear IV with positive stride can overflow when in a
-// less-than comparison, knowing the invariant term of the comparison, the
+// Verify if an linear IV with positive stride can overflow when in a
+// less-than comparison, knowing the invariant term of the comparison, the
// stride and the knowledge of NSW/NUW flags on the recurrence.
bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
bool IsSigned, bool NoWrap) {
return (MaxValue - MaxStrideMinusOne).ult(MaxRHS);
}
-// Verify if an linear IV with negative stride can overflow when in a
+// Verify if an linear IV with negative stride can overflow when in a
// greater-than comparison, knowing the invariant term of the comparison,
// the stride and the knowledge of NSW/NUW flags on the recurrence.
bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
// Compute the backedge taken count knowing the interval difference, the
// stride and presence of the equality in the comparison.
-const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step,
+const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step,
bool Equality) {
const SCEV *One = getConstant(Step->getType(), 1);
Delta = Equality ? getAddExpr(Delta, Step)
/// specified less-than comparison will execute. If not computable, return
/// CouldNotCompute.
///
-/// @param IsSubExpr is true when the LHS < RHS condition does not directly
-/// control the branch. In this case, we can only compute an iteration count for
-/// a subexpression that cannot overflow before evaluating true.
+/// @param ControlsExit is true when the LHS < RHS condition directly controls
+/// the branch (loops exits only if condition is true). In this case, we can use
+/// NoWrapFlags to skip overflow checks.
ScalarEvolution::ExitLimit
ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned,
- bool IsSubExpr) {
+ bool ControlsExit) {
// We handle only IV < Invariant
if (!isLoopInvariant(RHS, L))
return getCouldNotCompute();
if (!IV || IV->getLoop() != L || !IV->isAffine())
return getCouldNotCompute();
- bool NoWrap = !IsSubExpr &&
+ bool NoWrap = ControlsExit &&
IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW);
const SCEV *Stride = IV->getStepRecurrence(*this);
// Avoid proven overflow cases: this will ensure that the backedge taken count
// will not generate any unsigned overflow. Relaxed no-overflow conditions
- // exploit NoWrapFlags, allowing to optimize in presence of undefined
+ // exploit NoWrapFlags, allowing to optimize in presence of undefined
// behaviors like the case of C language.
if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap))
return getCouldNotCompute();
: ICmpInst::ICMP_ULT;
const SCEV *Start = IV->getStart();
const SCEV *End = RHS;
- if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS))
- End = IsSigned ? getSMaxExpr(RHS, Start)
- : getUMaxExpr(RHS, Start);
+ if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) {
+ const SCEV *Diff = getMinusSCEV(RHS, Start);
+ // If we have NoWrap set, then we can assume that the increment won't
+ // overflow, in which case if RHS - Start is a constant, we don't need to
+ // do a max operation since we can just figure it out statically
+ if (NoWrap && isa<SCEVConstant>(Diff)) {
+ APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue();
+ if (D.isNegative())
+ End = Start;
+ } else
+ End = IsSigned ? getSMaxExpr(RHS, Start)
+ : getUMaxExpr(RHS, Start);
+ }
const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false);
if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount;
- return ExitLimit(BECount, MaxBECount, /*MustExit=*/true);
+ return ExitLimit(BECount, MaxBECount);
}
ScalarEvolution::ExitLimit
ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
const Loop *L, bool IsSigned,
- bool IsSubExpr) {
+ bool ControlsExit) {
// We handle only IV > Invariant
if (!isLoopInvariant(RHS, L))
return getCouldNotCompute();
if (!IV || IV->getLoop() != L || !IV->isAffine())
return getCouldNotCompute();
- bool NoWrap = !IsSubExpr &&
+ bool NoWrap = ControlsExit &&
IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW);
const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
// Avoid proven overflow cases: this will ensure that the backedge taken count
// will not generate any unsigned overflow. Relaxed no-overflow conditions
- // exploit NoWrapFlags, allowing to optimize in presence of undefined
+ // exploit NoWrapFlags, allowing to optimize in presence of undefined
// behaviors like the case of C language.
if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap))
return getCouldNotCompute();
const SCEV *Start = IV->getStart();
const SCEV *End = RHS;
- if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS))
- End = IsSigned ? getSMinExpr(RHS, Start)
- : getUMinExpr(RHS, Start);
+ if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
+ const SCEV *Diff = getMinusSCEV(RHS, Start);
+ // If we have NoWrap set, then we can assume that the increment won't
+ // overflow, in which case if RHS - Start is a constant, we don't need to
+ // do a max operation since we can just figure it out statically
+ if (NoWrap && isa<SCEVConstant>(Diff)) {
+ APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue();
+ if (!D.isNegative())
+ End = Start;
+ } else
+ End = IsSigned ? getSMinExpr(RHS, Start)
+ : getUMinExpr(RHS, Start);
+ }
const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false);
if (isa<SCEVConstant>(BECount))
MaxBECount = BECount;
else
- MaxBECount = computeBECount(getConstant(MaxStart - MinEnd),
+ MaxBECount = computeBECount(getConstant(MaxStart - MinEnd),
getConstant(MinStride), false);
if (isa<SCEVCouldNotCompute>(MaxBECount))
MaxBECount = BECount;
- return ExitLimit(BECount, MaxBECount, /*MustExit=*/true);
+ return ExitLimit(BECount, MaxBECount);
}
/// getNumIterationsInRange - Return the number of iterations of this loop that
});
}
-static const APInt srem(const SCEVConstant *C1, const SCEVConstant *C2) {
- APInt A = C1->getValue()->getValue();
- APInt B = C2->getValue()->getValue();
- uint32_t ABW = A.getBitWidth();
- uint32_t BBW = B.getBitWidth();
-
- if (ABW > BBW)
- B = B.sext(ABW);
- else if (ABW < BBW)
- A = A.sext(BBW);
-
- return APIntOps::srem(A, B);
-}
-
-static const APInt sdiv(const SCEVConstant *C1, const SCEVConstant *C2) {
- APInt A = C1->getValue()->getValue();
- APInt B = C2->getValue()->getValue();
- uint32_t ABW = A.getBitWidth();
- uint32_t BBW = B.getBitWidth();
-
- if (ABW > BBW)
- B = B.sext(ABW);
- else if (ABW < BBW)
- A = A.sext(BBW);
-
- return APIntOps::sdiv(A, B);
-}
-
-namespace {
-struct FindSCEVSize {
- int Size;
- FindSCEVSize() : Size(0) {}
-
- bool follow(const SCEV *S) {
- ++Size;
- // Keep looking at all operands of S.
- return true;
- }
- bool isDone() const {
- return false;
- }
-};
-}
-
-// Returns the size of the SCEV S.
-static inline int sizeOfSCEV(const SCEV *S) {
- FindSCEVSize F;
- SCEVTraversal<FindSCEVSize> ST(F);
- ST.visitAll(S);
- return F.Size;
-}
-
-namespace {
-
-struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> {
-public:
- // Computes the Quotient and Remainder of the division of Numerator by
- // Denominator.
- static void divide(ScalarEvolution &SE, const SCEV *Numerator,
- const SCEV *Denominator, const SCEV **Quotient,
- const SCEV **Remainder) {
- assert(Numerator && Denominator && "Uninitialized SCEV");
-
- SCEVDivision D(SE, Numerator, Denominator);
-
- // Check for the trivial case here to avoid having to check for it in the
- // rest of the code.
- if (Numerator == Denominator) {
- *Quotient = D.One;
- *Remainder = D.Zero;
- return;
- }
-
- if (Numerator->isZero()) {
- *Quotient = D.Zero;
- *Remainder = D.Zero;
- return;
- }
-
- // Split the Denominator when it is a product.
- if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) {
- const SCEV *Q, *R;
- *Quotient = Numerator;
- for (const SCEV *Op : T->operands()) {
- divide(SE, *Quotient, Op, &Q, &R);
- *Quotient = Q;
-
- // Bail out when the Numerator is not divisible by one of the terms of
- // the Denominator.
- if (!R->isZero()) {
- *Quotient = D.Zero;
- *Remainder = Numerator;
- return;
- }
- }
- *Remainder = D.Zero;
- return;
- }
-
- D.visit(Numerator);
- *Quotient = D.Quotient;
- *Remainder = D.Remainder;
- }
-
- SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, const SCEV *Denominator)
- : SE(S), Denominator(Denominator) {
- Zero = SE.getConstant(Denominator->getType(), 0);
- One = SE.getConstant(Denominator->getType(), 1);
-
- // By default, we don't know how to divide Expr by Denominator.
- // Providing the default here simplifies the rest of the code.
- Quotient = Zero;
- Remainder = Numerator;
- }
-
- // Except in the trivial case described above, we do not know how to divide
- // Expr by Denominator for the following functions with empty implementation.
- void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {}
- void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {}
- void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {}
- void visitUDivExpr(const SCEVUDivExpr *Numerator) {}
- void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {}
- void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {}
- void visitUnknown(const SCEVUnknown *Numerator) {}
- void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {}
-
- void visitConstant(const SCEVConstant *Numerator) {
- if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
- Quotient = SE.getConstant(sdiv(Numerator, D));
- Remainder = SE.getConstant(srem(Numerator, D));
- return;
- }
- }
-
- void visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
- const SCEV *StartQ, *StartR, *StepQ, *StepR;
- assert(Numerator->isAffine() && "Numerator should be affine");
- divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR);
- divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR);
- Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(),
- Numerator->getNoWrapFlags());
- Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(),
- Numerator->getNoWrapFlags());
- }
-
- void visitAddExpr(const SCEVAddExpr *Numerator) {
- SmallVector<const SCEV *, 2> Qs, Rs;
- Type *Ty = Denominator->getType();
-
- for (const SCEV *Op : Numerator->operands()) {
- const SCEV *Q, *R;
- divide(SE, Op, Denominator, &Q, &R);
-
- // Bail out if types do not match.
- if (Ty != Q->getType() || Ty != R->getType()) {
- Quotient = Zero;
- Remainder = Numerator;
- return;
- }
-
- Qs.push_back(Q);
- Rs.push_back(R);
- }
-
- if (Qs.size() == 1) {
- Quotient = Qs[0];
- Remainder = Rs[0];
- return;
- }
-
- Quotient = SE.getAddExpr(Qs);
- Remainder = SE.getAddExpr(Rs);
- }
-
- void visitMulExpr(const SCEVMulExpr *Numerator) {
- SmallVector<const SCEV *, 2> Qs;
- Type *Ty = Denominator->getType();
-
- bool FoundDenominatorTerm = false;
- for (const SCEV *Op : Numerator->operands()) {
- // Bail out if types do not match.
- if (Ty != Op->getType()) {
- Quotient = Zero;
- Remainder = Numerator;
- return;
- }
-
- if (FoundDenominatorTerm) {
- Qs.push_back(Op);
- continue;
- }
-
- // Check whether Denominator divides one of the product operands.
- const SCEV *Q, *R;
- divide(SE, Op, Denominator, &Q, &R);
- if (!R->isZero()) {
- Qs.push_back(Op);
- continue;
- }
-
- // Bail out if types do not match.
- if (Ty != Q->getType()) {
- Quotient = Zero;
- Remainder = Numerator;
- return;
- }
-
- FoundDenominatorTerm = true;
- Qs.push_back(Q);
- }
-
- if (FoundDenominatorTerm) {
- Remainder = Zero;
- if (Qs.size() == 1)
- Quotient = Qs[0];
- else
- Quotient = SE.getMulExpr(Qs);
- return;
- }
-
- if (!isa<SCEVUnknown>(Denominator)) {
- Quotient = Zero;
- Remainder = Numerator;
- return;
- }
-
- // The Remainder is obtained by replacing Denominator by 0 in Numerator.
- ValueToValueMap RewriteMap;
- RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
- cast<SCEVConstant>(Zero)->getValue();
- Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
-
- if (Remainder->isZero()) {
- // The Quotient is obtained by replacing Denominator by 1 in Numerator.
- RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] =
- cast<SCEVConstant>(One)->getValue();
- Quotient =
- SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true);
- return;
- }
-
- // Quotient is (Numerator - Remainder) divided by Denominator.
- const SCEV *Q, *R;
- const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder);
- if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) {
- // This SCEV does not seem to simplify: fail the division here.
- Quotient = Zero;
- Remainder = Numerator;
- return;
- }
- divide(SE, Diff, Denominator, &Q, &R);
- assert(R == Zero &&
- "(Numerator - Remainder) should evenly divide Denominator");
- Quotient = Q;
- }
-
-private:
- ScalarEvolution &SE;
- const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One;
-};
-}
-
static bool findArrayDimensionsRec(ScalarEvolution &SE,
SmallVectorImpl<const SCEV *> &Terms,
SmallVectorImpl<const SCEV *> &Sizes) {
// that until everything else is done.
if (U == Old)
continue;
- if (!Visited.insert(U))
+ if (!Visited.insert(U).second)
continue;
if (PHINode *PN = dyn_cast<PHINode>(U))
SE->ConstantEvolutionLoopExitValue.erase(PN);
bool ScalarEvolution::runOnFunction(Function &F) {
this->F = &F;
- LI = &getAnalysis<LoopInfo>();
+ AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
+ LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
DL = DLP ? &DLP->getDataLayout() : nullptr;
- TLI = &getAnalysis<TargetLibraryInfo>();
+ TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
return false;
}
void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
- AU.addRequiredTransitive<LoopInfo>();
+ AU.addRequired<AssumptionCacheTracker>();
+ AU.addRequiredTransitive<LoopInfoWrapperPass>();
AU.addRequiredTransitive<DominatorTreeWrapperPass>();
- AU.addRequired<TargetLibraryInfo>();
+ AU.addRequired<TargetLibraryInfoWrapperPass>();
}
bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
ScalarEvolution::LoopDisposition
ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
- SmallVector<std::pair<const Loop *, LoopDisposition>, 2> &Values = LoopDispositions[S];
- for (unsigned u = 0; u < Values.size(); u++) {
- if (Values[u].first == L)
- return Values[u].second;
+ auto &Values = LoopDispositions[S];
+ for (auto &V : Values) {
+ if (V.getPointer() == L)
+ return V.getInt();
}
- Values.push_back(std::make_pair(L, LoopVariant));
+ Values.emplace_back(L, LoopVariant);
LoopDisposition D = computeLoopDisposition(S, L);
- SmallVector<std::pair<const Loop *, LoopDisposition>, 2> &Values2 = LoopDispositions[S];
- for (unsigned u = Values2.size(); u > 0; u--) {
- if (Values2[u - 1].first == L) {
- Values2[u - 1].second = D;
+ auto &Values2 = LoopDispositions[S];
+ for (auto &V : make_range(Values2.rbegin(), Values2.rend())) {
+ if (V.getPointer() == L) {
+ V.setInt(D);
break;
}
}
ScalarEvolution::BlockDisposition
ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
- SmallVector<std::pair<const BasicBlock *, BlockDisposition>, 2> &Values = BlockDispositions[S];
- for (unsigned u = 0; u < Values.size(); u++) {
- if (Values[u].first == BB)
- return Values[u].second;
+ auto &Values = BlockDispositions[S];
+ for (auto &V : Values) {
+ if (V.getPointer() == BB)
+ return V.getInt();
}
- Values.push_back(std::make_pair(BB, DoesNotDominateBlock));
+ Values.emplace_back(BB, DoesNotDominateBlock);
BlockDisposition D = computeBlockDisposition(S, BB);
- SmallVector<std::pair<const BasicBlock *, BlockDisposition>, 2> &Values2 = BlockDispositions[S];
- for (unsigned u = Values2.size(); u > 0; u--) {
- if (Values2[u - 1].first == BB) {
- Values2[u - 1].second = D;
+ auto &Values2 = BlockDispositions[S];
+ for (auto &V : make_range(Values2.rbegin(), Values2.rend())) {
+ if (V.getPointer() == BB) {
+ V.setInt(D);
break;
}
}