X-Git-Url: http://plrg.eecs.uci.edu/git/?p=oota-llvm.git;a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=ef1bb3a36c8d4205652560cfa9c04bf112c9e540;hp=4181ee6446adace97116e1fdfd996e2e0a602360;hb=44fb5881d8edf448d6231a5b8df583aecd6bcd42;hpb=8b6e5f368ca02dda35a1b0e6b84f39a0890214b1 diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 4181ee6446a..ef1bb3a36c8 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -83,6 +83,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -123,12 +124,11 @@ VerifySCEV("verify-scev", // Implementation of the SCEV class. // -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void SCEV::dump() const { print(dbgs()); dbgs() << '\n'; } -#endif void SCEV::print(raw_ostream &OS) const { switch (static_cast(getSCEVType())) { @@ -294,7 +294,7 @@ bool SCEV::isNonConstantNegative() const { if (!SC) return false; // Return true if the value is negative, this matches things like (-42 * V). - return SC->getValue()->getValue().isNegative(); + return SC->getAPInt().isNegative(); } SCEVCouldNotCompute::SCEVCouldNotCompute() : @@ -446,179 +446,179 @@ bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const { //===----------------------------------------------------------------------===// namespace { - /// SCEVComplexityCompare - Return true if the complexity of the LHS is less - /// than the complexity of the RHS. This comparator is used to canonicalize - /// expressions. - class SCEVComplexityCompare { - const LoopInfo *const LI; - public: - explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {} - - // Return true or false if LHS is less than, or at least RHS, respectively. - bool operator()(const SCEV *LHS, const SCEV *RHS) const { - return compare(LHS, RHS) < 0; - } - - // Return negative, zero, or positive, if LHS is less than, equal to, or - // greater than RHS, respectively. A three-way result allows recursive - // comparisons to be more efficient. - int compare(const SCEV *LHS, const SCEV *RHS) const { - // Fast-path: SCEVs are uniqued so we can do a quick equality check. - if (LHS == RHS) - return 0; - - // Primarily, sort the SCEVs by their getSCEVType(). - unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); - if (LType != RType) - return (int)LType - (int)RType; - - // Aside from the getSCEVType() ordering, the particular ordering - // isn't very important except that it's beneficial to be consistent, - // so that (a + b) and (b + a) don't end up as different expressions. - switch (static_cast(LType)) { - case scUnknown: { - const SCEVUnknown *LU = cast(LHS); - const SCEVUnknown *RU = cast(RHS); - - // Sort SCEVUnknown values with some loose heuristics. TODO: This is - // not as complete as it could be. - const Value *LV = LU->getValue(), *RV = RU->getValue(); - - // Order pointer values after integer values. This helps SCEVExpander - // form GEPs. - bool LIsPointer = LV->getType()->isPointerTy(), - RIsPointer = RV->getType()->isPointerTy(); - if (LIsPointer != RIsPointer) - return (int)LIsPointer - (int)RIsPointer; - - // Compare getValueID values. - unsigned LID = LV->getValueID(), - RID = RV->getValueID(); - if (LID != RID) - return (int)LID - (int)RID; - - // Sort arguments by their position. - if (const Argument *LA = dyn_cast(LV)) { - const Argument *RA = cast(RV); - unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); - return (int)LArgNo - (int)RArgNo; - } - - // For instructions, compare their loop depth, and their operand - // count. This is pretty loose. - if (const Instruction *LInst = dyn_cast(LV)) { - const Instruction *RInst = cast(RV); - - // Compare loop depths. - const BasicBlock *LParent = LInst->getParent(), - *RParent = RInst->getParent(); - if (LParent != RParent) { - unsigned LDepth = LI->getLoopDepth(LParent), - RDepth = LI->getLoopDepth(RParent); - if (LDepth != RDepth) - return (int)LDepth - (int)RDepth; - } - - // Compare the number of operands. - unsigned LNumOps = LInst->getNumOperands(), - RNumOps = RInst->getNumOperands(); - return (int)LNumOps - (int)RNumOps; - } +/// SCEVComplexityCompare - Return true if the complexity of the LHS is less +/// than the complexity of the RHS. This comparator is used to canonicalize +/// expressions. +class SCEVComplexityCompare { + const LoopInfo *const LI; +public: + explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {} - return 0; - } + // Return true or false if LHS is less than, or at least RHS, respectively. + bool operator()(const SCEV *LHS, const SCEV *RHS) const { + return compare(LHS, RHS) < 0; + } - case scConstant: { - const SCEVConstant *LC = cast(LHS); - const SCEVConstant *RC = cast(RHS); - - // Compare constant values. - const APInt &LA = LC->getValue()->getValue(); - const APInt &RA = RC->getValue()->getValue(); - unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); - if (LBitWidth != RBitWidth) - return (int)LBitWidth - (int)RBitWidth; - return LA.ult(RA) ? -1 : 1; + // Return negative, zero, or positive, if LHS is less than, equal to, or + // greater than RHS, respectively. A three-way result allows recursive + // comparisons to be more efficient. + int compare(const SCEV *LHS, const SCEV *RHS) const { + // Fast-path: SCEVs are uniqued so we can do a quick equality check. + if (LHS == RHS) + return 0; + + // Primarily, sort the SCEVs by their getSCEVType(). + unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); + if (LType != RType) + return (int)LType - (int)RType; + + // Aside from the getSCEVType() ordering, the particular ordering + // isn't very important except that it's beneficial to be consistent, + // so that (a + b) and (b + a) don't end up as different expressions. + switch (static_cast(LType)) { + case scUnknown: { + const SCEVUnknown *LU = cast(LHS); + const SCEVUnknown *RU = cast(RHS); + + // Sort SCEVUnknown values with some loose heuristics. TODO: This is + // not as complete as it could be. + const Value *LV = LU->getValue(), *RV = RU->getValue(); + + // Order pointer values after integer values. This helps SCEVExpander + // form GEPs. + bool LIsPointer = LV->getType()->isPointerTy(), + RIsPointer = RV->getType()->isPointerTy(); + if (LIsPointer != RIsPointer) + return (int)LIsPointer - (int)RIsPointer; + + // Compare getValueID values. + unsigned LID = LV->getValueID(), + RID = RV->getValueID(); + if (LID != RID) + return (int)LID - (int)RID; + + // Sort arguments by their position. + if (const Argument *LA = dyn_cast(LV)) { + const Argument *RA = cast(RV); + unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); + return (int)LArgNo - (int)RArgNo; } - case scAddRecExpr: { - const SCEVAddRecExpr *LA = cast(LHS); - const SCEVAddRecExpr *RA = cast(RHS); - - // Compare addrec loop depths. - const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); - if (LLoop != RLoop) { - unsigned LDepth = LLoop->getLoopDepth(), - RDepth = RLoop->getLoopDepth(); + // For instructions, compare their loop depth, and their operand + // count. This is pretty loose. + if (const Instruction *LInst = dyn_cast(LV)) { + const Instruction *RInst = cast(RV); + + // Compare loop depths. + const BasicBlock *LParent = LInst->getParent(), + *RParent = RInst->getParent(); + if (LParent != RParent) { + unsigned LDepth = LI->getLoopDepth(LParent), + RDepth = LI->getLoopDepth(RParent); if (LDepth != RDepth) return (int)LDepth - (int)RDepth; } - // Addrec complexity grows with operand count. - unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); - if (LNumOps != RNumOps) - return (int)LNumOps - (int)RNumOps; + // Compare the number of operands. + unsigned LNumOps = LInst->getNumOperands(), + RNumOps = RInst->getNumOperands(); + return (int)LNumOps - (int)RNumOps; + } + + return 0; + } - // Lexicographically compare. - for (unsigned i = 0; i != LNumOps; ++i) { - long X = compare(LA->getOperand(i), RA->getOperand(i)); - if (X != 0) - return X; - } + case scConstant: { + const SCEVConstant *LC = cast(LHS); + const SCEVConstant *RC = cast(RHS); - return 0; + // Compare constant values. + const APInt &LA = LC->getAPInt(); + const APInt &RA = RC->getAPInt(); + unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); + if (LBitWidth != RBitWidth) + return (int)LBitWidth - (int)RBitWidth; + return LA.ult(RA) ? -1 : 1; + } + + case scAddRecExpr: { + const SCEVAddRecExpr *LA = cast(LHS); + const SCEVAddRecExpr *RA = cast(RHS); + + // Compare addrec loop depths. + const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); + if (LLoop != RLoop) { + unsigned LDepth = LLoop->getLoopDepth(), + RDepth = RLoop->getLoopDepth(); + if (LDepth != RDepth) + return (int)LDepth - (int)RDepth; } - case scAddExpr: - case scMulExpr: - case scSMaxExpr: - case scUMaxExpr: { - const SCEVNAryExpr *LC = cast(LHS); - const SCEVNAryExpr *RC = cast(RHS); - - // Lexicographically compare n-ary expressions. - unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); - if (LNumOps != RNumOps) - return (int)LNumOps - (int)RNumOps; - - for (unsigned i = 0; i != LNumOps; ++i) { - if (i >= RNumOps) - return 1; - long X = compare(LC->getOperand(i), RC->getOperand(i)); - if (X != 0) - return X; - } + // Addrec complexity grows with operand count. + unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); + if (LNumOps != RNumOps) return (int)LNumOps - (int)RNumOps; + + // Lexicographically compare. + for (unsigned i = 0; i != LNumOps; ++i) { + long X = compare(LA->getOperand(i), RA->getOperand(i)); + if (X != 0) + return X; } - case scUDivExpr: { - const SCEVUDivExpr *LC = cast(LHS); - const SCEVUDivExpr *RC = cast(RHS); + return 0; + } + + case scAddExpr: + case scMulExpr: + case scSMaxExpr: + case scUMaxExpr: { + const SCEVNAryExpr *LC = cast(LHS); + const SCEVNAryExpr *RC = cast(RHS); + + // Lexicographically compare n-ary expressions. + unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); + if (LNumOps != RNumOps) + return (int)LNumOps - (int)RNumOps; - // Lexicographically compare udiv expressions. - long X = compare(LC->getLHS(), RC->getLHS()); + for (unsigned i = 0; i != LNumOps; ++i) { + if (i >= RNumOps) + return 1; + long X = compare(LC->getOperand(i), RC->getOperand(i)); if (X != 0) return X; - return compare(LC->getRHS(), RC->getRHS()); } + return (int)LNumOps - (int)RNumOps; + } - case scTruncate: - case scZeroExtend: - case scSignExtend: { - const SCEVCastExpr *LC = cast(LHS); - const SCEVCastExpr *RC = cast(RHS); + case scUDivExpr: { + const SCEVUDivExpr *LC = cast(LHS); + const SCEVUDivExpr *RC = cast(RHS); - // Compare cast expressions by operand. - return compare(LC->getOperand(), RC->getOperand()); - } + // Lexicographically compare udiv expressions. + long X = compare(LC->getLHS(), RC->getLHS()); + if (X != 0) + return X; + return compare(LC->getRHS(), RC->getRHS()); + } - case scCouldNotCompute: - llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); - } - llvm_unreachable("Unknown SCEV kind!"); + case scTruncate: + case scZeroExtend: + case scSignExtend: { + const SCEVCastExpr *LC = cast(LHS); + const SCEVCastExpr *RC = cast(RHS); + + // Compare cast expressions by operand. + return compare(LC->getOperand(), RC->getOperand()); } - }; -} + + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + } + llvm_unreachable("Unknown SCEV kind!"); + } +}; +} // end anonymous namespace /// GroupByComplexity - Given a list of SCEV objects, order them by their /// complexity, and group objects of the same complexity together by value. @@ -666,24 +666,22 @@ static void GroupByComplexity(SmallVectorImpl &Ops, } } -namespace { -struct FindSCEVSize { - int Size; - FindSCEVSize() : Size(0) {} - - bool follow(const SCEV *S) { - ++Size; - // Keep looking at all operands of S. - return true; - } - bool isDone() const { - return false; - } -}; -} - // Returns the size of the SCEV S. static inline int sizeOfSCEV(const SCEV *S) { + struct FindSCEVSize { + int Size; + FindSCEVSize() : Size(0) {} + + bool follow(const SCEV *S) { + ++Size; + // Keep looking at all operands of S. + return true; + } + bool isDone() const { + return false; + } + }; + FindSCEVSize F; SCEVTraversal ST(F); ST.visitAll(S); @@ -762,8 +760,8 @@ public: void visitConstant(const SCEVConstant *Numerator) { if (const SCEVConstant *D = dyn_cast(Denominator)) { - APInt NumeratorVal = Numerator->getValue()->getValue(); - APInt DenominatorVal = D->getValue()->getValue(); + APInt NumeratorVal = Numerator->getAPInt(); + APInt DenominatorVal = D->getAPInt(); uint32_t NumeratorBW = NumeratorVal.getBitWidth(); uint32_t DenominatorBW = DenominatorVal.getBitWidth(); @@ -1132,8 +1130,8 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, // If the input value is a chrec scev, truncate the chrec's operands. if (const SCEVAddRecExpr *AddRec = dyn_cast(Op)) { SmallVector Operands; - for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) - Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty)); + for (const SCEV *Op : AddRec->operands()) + Operands.push_back(getTruncateExpr(Op, Ty)); return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap); } @@ -1268,7 +1266,9 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // `Step`: // 1. NSW/NUW flags on the step increment. - const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags()); + auto PreStartFlags = + ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW); + const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags); const SCEVAddRecExpr *PreAR = dyn_cast( SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); @@ -1303,9 +1303,9 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ExtendOpTraits::getOverflowLimitForStep(Step, &Pred, SE); if (OverflowLimit && - SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) { + SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) return PreStart; - } + return nullptr; } @@ -1371,24 +1371,22 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, if (!StartC) return false; - APInt StartAI = StartC->getValue()->getValue(); + APInt StartAI = StartC->getAPInt(); for (unsigned Delta : {-2, -1, 1, 2}) { const SCEV *PreStart = getConstant(StartAI - Delta); + FoldingSetNodeID ID; + ID.AddInteger(scAddRecExpr); + ID.AddPointer(PreStart); + ID.AddPointer(Step); + ID.AddPointer(L); + void *IP = nullptr; + const auto *PreAR = + static_cast(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + // Give up if we don't already have the add recurrence we need because // actually constructing an add recurrence is relatively expensive. - const SCEVAddRecExpr *PreAR = [&]() { - FoldingSetNodeID ID; - ID.AddInteger(scAddRecExpr); - ID.AddPointer(PreStart); - ID.AddPointer(Step); - ID.AddPointer(L); - void *IP = nullptr; - return static_cast( - this->UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); - }(); - if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2) const SCEV *DeltaS = getConstant(StartC->getType(), Delta); ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; @@ -1559,6 +1557,18 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, } } + if (auto *SA = dyn_cast(Op)) { + // zext((A + B + ...)) --> (zext(A) + zext(B) + ...) + if (SA->getNoWrapFlags(SCEV::FlagNUW)) { + // If the addition does not unsign overflow then we can, by definition, + // commute the zero extension with the addition operation. + SmallVector Ops; + for (const auto *Op : SA->operands()) + Ops.push_back(getZeroExtendExpr(Op, Ty)); + return getAddExpr(Ops, SCEV::FlagNUW); + } + } + // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; @@ -1622,8 +1632,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, auto *SMul = dyn_cast(SA->getOperand(1)); if (SMul && SC1) { if (auto *SC2 = dyn_cast(SMul->getOperand(0))) { - const APInt &C1 = SC1->getValue()->getValue(); - const APInt &C2 = SC2->getValue()->getValue(); + const APInt &C1 = SC1->getAPInt(); + const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) return getAddExpr(getSignExtendExpr(SC1, Ty), @@ -1631,6 +1641,16 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, } } } + + // sext((A + B + ...)) --> (sext(A) + sext(B) + ...) + if (SA->getNoWrapFlags(SCEV::FlagNSW)) { + // If the addition does not sign overflow then we can, by definition, + // commute the sign extension with the addition operation. + SmallVector Ops; + for (const auto *Op : SA->operands()) + Ops.push_back(getSignExtendExpr(Op, Ty)); + return getAddExpr(Ops, SCEV::FlagNSW); + } } // If the input value is a chrec scev, and we can prove that the value // did not overflow the old, smaller, value, we can sign extend all of the @@ -1738,8 +1758,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, auto *SC1 = dyn_cast(Start); auto *SC2 = dyn_cast(Step); if (SC1 && SC2) { - const APInt &C1 = SC1->getValue()->getValue(); - const APInt &C2 = SC2->getValue()->getValue(); + const APInt &C1 = SC1->getAPInt(); + const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) { Start = getSignExtendExpr(Start, Ty); @@ -1779,7 +1799,7 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, // Sign-extend negative constants. if (const SCEVConstant *SC = dyn_cast(Op)) - if (SC->getValue()->getValue().isNegative()) + if (SC->getAPInt().isNegative()) return getSignExtendExpr(Op, Ty); // Peel off a truncate cast. @@ -1857,7 +1877,7 @@ CollectAddOperandsWithScales(DenseMap &M, // Pull a buried constant out to the outside. if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero()) Interesting = true; - AccumulatedConstant += Scale * C->getValue()->getValue(); + AccumulatedConstant += Scale * C->getAPInt(); } // Next comes everything else. We're especially interested in multiplies @@ -1866,7 +1886,7 @@ CollectAddOperandsWithScales(DenseMap &M, const SCEVMulExpr *Mul = dyn_cast(Ops[i]); if (Mul && isa(Mul->getOperand(0))) { APInt NewScale = - Scale * cast(Mul->getOperand(0))->getValue()->getValue(); + Scale * cast(Mul->getOperand(0))->getAPInt(); if (Mul->getNumOperands() == 2 && isa(Mul->getOperand(1))) { // A multiplication of a constant with another add; recurse. const SCEVAddExpr *Add = cast(Mul->getOperand(1)); @@ -1907,22 +1927,15 @@ CollectAddOperandsWithScales(DenseMap &M, return Interesting; } -namespace { - struct APIntCompare { - bool operator()(const APInt &LHS, const APInt &RHS) const { - return LHS.ult(RHS); - } - }; -} - // We're trying to construct a SCEV of type `Type' with `Ops' as operands and // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of // can't-overflow flags for the operation if possible. static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const SmallVectorImpl &Ops, - SCEV::NoWrapFlags OldFlags) { + SCEV::NoWrapFlags Flags) { using namespace std::placeholders; + typedef OverflowingBinaryOperator OBO; bool CanAnalyze = Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; @@ -1931,18 +1944,42 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; SCEV::NoWrapFlags SignOrUnsignWrap = - ScalarEvolution::maskFlags(OldFlags, SignOrUnsignMask); + ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - auto IsKnownNonNegative = - std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1); + auto IsKnownNonNegative = [&](const SCEV *S) { + return SE->isKnownNonNegative(S); + }; + + if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative)) + Flags = + ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); + + SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); + + if (SignOrUnsignWrap != SignOrUnsignMask && Type == scAddExpr && + Ops.size() == 2 && isa(Ops[0])) { - if (SignOrUnsignWrap == SCEV::FlagNSW && - std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative)) - return ScalarEvolution::setFlags(OldFlags, - (SCEV::NoWrapFlags)SignOrUnsignMask); + // (A + C) --> (A + C) if the addition does not sign overflow + // (A + C) --> (A + C) if the addition does not unsign overflow - return OldFlags; + const APInt &C = cast(Ops[0])->getAPInt(); + if (!(SignOrUnsignWrap & SCEV::FlagNSW)) { + auto NSWRegion = + ConstantRange::makeNoWrapRegion(Instruction::Add, C, OBO::NoSignedWrap); + if (NSWRegion.contains(SE->getSignedRange(Ops[1]))) + Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); + } + if (!(SignOrUnsignWrap & SCEV::FlagNUW)) { + auto NUWRegion = + ConstantRange::makeNoWrapRegion(Instruction::Add, C, + OBO::NoUnsignedWrap); + if (NUWRegion.contains(SE->getUnsignedRange(Ops[1]))) + Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); + } + } + + return Flags; } /// getAddExpr - Get a canonical add expression, or something simpler if @@ -1960,11 +1997,11 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, "SCEVAddExpr operand types don't match!"); #endif - Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); - // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, &LI); + Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); + // If there are any constants, fold them together. unsigned Idx = 0; if (const SCEVConstant *LHSC = dyn_cast(Ops[0])) { @@ -1972,8 +2009,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - Ops[0] = getConstant(LHSC->getValue()->getValue() + - RHSC->getValue()->getValue()); + Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt()); if (Ops.size() == 2) return Ops[0]; Ops.erase(Ops.begin()+1); // Erase the folded element LHSC = cast(Ops[0]); @@ -2043,8 +2079,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, break; } LargeMulOps.push_back(T->getOperand()); - } else if (const SCEVConstant *C = - dyn_cast(M->getOperand(j))) { + } else if (const auto *C = dyn_cast(M->getOperand(j))) { LargeMulOps.push_back(getAnyExtendExpr(C, SrcType)); } else { Ok = false; @@ -2103,22 +2138,26 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, Ops.data(), Ops.size(), APInt(BitWidth, 1), *this)) { + struct APIntCompare { + bool operator()(const APInt &LHS, const APInt &RHS) const { + return LHS.ult(RHS); + } + }; + // Some interesting folding opportunity is present, so its worthwhile to // re-generate the operands list. Group the operands by constant scale, // to avoid multiplying by the same constant scale multiple times. std::map, APIntCompare> MulOpLists; - for (SmallVectorImpl::const_iterator I = NewOps.begin(), - E = NewOps.end(); I != E; ++I) - MulOpLists[M.find(*I)->second].push_back(*I); + for (const SCEV *NewOp : NewOps) + MulOpLists[M.find(NewOp)->second].push_back(NewOp); // Re-generate the operands list. Ops.clear(); if (AccumulatedConstant != 0) Ops.push_back(getConstant(AccumulatedConstant)); - for (std::map, APIntCompare>::iterator - I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I) - if (I->first != 0) - Ops.push_back(getMulExpr(getConstant(I->first), - getAddExpr(I->second))); + for (auto &MulOp : MulOpLists) + if (MulOp.first != 0) + Ops.push_back(getMulExpr(getConstant(MulOp.first), + getAddExpr(MulOp.second))); if (Ops.empty()) return getZero(Ty); if (Ops.size() == 1) @@ -2259,8 +2298,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, AddRec->op_end()); for (; OtherIdx != Ops.size() && isa(Ops[OtherIdx]); ++OtherIdx) - if (const SCEVAddRecExpr *OtherAddRec = - dyn_cast(Ops[OtherIdx])) + if (const auto *OtherAddRec = dyn_cast(Ops[OtherIdx])) if (OtherAddRec->getLoop() == AddRecLoop) { for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) { @@ -2368,11 +2406,11 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, "SCEVMulExpr operand types don't match!"); #endif - Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); - // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, &LI); + Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); + // If there are any constants, fold them together. unsigned Idx = 0; if (const SCEVConstant *LHSC = dyn_cast(Ops[0])) { @@ -2390,9 +2428,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, ++Idx; while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get(getContext(), - LHSC->getValue()->getValue() * - RHSC->getValue()->getValue()); + ConstantInt *Fold = + ConstantInt::get(getContext(), LHSC->getAPInt() * RHSC->getAPInt()); Ops[0] = getConstant(Fold); Ops.erase(Ops.begin()+1); // Erase the folded element if (Ops.size() == 1) return Ops[0]; @@ -2413,23 +2450,19 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) { SmallVector NewOps; bool AnyFolded = false; - for (SCEVAddRecExpr::op_iterator I = Add->op_begin(), - E = Add->op_end(); I != E; ++I) { - const SCEV *Mul = getMulExpr(Ops[0], *I); + for (const SCEV *AddOp : Add->operands()) { + const SCEV *Mul = getMulExpr(Ops[0], AddOp); if (!isa(Mul)) AnyFolded = true; NewOps.push_back(Mul); } if (AnyFolded) return getAddExpr(NewOps); - } - else if (const SCEVAddRecExpr * - AddRec = dyn_cast(Ops[1])) { + } else if (const auto *AddRec = dyn_cast(Ops[1])) { // Negation preserves a recurrence's no self-wrap property. SmallVector Operands; - for (SCEVAddRecExpr::op_iterator I = AddRec->op_begin(), - E = AddRec->op_end(); I != E; ++I) { - Operands.push_back(getMulExpr(Ops[0], *I)); - } + for (const SCEV *AddRecOp : AddRec->operands()) + Operands.push_back(getMulExpr(Ops[0], AddRecOp)); + return getAddRecExpr(Operands, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW)); } @@ -2618,11 +2651,11 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // its operands. // TODO: Generalize this to non-constants by using known-bits information. Type *Ty = LHS->getType(); - unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros(); + unsigned LZ = RHSC->getAPInt().countLeadingZeros(); unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1; // For non-power-of-two values, effectively round the value up to the // nearest power of two. - if (!RHSC->getValue()->getValue().isPowerOf2()) + if (!RHSC->getAPInt().isPowerOf2()) ++MaxShiftAmt; IntegerType *ExtTy = IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt); @@ -2630,18 +2663,17 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, if (const SCEVConstant *Step = dyn_cast(AR->getStepRecurrence(*this))) { // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. - const APInt &StepInt = Step->getValue()->getValue(); - const APInt &DivInt = RHSC->getValue()->getValue(); + const APInt &StepInt = Step->getAPInt(); + const APInt &DivInt = RHSC->getAPInt(); if (!StepInt.urem(DivInt) && getZeroExtendExpr(AR, ExtTy) == getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), getZeroExtendExpr(Step, ExtTy), AR->getLoop(), SCEV::FlagAnyWrap)) { SmallVector Operands; - for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i) - Operands.push_back(getUDivExpr(AR->getOperand(i), RHS)); - return getAddRecExpr(Operands, AR->getLoop(), - SCEV::FlagNW); + for (const SCEV *Op : AR->operands()) + Operands.push_back(getUDivExpr(Op, RHS)); + return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW); } /// Get a canonical UDivExpr for a recurrence. /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0. @@ -2652,7 +2684,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), getZeroExtendExpr(Step, ExtTy), AR->getLoop(), SCEV::FlagAnyWrap)) { - const APInt &StartInt = StartC->getValue()->getValue(); + const APInt &StartInt = StartC->getAPInt(); const APInt &StartRem = StartInt.urem(StepInt); if (StartRem != 0) LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step, @@ -2662,8 +2694,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // (A*B)/C --> A*(B/C) if safe and B/C can be folded. if (const SCEVMulExpr *M = dyn_cast(LHS)) { SmallVector Operands; - for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) - Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy)); + for (const SCEV *Op : M->operands()) + Operands.push_back(getZeroExtendExpr(Op, ExtTy)); if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) // Find an operand that's safely divisible. for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) { @@ -2680,8 +2712,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. if (const SCEVAddExpr *A = dyn_cast(LHS)) { SmallVector Operands; - for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) - Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy)); + for (const SCEV *Op : A->operands()) + Operands.push_back(getZeroExtendExpr(Op, ExtTy)); if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { Operands.clear(); for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { @@ -2719,8 +2751,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, } static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue().abs(); - APInt B = C2->getValue()->getValue().abs(); + APInt A = C1->getAPInt().abs(); + APInt B = C2->getAPInt().abs(); uint32_t ABW = A.getBitWidth(); uint32_t BBW = B.getBitWidth(); @@ -2749,8 +2781,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, if (const SCEVConstant *RHSCst = dyn_cast(RHS)) { // If the mulexpr multiplies by a constant, then that constant must be the // first element of the mulexpr. - if (const SCEVConstant *LHSCst = - dyn_cast(Mul->getOperand(0))) { + if (const auto *LHSCst = dyn_cast(Mul->getOperand(0))) { if (LHSCst == RHSCst) { SmallVector Operands; Operands.append(Mul->op_begin() + 1, Mul->op_end()); @@ -2762,10 +2793,10 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, // check. APInt Factor = gcd(LHSCst, RHSCst); if (!Factor.isIntN(1)) { - LHSCst = cast( - getConstant(LHSCst->getValue()->getValue().udiv(Factor))); - RHSCst = cast( - getConstant(RHSCst->getValue()->getValue().udiv(Factor))); + LHSCst = + cast(getConstant(LHSCst->getAPInt().udiv(Factor))); + RHSCst = + cast(getConstant(RHSCst->getAPInt().udiv(Factor))); SmallVector Operands; Operands.push_back(LHSCst); Operands.append(Mul->op_begin() + 1, Mul->op_end()); @@ -2849,12 +2880,9 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, // AddRecs require their operands be loop-invariant with respect to their // loops. Don't perform this transformation if it would break this // requirement. - bool AllInvariant = true; - for (unsigned i = 0, e = Operands.size(); i != e; ++i) - if (!isLoopInvariant(Operands[i], L)) { - AllInvariant = false; - break; - } + bool AllInvariant = all_of( + Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); }); + if (AllInvariant) { // Create a recurrence for the outer loop with the same step size. // @@ -2864,12 +2892,10 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags()); NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags); - AllInvariant = true; - for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i) - if (!isLoopInvariant(NestedOperands[i], NestedLoop)) { - AllInvariant = false; - break; - } + AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) { + return isLoopInvariant(Op, NestedLoop); + }); + if (AllInvariant) { // Ok, both add recurrences are valid after the transformation. // @@ -2986,9 +3012,8 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get(getContext(), - APIntOps::smax(LHSC->getValue()->getValue(), - RHSC->getValue()->getValue())); + ConstantInt *Fold = ConstantInt::get( + getContext(), APIntOps::smax(LHSC->getAPInt(), RHSC->getAPInt())); Ops[0] = getConstant(Fold); Ops.erase(Ops.begin()+1); // Erase the folded element if (Ops.size() == 1) return Ops[0]; @@ -3090,9 +3115,8 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get(getContext(), - APIntOps::umax(LHSC->getValue()->getValue(), - RHSC->getValue()->getValue())); + ConstantInt *Fold = ConstantInt::get( + getContext(), APIntOps::umax(LHSC->getAPInt(), RHSC->getAPInt())); Ops[0] = getConstant(Fold); Ops.erase(Ops.begin()+1); // Erase the folded element if (Ops.size() == 1) return Ops[0]; @@ -3181,8 +3205,7 @@ const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { // We can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. - return getConstant(IntTy, - F.getParent()->getDataLayout().getTypeAllocSize(AllocTy)); + return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); } const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, @@ -3192,9 +3215,7 @@ const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. return getConstant( - IntTy, - F.getParent()->getDataLayout().getStructLayout(STy)->getElementOffset( - FieldNo)); + IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo)); } const SCEV *ScalarEvolution::getUnknown(Value *V) { @@ -3236,7 +3257,7 @@ bool ScalarEvolution::isSCEVable(Type *Ty) const { /// for which isSCEVable must return true. uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); - return F.getParent()->getDataLayout().getTypeSizeInBits(Ty); + return getDataLayout().getTypeSizeInBits(Ty); } /// getEffectiveSCEVType - Return a type with the same bitwidth as @@ -3246,20 +3267,20 @@ uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); - if (Ty->isIntegerTy()) { + if (Ty->isIntegerTy()) return Ty; - } // The only other support type is pointer. assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!"); - return F.getParent()->getDataLayout().getIntPtrType(Ty); + return getDataLayout().getIntPtrType(Ty); } const SCEV *ScalarEvolution::getCouldNotCompute() { return CouldNotCompute.get(); } -namespace { + +bool ScalarEvolution::checkValidity(const SCEV *S) const { // Helper class working with SCEVTraversal to figure out if a SCEV contains // a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne // is set iff if find such SCEVUnknown. @@ -3281,9 +3302,7 @@ namespace { } bool isDone() const { return FindOne; } }; -} -bool ScalarEvolution::checkValidity(const SCEV *S) const { FindInvalidSCEVUnknown F; SCEVTraversal ST(F); ST.visitAll(S); @@ -3523,16 +3542,14 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { if (const SCEVCastExpr *Cast = dyn_cast(V)) { return getPointerBase(Cast->getOperand()); - } - else if (const SCEVNAryExpr *NAry = dyn_cast(V)) { + } else if (const SCEVNAryExpr *NAry = dyn_cast(V)) { const SCEV *PtrOp = nullptr; - for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); - I != E; ++I) { - if ((*I)->getType()->isPointerTy()) { + for (const SCEV *NAryOp : NAry->operands()) { + if (NAryOp->getType()->isPointerTy()) { // Cannot find the base of an expression with multiple pointer operands. if (PtrOp) return V; - PtrOp = *I; + PtrOp = NAryOp; } } if (!PtrOp) @@ -3568,8 +3585,7 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { if (!Visited.insert(I).second) continue; - ValueExprMapType::iterator It = - ValueExprMap.find_as(static_cast(I)); + auto It = ValueExprMap.find_as(static_cast(I)); if (It != ValueExprMap.end()) { const SCEV *Old = It->second; @@ -3597,6 +3613,73 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { } } +namespace { +class SCEVInitRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE) { + SCEVInitRewriter Rewriter(L, SE); + const SCEV *Result = Rewriter.visit(Scev); + return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); + } + + SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L), Valid(true) {} + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) + Valid = false; + return Expr; + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + // Only allow AddRecExprs for this loop. + if (Expr->getLoop() == L) + return Expr->getStart(); + Valid = false; + return Expr; + } + + bool isValid() { return Valid; } + +private: + const Loop *L; + bool Valid; +}; + +class SCEVShiftRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE) { + SCEVShiftRewriter Rewriter(L, SE); + const SCEV *Result = Rewriter.visit(Scev); + return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); + } + + SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L), Valid(true) {} + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + // Only allow AddRecExprs for this loop. + if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) + Valid = false; + return Expr; + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + if (Expr->getLoop() == L && Expr->isAffine()) + return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE)); + Valid = false; + return Expr; + } + bool isValid() { return Valid; } + +private: + const Loop *L; + bool Valid; +}; +} // end anonymous namespace + const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const Loop *L = LI.getLoopFor(PN->getParent()); if (!L || L->getHeader() != PN->getParent()) @@ -3709,31 +3792,28 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { return PHISCEV; } } - } else if (const SCEVAddRecExpr *AddRec = - dyn_cast(BEValue)) { + } else { // Otherwise, this could be a loop like this: // i = 0; for (j = 1; ..; ++j) { .... i = j; } // In this case, j = {1,+,1} and BEValue is j. // Because the other in-value of i (0) fits the evolution of BEValue // i really is an addrec evolution. - if (AddRec->getLoop() == L && AddRec->isAffine()) { + // + // We can generalize this saying that i is the shifted value of BEValue + // by one iteration: + // PHI(f(0), f({1,+,1})) --> f({0,+,1}) + const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); + const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this); + if (Shifted != getCouldNotCompute() && + Start != getCouldNotCompute()) { const SCEV *StartVal = getSCEV(StartValueV); - - // If StartVal = j.start - j.stride, we can use StartVal as the - // initial step of the addrec evolution. - if (StartVal == - getMinusSCEV(AddRec->getOperand(0), AddRec->getOperand(1))) { - // FIXME: For constant StartVal, we should be able to infer - // no-wrap flags. - const SCEV *PHISCEV = getAddRecExpr(StartVal, AddRec->getOperand(1), - L, SCEV::FlagAnyWrap); - + if (Start == StartVal) { // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the // entries for the scalars that use the symbolic expression. ForgetSymbolicName(PN, SymbolicName); - ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; - return PHISCEV; + ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; + return Shifted; } } } @@ -3767,8 +3847,8 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S, switch (S->getSCEVType()) { case scConstant: case scTruncate: case scZeroExtend: case scSignExtend: case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr: - // These expressions are available if their operand(s) is/are. - return true; + // These expressions are available if their operand(s) is/are. + return true; case scAddRecExpr: { // We allow add recurrences that are on the loop BB is in, or some @@ -3799,8 +3879,8 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S, case scUDivExpr: case scCouldNotCompute: - // We do not try to smart about these at all. - return setUnavailable(); + // We do not try to smart about these at all. + return setUnavailable(); } llvm_unreachable("switch should be fully covered!"); } @@ -3852,6 +3932,11 @@ const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { if (PN->getNumIncomingValues() == 2) { const Loop *L = LI.getLoopFor(PN->getParent()); + // We don't want to break LCSSA, even in a SCEV expression tree. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (LI.getLoopFor(PN->getIncomingBlock(i)) != L) + return nullptr; + // Try to match // // br %cond, label %left, label %right @@ -3891,8 +3976,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // PHI's incoming blocks are in a different loop, in which case doing so // risks breaking LCSSA form. Instcombine would normally zap these, but // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = SimplifyInstruction(PN, F.getParent()->getDataLayout(), &TLI, - &DT, &AC)) + if (Value *V = SimplifyInstruction(PN, getDataLayout(), &TLI, &DT, &AC)) if (LI.replacementPreservesLCSSAForm(PN, V)) return getSCEV(V); @@ -4022,7 +4106,7 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { if (const SCEVConstant *C = dyn_cast(S)) - return C->getValue()->getValue().countTrailingZeros(); + return C->getAPInt().countTrailingZeros(); if (const SCEVTruncateExpr *T = dyn_cast(S)) return std::min(GetMinTrailingZeros(T->getOperand()), @@ -4087,8 +4171,8 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { // For a SCEVUnknown, ask ValueTracking. unsigned BitWidth = getTypeSizeInBits(U->getType()); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, F.getParent()->getDataLayout(), - 0, &AC, nullptr, &DT); + computeKnownBits(U->getValue(), Zeros, Ones, getDataLayout(), 0, &AC, + nullptr, &DT); return Zeros.countTrailingOnes(); } @@ -4099,26 +4183,9 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { /// GetRangeFromMetadata - Helper method to assign a range to V from /// metadata present in the IR. static Optional GetRangeFromMetadata(Value *V) { - if (Instruction *I = dyn_cast(V)) { - if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) { - ConstantRange TotalRange( - cast(I->getType())->getBitWidth(), false); - - unsigned NumRanges = MD->getNumOperands() / 2; - assert(NumRanges >= 1); - - for (unsigned i = 0; i < NumRanges; ++i) { - ConstantInt *Lower = - mdconst::extract(MD->getOperand(2 * i + 0)); - ConstantInt *Upper = - mdconst::extract(MD->getOperand(2 * i + 1)); - ConstantRange Range(Lower->getValue(), Upper->getValue()); - TotalRange = TotalRange.unionWith(Range); - } - - return TotalRange; - } - } + if (Instruction *I = dyn_cast(V)) + if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) + return getConstantRangeFromMetadata(*MD); return None; } @@ -4140,7 +4207,7 @@ ScalarEvolution::getRange(const SCEV *S, return I->second; if (const SCEVConstant *C = dyn_cast(S)) - return setRange(C, SignHint, ConstantRange(C->getValue()->getValue())); + return setRange(C, SignHint, ConstantRange(C->getAPInt())); unsigned BitWidth = getTypeSizeInBits(S->getType()); ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); @@ -4218,9 +4285,8 @@ ScalarEvolution::getRange(const SCEV *S, if (AddRec->getNoWrapFlags(SCEV::FlagNUW)) if (const SCEVConstant *C = dyn_cast(AddRec->getStart())) if (!C->getValue()->isZero()) - ConservativeResult = - ConservativeResult.intersectWith( - ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0))); + ConservativeResult = ConservativeResult.intersectWith( + ConstantRange(C->getAPInt(), APInt(BitWidth, 0))); // If there's no signed wrap, and all the operands have the same sign or // zero, the value won't ever change sign. @@ -4318,7 +4384,7 @@ ScalarEvolution::getRange(const SCEV *S, // Split here to avoid paying the compile-time cost of calling both // computeKnownBits and ComputeNumSignBits. This restriction can be lifted // if needed. - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = getDataLayout(); if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { // For a SCEVUnknown, ask ValueTracking. APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); @@ -4526,8 +4592,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, - F.getParent()->getDataLayout(), 0, &AC, nullptr, &DT); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, getDataLayout(), + 0, &AC, nullptr, &DT); APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); @@ -5302,6 +5368,14 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, BECount = EL0.Exact; } + // There are cases (e.g. PR26207) where computeExitLimitFromCond is able + // to be more aggressive when computing BECount than when computing + // MaxBECount. In these cases it is possible for EL0.Exact and EL1.Exact + // to match, but for EL0.Max and EL1.Max to not. + if (isa(MaxBECount) && + !isa(BECount)) + MaxBECount = BECount; + return ExitLimit(BECount, MaxBECount); } if (BO->getOpcode() == Instruction::Or) { @@ -5386,6 +5460,11 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, return ItCnt; } + ExitLimit ShiftEL = computeShiftCompareExitLimit( + ExitCond->getOperand(0), ExitCond->getOperand(1), L, Cond); + if (ShiftEL.hasAnyInfo()) + return ShiftEL; + const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); @@ -5411,7 +5490,7 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, if (AddRec->getLoop() == L) { // Form the constant range. ConstantRange CompRange( - ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue())); + ICmpInst::makeConstantRange(Cond, RHSC->getAPInt())); const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); if (!isa(Ret)) return Ret; @@ -5445,14 +5524,6 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, break; } default: -#if 0 - dbgs() << "computeBackedgeTakenCount "; - if (ExitCond->getOperand(0)->getType()->isUnsigned()) - dbgs() << "[unsigned] "; - dbgs() << *LHS << " " - << Instruction::getOpcodeName(Instruction::ICmp) - << " " << *RHS << "\n"; -#endif break; } return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); @@ -5565,11 +5636,6 @@ ScalarEvolution::computeLoadConstantCompareExitLimit( Result = ConstantExpr::getICmp(predicate, Result, RHS); if (!isa(Result)) break; // Couldn't decide for sure if (cast(Result)->getValue().isMinValue()) { -#if 0 - dbgs() << "\n***\n*** Computed loop count " << *ItCst - << "\n*** From global " << *GV << "*** BB: " << *L->getHeader() - << "***\n"; -#endif ++NumArrayLenItCounts; return getConstant(ItCst); // Found terminating iteration! } @@ -5577,6 +5643,149 @@ ScalarEvolution::computeLoadConstantCompareExitLimit( return getCouldNotCompute(); } +ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( + Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) { + ConstantInt *RHS = dyn_cast(RHSV); + if (!RHS) + return getCouldNotCompute(); + + const BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return getCouldNotCompute(); + + const BasicBlock *Predecessor = L->getLoopPredecessor(); + if (!Predecessor) + return getCouldNotCompute(); + + // Return true if V is of the form "LHS `shift_op` ". + // Return LHS in OutLHS and shift_opt in OutOpCode. + auto MatchPositiveShift = + [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) { + + using namespace PatternMatch; + + ConstantInt *ShiftAmt; + if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) + OutOpCode = Instruction::LShr; + else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) + OutOpCode = Instruction::AShr; + else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) + OutOpCode = Instruction::Shl; + else + return false; + + return ShiftAmt->getValue().isStrictlyPositive(); + }; + + // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in + // + // loop: + // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ] + // %iv.shifted = lshr i32 %iv, + // + // Return true on a succesful match. Return the corresponding PHI node (%iv + // above) in PNOut and the opcode of the shift operation in OpCodeOut. + auto MatchShiftRecurrence = + [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) { + Optional PostShiftOpCode; + + { + Instruction::BinaryOps OpC; + Value *V; + + // If we encounter a shift instruction, "peel off" the shift operation, + // and remember that we did so. Later when we inspect %iv's backedge + // value, we will make sure that the backedge value uses the same + // operation. + // + // Note: the peeled shift operation does not have to be the same + // instruction as the one feeding into the PHI's backedge value. We only + // really care about it being the same *kind* of shift instruction -- + // that's all that is required for our later inferences to hold. + if (MatchPositiveShift(LHS, V, OpC)) { + PostShiftOpCode = OpC; + LHS = V; + } + } + + PNOut = dyn_cast(LHS); + if (!PNOut || PNOut->getParent() != L->getHeader()) + return false; + + Value *BEValue = PNOut->getIncomingValueForBlock(Latch); + Value *OpLHS; + + return + // The backedge value for the PHI node must be a shift by a positive + // amount + MatchPositiveShift(BEValue, OpLHS, OpCodeOut) && + + // of the PHI node itself + OpLHS == PNOut && + + // and the kind of shift should be match the kind of shift we peeled + // off, if any. + (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut); + }; + + PHINode *PN; + Instruction::BinaryOps OpCode; + if (!MatchShiftRecurrence(LHS, PN, OpCode)) + return getCouldNotCompute(); + + const DataLayout &DL = getDataLayout(); + + // The key rationale for this optimization is that for some kinds of shift + // recurrences, the value of the recurrence "stabilizes" to either 0 or -1 + // within a finite number of iterations. If the condition guarding the + // backedge (in the sense that the backedge is taken if the condition is true) + // is false for the value the shift recurrence stabilizes to, then we know + // that the backedge is taken only a finite number of times. + + ConstantInt *StableValue = nullptr; + switch (OpCode) { + default: + llvm_unreachable("Impossible case!"); + + case Instruction::AShr: { + // {K,ashr,} stabilizes to signum(K) in at most + // bitwidth(K) iterations. + Value *FirstValue = PN->getIncomingValueForBlock(Predecessor); + bool KnownZero, KnownOne; + ComputeSignBit(FirstValue, KnownZero, KnownOne, DL, 0, nullptr, + Predecessor->getTerminator(), &DT); + auto *Ty = cast(RHS->getType()); + if (KnownZero) + StableValue = ConstantInt::get(Ty, 0); + else if (KnownOne) + StableValue = ConstantInt::get(Ty, -1, true); + else + return getCouldNotCompute(); + + break; + } + case Instruction::LShr: + case Instruction::Shl: + // Both {K,lshr,} and {K,shl,} + // stabilize to 0 in at most bitwidth(K) iterations. + StableValue = ConstantInt::get(cast(RHS->getType()), 0); + break; + } + + auto *Result = + ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI); + assert(Result->getType()->isIntegerTy(1) && + "Otherwise cannot be an operand to a branch instruction"); + + if (Result->isZeroValue()) { + unsigned BitWidth = getTypeSizeInBits(RHS->getType()); + const SCEV *UpperBound = + getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); + return ExitLimit(getCouldNotCompute(), UpperBound); + } + + return getCouldNotCompute(); +} /// CanConstantFold - Return true if we can constant fold an instruction of the /// specified type, assuming that all operands were constants. @@ -5618,12 +5827,10 @@ getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, // Otherwise, we can evaluate this instruction if all of its operands are // constant or derived from a PHI node themselves. PHINode *PHI = nullptr; - for (Instruction::op_iterator OpI = UseInst->op_begin(), - OpE = UseInst->op_end(); OpI != OpE; ++OpI) { - - if (isa(*OpI)) continue; + for (Value *Op : UseInst->operands()) { + if (isa(Op)) continue; - Instruction *OpInst = dyn_cast(*OpI); + Instruction *OpInst = dyn_cast(Op); if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr; PHINode *P = dyn_cast(OpInst); @@ -5657,9 +5864,8 @@ static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { Instruction *I = dyn_cast(V); if (!I || !canConstantEvolve(I, L)) return nullptr; - if (PHINode *PN = dyn_cast(I)) { + if (PHINode *PN = dyn_cast(I)) return PN; - } // Record non-constant instructions contained by the loop. DenseMap PHIMap; @@ -5716,6 +5922,30 @@ static Constant *EvaluateExpression(Value *V, const Loop *L, TLI); } + +// If every incoming value to PN except the one for BB is a specific Constant, +// return that, else return nullptr. +static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) { + Constant *IncomingVal = nullptr; + + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + if (PN->getIncomingBlock(i) == BB) + continue; + + auto *CurrentVal = dyn_cast(PN->getIncomingValue(i)); + if (!CurrentVal) + return nullptr; + + if (IncomingVal != CurrentVal) { + if (IncomingVal) + return nullptr; + IncomingVal = CurrentVal; + } + } + + return IncomingVal; +} + /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a /// constant number of times, and the PHI node is just a recurrence @@ -5741,25 +5971,10 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, if (!Latch) return nullptr; - // Since the loop has one latch, the PHI node must have two entries. One - // entry must be a constant (coming in from outside of the loop), and the - // second must be derived from the same PHI. - - BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0) - ? PN->getIncomingBlock(1) - : PN->getIncomingBlock(0); - - assert(PN->getNumIncomingValues() == 2 && "Follows from having one latch!"); - - // Note: not all PHI nodes in the same block have to have their incoming - // values in the same order, so we use the basic block to look up the incoming - // value, not an index. - for (auto &I : *Header) { PHINode *PHI = dyn_cast(&I); if (!PHI) break; - auto *StartCST = - dyn_cast(PHI->getIncomingValueForBlock(NonLatch)); + auto *StartCST = getOtherIncomingValue(PHI, Latch); if (!StartCST) continue; CurrentIterVals[PHI] = StartCST; } @@ -5774,7 +5989,7 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, unsigned NumIterations = BEs.getZExtValue(); // must be in range unsigned IterationNum = 0; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = getDataLayout(); for (; ; ++IterationNum) { if (IterationNum == NumIterations) return RetVal = CurrentIterVals[PN]; // Got exit value! @@ -5838,21 +6053,11 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Should follow from NumIncomingValues == 2!"); - // NonLatch is the preheader, or something equivalent. - BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0) - ? PN->getIncomingBlock(1) - : PN->getIncomingBlock(0); - - // Note: not all PHI nodes in the same block have to have their incoming - // values in the same order, so we use the basic block to look up the incoming - // value, not an index. - for (auto &I : *Header) { PHINode *PHI = dyn_cast(&I); if (!PHI) break; - auto *StartCST = - dyn_cast(PHI->getIncomingValueForBlock(NonLatch)); + auto *StartCST = getOtherIncomingValue(PHI, Latch); if (!StartCST) continue; CurrentIterVals[PHI] = StartCST; } @@ -5863,7 +6068,7 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, // the loop symbolically to determine when the condition gets a value of // "ExitWhen". unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis. - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = getDataLayout(); for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){ auto *CondVal = dyn_cast_or_null( EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI)); @@ -5913,22 +6118,22 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, /// In the case that a relevant loop exit value cannot be computed, the /// original value V is returned. const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { + SmallVector, 2> &Values = + ValuesAtScopes[V]; // Check to see if we've folded this expression at this loop before. - SmallVector, 2> &Values = ValuesAtScopes[V]; - for (unsigned u = 0; u < Values.size(); u++) { - if (Values[u].first == L) - return Values[u].second ? Values[u].second : V; - } - Values.push_back(std::make_pair(L, static_cast(nullptr))); + for (auto &LS : Values) + if (LS.first == L) + return LS.second ? LS.second : V; + + Values.emplace_back(L, nullptr); + // Otherwise compute it. const SCEV *C = computeSCEVAtScope(V, L); - SmallVector, 2> &Values2 = ValuesAtScopes[V]; - for (unsigned u = Values2.size(); u > 0; u--) { - if (Values2[u - 1].first == L) { - Values2[u - 1].second = C; + for (auto &LS : reverse(ValuesAtScopes[V])) + if (LS.first == L) { + LS.second = C; break; } - } return C; } @@ -6052,9 +6257,8 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // Okay, we know how many times the containing loop executes. If // this is a constant evolving PHI node, get the final value at // the specified iteration number. - Constant *RV = getConstantEvolutionLoopExitValue(PN, - BTCC->getValue()->getValue(), - LI); + Constant *RV = + getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), LI); if (RV) return getSCEV(RV); } } @@ -6066,8 +6270,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { if (CanConstantFold(I)) { SmallVector Operands; bool MadeImprovement = false; - for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { - Value *Op = I->getOperand(i); + for (Value *Op : I->operands()) { if (Constant *C = dyn_cast(Op)) { Operands.push_back(C); continue; @@ -6096,7 +6299,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // Check to see if getSCEVAtScope actually made an improvement. if (MadeImprovement) { Constant *C = nullptr; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = getDataLayout(); if (const CmpInst *CI = dyn_cast(I)) C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], Operands[1], DL, &TLI); @@ -6296,10 +6499,10 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { return std::make_pair(CNC, CNC); } - uint32_t BitWidth = LC->getValue()->getValue().getBitWidth(); - const APInt &L = LC->getValue()->getValue(); - const APInt &M = MC->getValue()->getValue(); - const APInt &N = NC->getValue()->getValue(); + uint32_t BitWidth = LC->getAPInt().getBitWidth(); + const APInt &L = LC->getAPInt(); + const APInt &M = MC->getAPInt(); + const APInt &N = NC->getAPInt(); APInt Two(BitWidth, 2); APInt Four(BitWidth, 4); @@ -6378,10 +6581,6 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { const SCEVConstant *R1 = dyn_cast(Roots.first); const SCEVConstant *R2 = dyn_cast(Roots.second); if (R1 && R2) { -#if 0 - dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1 - << " sol#2: " << *R2 << "\n"; -#endif // Pick the smallest positive root value. if (ConstantInt *CB = dyn_cast(ConstantExpr::getICmp(CmpInst::ICMP_ULT, @@ -6435,7 +6634,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // For negative steps (counting down to zero): // N = Start/-Step // First compute the unsigned distance from zero in the direction of Step. - bool CountDown = StepC->getValue()->getValue().isNegative(); + bool CountDown = StepC->getAPInt().isNegative(); const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start); // Handle unitary steps, which cannot wraparound. @@ -6460,7 +6659,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // done by counting and comparing the number of trailing zeros of Step and // Distance. if (!CountDown) { - const APInt &StepV = StepC->getValue()->getValue(); + const APInt &StepV = StepC->getAPInt(); // StepV.isPowerOf2() returns true if StepV is an positive power of two. It // also returns true if StepV is maximally negative (eg, INT_MIN), but that // case is not handled as this code is guarded by !CountDown. @@ -6522,8 +6721,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // Then, try to solve the above equation provided that Start is constant. if (const SCEVConstant *StartC = dyn_cast(Start)) - return SolveLinEquationWithOverflow(StepC->getValue()->getValue(), - -StartC->getValue()->getValue(), + return SolveLinEquationWithOverflow(StepC->getAPInt(), -StartC->getAPInt(), *this); return getCouldNotCompute(); } @@ -6646,7 +6844,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, // If there's a constant operand, canonicalize comparisons with boundary // cases, and canonicalize *-or-equal comparisons to regular comparisons. if (const SCEVConstant *RC = dyn_cast(RHS)) { - const APInt &RA = RC->getValue()->getValue(); + const APInt &RA = RC->getAPInt(); switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: @@ -6837,16 +7035,14 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, Pred = ICmpInst::ICMP_ULT; Changed = true; } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) { - LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS, - SCEV::FlagNUW); + LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS); Pred = ICmpInst::ICMP_ULT; Changed = true; } break; case ICmpInst::ICMP_UGE: if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) { - RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, - SCEV::FlagNUW); + RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS); Pred = ICmpInst::ICMP_UGT; Changed = true; } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) { @@ -7142,6 +7338,62 @@ ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred, return false; } +bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, + const SCEV *LHS, + const SCEV *RHS) { + + // Match Result to (X + Y) where Y is a constant integer. + // Return Y via OutY. + auto MatchBinaryAddToConst = + [this](const SCEV *Result, const SCEV *X, APInt &OutY, + SCEV::NoWrapFlags ExpectedFlags) { + const SCEV *NonConstOp, *ConstOp; + SCEV::NoWrapFlags FlagsPresent; + + if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) || + !isa(ConstOp) || NonConstOp != X) + return false; + + OutY = cast(ConstOp)->getAPInt(); + return (FlagsPresent & ExpectedFlags) == ExpectedFlags; + }; + + APInt C; + + switch (Pred) { + default: + break; + + case ICmpInst::ICMP_SGE: + std::swap(LHS, RHS); + case ICmpInst::ICMP_SLE: + // X s<= (X + C) if C >= 0 + if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative()) + return true; + + // (X + C) s<= X if C <= 0 + if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && + !C.isStrictlyPositive()) + return true; + break; + + case ICmpInst::ICMP_SGT: + std::swap(LHS, RHS); + case ICmpInst::ICMP_SLT: + // X s< (X + C) if C > 0 + if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && + C.isStrictlyPositive()) + return true; + + // (X + C) s< X if C < 0 + if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative()) + return true; + break; + } + + return false; +} + bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { @@ -7159,12 +7411,9 @@ bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the // interesting cases seen in practice. We can consider "upgrading" L >= 0 to // use isKnownPredicate later if needed. - if (isKnownNonNegative(RHS) && - isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) && - isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS)) - return true; - - return false; + return isKnownNonNegative(RHS) && + isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) && + isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS); } /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is @@ -7317,6 +7566,7 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, return false; } +namespace { /// RAII wrapper to prevent recursive application of isImpliedCond. /// ScalarEvolution's PendingLoopPredicates set must be empty unless we are /// currently evaluating isImpliedCond. @@ -7334,6 +7584,7 @@ struct MarkPendingLoopPredicate { LoopPreds.erase(Cond); } }; +} // end anonymous namespace /// isImpliedCond - Test whether the condition described by Pred, LHS, /// and RHS is true whenever the given Cond value evaluates to true. @@ -7435,6 +7686,13 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, RHS, LHS, FoundLHS, FoundRHS); } + // Unsigned comparison is the same as signed comparison when both the operands + // are non-negative. + if (CmpInst::isUnsigned(FoundPred) && + CmpInst::getSignedPredicate(FoundPred) == Pred && + isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); + // Check if we can make progress by sharpening ranges. if (FoundPred == ICmpInst::ICMP_NE && (isa(FoundLHS) || isa(FoundRHS))) { @@ -7458,7 +7716,7 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, APInt Min = ICmpInst::isSigned(Pred) ? getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin(); - if (Min == C->getValue()->getValue()) { + if (Min == C->getAPInt()) { // Given (V >= Min && V != Min) we conclude V >= (Min + 1). // This is true even if (Min + 1) wraps around -- in case of // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)). @@ -7509,21 +7767,24 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, return false; } -// Return true if More == (Less + C), where C is a constant. -static bool IsConstDiff(ScalarEvolution &SE, const SCEV *Less, const SCEV *More, - APInt &C) { - // We avoid subtracting expressions here because this function is usually - // fairly deep in the call stack (i.e. is called many times). +bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, + const SCEV *&L, const SCEV *&R, + SCEV::NoWrapFlags &Flags) { + const auto *AE = dyn_cast(Expr); + if (!AE || AE->getNumOperands() != 2) + return false; - auto SplitBinaryAdd = [](const SCEV *Expr, const SCEV *&L, const SCEV *&R) { - const auto *AE = dyn_cast(Expr); - if (!AE || AE->getNumOperands() != 2) - return false; + L = AE->getOperand(0); + R = AE->getOperand(1); + Flags = AE->getNoWrapFlags(); + return true; +} - L = AE->getOperand(0); - R = AE->getOperand(1); - return true; - }; +bool ScalarEvolution::computeConstantDifference(const SCEV *Less, + const SCEV *More, + APInt &C) { + // We avoid subtracting expressions here because this function is usually + // fairly deep in the call stack (i.e. is called many times). if (isa(Less) && isa(More)) { const auto *LAR = cast(Less); @@ -7537,7 +7798,7 @@ static bool IsConstDiff(ScalarEvolution &SE, const SCEV *Less, const SCEV *More, if (!LAR->isAffine() || !MAR->isAffine()) return false; - if (LAR->getStepRecurrence(SE) != MAR->getStepRecurrence(SE)) + if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this)) return false; Less = LAR->getStart(); @@ -7547,24 +7808,25 @@ static bool IsConstDiff(ScalarEvolution &SE, const SCEV *Less, const SCEV *More, } if (isa(Less) && isa(More)) { - const auto &M = cast(More)->getValue()->getValue(); - const auto &L = cast(Less)->getValue()->getValue(); + const auto &M = cast(More)->getAPInt(); + const auto &L = cast(Less)->getAPInt(); C = M - L; return true; } const SCEV *L, *R; - if (SplitBinaryAdd(Less, L, R)) + SCEV::NoWrapFlags Flags; + if (splitBinaryAdd(Less, L, R, Flags)) if (const auto *LC = dyn_cast(L)) if (R == More) { - C = -(LC->getValue()->getValue()); + C = -(LC->getAPInt()); return true; } - if (SplitBinaryAdd(More, L, R)) + if (splitBinaryAdd(More, L, R, Flags)) if (const auto *LC = dyn_cast(L)) if (R == Less) { - C = LC->getValue()->getValue(); + C = LC->getAPInt(); return true; } @@ -7626,8 +7888,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( // C)". APInt LDiff, RDiff; - if (!IsConstDiff(*this, FoundLHS, LHS, LDiff) || - !IsConstDiff(*this, FoundRHS, RHS, RDiff) || + if (!computeConstantDifference(FoundLHS, LHS, LDiff) || + !computeConstantDifference(FoundRHS, RHS, RDiff) || LDiff != RDiff) return false; @@ -7673,17 +7935,13 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, /// If Expr computes ~A, return A else return nullptr static const SCEV *MatchNotExpr(const SCEV *Expr) { const SCEVAddExpr *Add = dyn_cast(Expr); - if (!Add || Add->getNumOperands() != 2) return nullptr; - - const SCEVConstant *AddLHS = dyn_cast(Add->getOperand(0)); - if (!(AddLHS && AddLHS->getValue()->getValue().isAllOnesValue())) + if (!Add || Add->getNumOperands() != 2 || + !Add->getOperand(0)->isAllOnesValue()) return nullptr; const SCEVMulExpr *AddRHS = dyn_cast(Add->getOperand(1)); - if (!AddRHS || AddRHS->getNumOperands() != 2) return nullptr; - - const SCEVConstant *MulLHS = dyn_cast(AddRHS->getOperand(0)); - if (!(MulLHS && MulLHS->getValue()->getValue().isAllOnesValue())) + if (!AddRHS || AddRHS->getNumOperands() != 2 || + !AddRHS->getOperand(0)->isAllOnesValue()) return nullptr; return AddRHS->getOperand(1); @@ -7697,8 +7955,7 @@ static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr, const MaxExprType *MaxExpr = dyn_cast(MaybeMaxExpr); if (!MaxExpr) return false; - auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate); - return It != MaxExpr->op_end(); + return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end(); } @@ -7791,8 +8048,9 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, auto IsKnownPredicateFull = [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { return isKnownPredicateWithRanges(Pred, LHS, RHS) || - IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || - IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS); + IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || + IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || + isKnownPredicateViaNoOverflow(Pred, LHS, RHS); }; switch (Pred) { @@ -7848,7 +8106,7 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, !isa(AddLHS->getOperand(0))) return false; - APInt ConstFoundRHS = cast(FoundRHS)->getValue()->getValue(); + APInt ConstFoundRHS = cast(FoundRHS)->getAPInt(); // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the // antecedent "`FoundLHS` `Pred` `FoundRHS`". @@ -7857,13 +8115,12 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, // Since `LHS` is `FoundLHS` + `AddLHS->getOperand(0)`, we can compute a range // for `LHS`: - APInt Addend = - cast(AddLHS->getOperand(0))->getValue()->getValue(); + APInt Addend = cast(AddLHS->getOperand(0))->getAPInt(); ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(Addend)); // We can also compute the range of values for `LHS` that satisfy the // consequent, "`LHS` `Pred` `RHS`": - APInt ConstRHS = cast(RHS)->getValue()->getValue(); + APInt ConstRHS = cast(RHS)->getAPInt(); ConstantRange SatisfyingLHSRange = ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS); @@ -7987,7 +8244,7 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // overflow, in which case if RHS - Start is a constant, we don't need to // do a max operation since we can just figure it out statically if (NoWrap && isa(Diff)) { - APInt D = dyn_cast(Diff)->getValue()->getValue(); + APInt D = dyn_cast(Diff)->getAPInt(); if (D.isNegative()) End = Start; } else @@ -8068,7 +8325,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, // overflow, in which case if RHS - Start is a constant, we don't need to // do a max operation since we can just figure it out statically if (NoWrap && isa(Diff)) { - APInt D = dyn_cast(Diff)->getValue()->getValue(); + APInt D = dyn_cast(Diff)->getAPInt(); if (!D.isNegative()) End = Start; } else @@ -8126,20 +8383,17 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, Operands[0] = SE.getZero(SC->getType()); const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(), getNoWrapFlags(FlagNW)); - if (const SCEVAddRecExpr *ShiftedAddRec = - dyn_cast(Shifted)) + if (const auto *ShiftedAddRec = dyn_cast(Shifted)) return ShiftedAddRec->getNumIterationsInRange( - Range.subtract(SC->getValue()->getValue()), SE); + Range.subtract(SC->getAPInt()), SE); // This is strange and shouldn't happen. return SE.getCouldNotCompute(); } // The only time we can solve this is when we have all constant indices. // Otherwise, we cannot determine the overflow conditions. - for (unsigned i = 0, e = getNumOperands(); i != e; ++i) - if (!isa(getOperand(i))) - return SE.getCouldNotCompute(); - + if (any_of(operands(), [](const SCEV *Op) { return !isa(Op); })) + return SE.getCouldNotCompute(); // Okay at this point we know that all elements of the chrec are constants and // that the start element is zero. @@ -8159,7 +8413,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // If A is negative then the lower of the range is the last possible loop // value. Also note that we already checked for a full range. APInt One(BitWidth,1); - APInt A = cast(getOperand(1))->getValue()->getValue(); + APInt A = cast(getOperand(1))->getAPInt(); APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower(); // The exit value should be (End+A)/A. @@ -8191,15 +8445,13 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, FlagAnyWrap); // Next, solve the constructed addrec - std::pair Roots = - SolveQuadraticEquation(cast(NewAddRec), SE); + auto Roots = SolveQuadraticEquation(cast(NewAddRec), SE); const SCEVConstant *R1 = dyn_cast(Roots.first); const SCEVConstant *R2 = dyn_cast(Roots.second); if (R1) { // Pick the smallest positive root value. - if (ConstantInt *CB = - dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, - R1->getValue(), R2->getValue()))) { + if (ConstantInt *CB = dyn_cast(ConstantExpr::getICmp( + ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (!CB->getZExtValue()) std::swap(R1, R2); // R1 is the minimum root now. @@ -8212,7 +8464,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, if (Range.contains(R1Val->getValue())) { // The next iteration must be out of the range... ConstantInt *NextVal = - ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1); + ConstantInt::get(SE.getContext(), R1->getAPInt() + 1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (!Range.contains(R1Val->getValue())) @@ -8223,7 +8475,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // If R1 was not in the range, then it is a good return value. Make // sure that R1-1 WAS in the range though, just in case. ConstantInt *NextVal = - ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1); + ConstantInt::get(SE.getContext(), R1->getAPInt() - 1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (Range.contains(R1Val->getValue())) return R1; @@ -8307,9 +8559,84 @@ struct SCEVCollectTerms { } bool isDone() const { return false; } }; + +// Check if a SCEV contains an AddRecExpr. +struct SCEVHasAddRec { + bool &ContainsAddRec; + + SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) { + ContainsAddRec = false; + } + + bool follow(const SCEV *S) { + if (isa(S)) { + ContainsAddRec = true; + + // Stop recursion: once we collected a term, do not walk its operands. + return false; + } + + // Keep looking. + return true; + } + bool isDone() const { return false; } +}; + +// Find factors that are multiplied with an expression that (possibly as a +// subexpression) contains an AddRecExpr. In the expression: +// +// 8 * (100 + %p * %q * (%a + {0, +, 1}_loop)) +// +// "%p * %q" are factors multiplied by the expression "(%a + {0, +, 1}_loop)" +// that contains the AddRec {0, +, 1}_loop. %p * %q are likely to be array size +// parameters as they form a product with an induction variable. +// +// This collector expects all array size parameters to be in the same MulExpr. +// It might be necessary to later add support for collecting parameters that are +// spread over different nested MulExpr. +struct SCEVCollectAddRecMultiplies { + SmallVectorImpl &Terms; + ScalarEvolution &SE; + + SCEVCollectAddRecMultiplies(SmallVectorImpl &T, ScalarEvolution &SE) + : Terms(T), SE(SE) {} + + bool follow(const SCEV *S) { + if (auto *Mul = dyn_cast(S)) { + bool HasAddRec = false; + SmallVector Operands; + for (auto Op : Mul->operands()) { + if (isa(Op)) { + Operands.push_back(Op); + } else { + bool ContainsAddRec; + SCEVHasAddRec ContiansAddRec(ContainsAddRec); + visitAll(Op, ContiansAddRec); + HasAddRec |= ContainsAddRec; + } + } + if (Operands.size() == 0) + return true; + + if (!HasAddRec) + return false; + + Terms.push_back(SE.getMulExpr(Operands)); + // Stop recursion: once we collected a term, do not walk its operands. + return false; + } + + // Keep looking. + return true; + } + bool isDone() const { return false; } +}; } -/// Find parametric terms in this SCEVAddRecExpr. +/// Find parametric terms in this SCEVAddRecExpr. We first for parameters in +/// two places: +/// 1) The strides of AddRec expressions. +/// 2) Unknowns that are multiplied with AddRec expressions. void ScalarEvolution::collectParametricTerms(const SCEV *Expr, SmallVectorImpl &Terms) { SmallVector Strides; @@ -8332,6 +8659,9 @@ void ScalarEvolution::collectParametricTerms(const SCEV *Expr, for (const SCEV *T : Terms) dbgs() << *T << "\n"; }); + + SCEVCollectAddRecMultiplies MulCollector(Terms, *this); + visitAll(Expr, MulCollector); } static bool findArrayDimensionsRec(ScalarEvolution &SE, @@ -8381,30 +8711,28 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE, return true; } -namespace { -struct FindParameter { - bool FoundParameter; - FindParameter() : FoundParameter(false) {} - - bool follow(const SCEV *S) { - if (isa(S)) { - FoundParameter = true; - // Stop recursion: we found a parameter. - return false; - } - // Keep looking. - return true; - } - bool isDone() const { - // Stop recursion if we have found a parameter. - return FoundParameter; - } -}; -} - // Returns true when S contains at least a SCEVUnknown parameter. static inline bool containsParameters(const SCEV *S) { + struct FindParameter { + bool FoundParameter; + FindParameter() : FoundParameter(false) {} + + bool follow(const SCEV *S) { + if (isa(S)) { + FoundParameter = true; + // Stop recursion: we found a parameter. + return false; + } + // Keep looking. + return true; + } + bool isDone() const { + // Stop recursion if we have found a parameter. + return FoundParameter; + } + }; + FindParameter F; SCEVTraversal ST(F); ST.visitAll(S); @@ -8492,11 +8820,13 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl &Terms, ScalarEvolution &SE = *const_cast(this); - // Divide all terms by the element size. + // Try to divide all terms by the element size. If term is not divisible by + // element size, proceed with the original term. for (const SCEV *&Term : Terms) { const SCEV *Q, *R; SCEVDivision::divide(SE, Term, ElementSize, &Q, &R); - Term = Q; + if (!Q->isZero()) + Term = Q; } SmallVector NewTerms; @@ -8745,6 +9075,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) UnsignedRanges(std::move(Arg.UnsignedRanges)), SignedRanges(std::move(Arg.SignedRanges)), UniqueSCEVs(std::move(Arg.UniqueSCEVs)), + UniquePreds(std::move(Arg.UniquePreds)), SCEVAllocator(std::move(Arg.SCEVAllocator)), FirstUnknown(Arg.FirstUnknown) { Arg.FirstUnknown = nullptr; @@ -8764,11 +9095,8 @@ ScalarEvolution::~ScalarEvolution() { // Free any extra memory created for ExitNotTakenInfo in the unlikely event // that a loop had multiple computable exits. - for (DenseMap::iterator I = - BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end(); - I != E; ++I) { - I->second.clear(); - } + for (auto &BTCI : BackedgeTakenCounts) + BTCI.second.clear(); assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!"); @@ -8826,11 +9154,11 @@ void ScalarEvolution::print(raw_ostream &OS) const { OS << "Classifying expressions for: "; F.printAsOperand(OS, /*PrintType=*/false); OS << "\n"; - for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) - if (isSCEVable(I->getType()) && !isa(*I)) { - OS << *I << '\n'; + for (Instruction &I : instructions(F)) + if (isSCEVable(I.getType()) && !isa(I)) { + OS << I << '\n'; OS << " --> "; - const SCEV *SV = SE.getSCEV(&*I); + const SCEV *SV = SE.getSCEV(&I); SV->print(OS); if (!isa(SV)) { OS << " U: "; @@ -8839,7 +9167,7 @@ void ScalarEvolution::print(raw_ostream &OS) const { SE.getSignedRange(SV).print(OS); } - const Loop *L = LI.getLoopFor((*I).getParent()); + const Loop *L = LI.getLoopFor(I.getParent()); const SCEV *AtUse = SE.getSCEVAtScope(SV, L); if (AtUse != SV) { @@ -8922,9 +9250,8 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { // This recurrence is variant w.r.t. L if any of its operands // are variant. - for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end(); - I != E; ++I) - if (!isLoopInvariant(*I, L)) + for (auto *Op : AR->operands()) + if (!isLoopInvariant(Op, L)) return LoopVariant; // Otherwise it's loop-invariant. @@ -8934,11 +9261,9 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { case scMulExpr: case scUMaxExpr: case scSMaxExpr: { - const SCEVNAryExpr *NAry = cast(S); bool HasVarying = false; - for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); - I != E; ++I) { - LoopDisposition D = getLoopDisposition(*I, L); + for (auto *Op : cast(S)->operands()) { + LoopDisposition D = getLoopDisposition(Op, L); if (D == LoopVariant) return LoopVariant; if (D == LoopComputable) @@ -8962,7 +9287,7 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { // invariant if they are not contained in the specified loop. // Instructions are never considered invariant in the function body // (null loop) because they are defined within the "loop". - if (Instruction *I = dyn_cast(cast(S)->getValue())) + if (auto *I = dyn_cast(cast(S)->getValue())) return (L && !L->contains(I)) ? LoopInvariant : LoopVariant; return LoopInvariant; case scCouldNotCompute: @@ -9023,9 +9348,8 @@ ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { case scSMaxExpr: { const SCEVNAryExpr *NAry = cast(S); bool Proper = true; - for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); - I != E; ++I) { - BlockDisposition D = getBlockDisposition(*I, BB); + for (const SCEV *NAryOp : NAry->operands()) { + BlockDisposition D = getBlockDisposition(NAryOp, BB); if (D == DoesNotDominateBlock) return DoesNotDominateBlock; if (D == DominatesBlock) @@ -9069,24 +9393,22 @@ bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) { return getBlockDisposition(S, BB) == ProperlyDominatesBlock; } -namespace { -// Search for a SCEV expression node within an expression tree. -// Implements SCEVTraversal::Visitor. -struct SCEVSearch { - const SCEV *Node; - bool IsFound; +bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { + // Search for a SCEV expression node within an expression tree. + // Implements SCEVTraversal::Visitor. + struct SCEVSearch { + const SCEV *Node; + bool IsFound; - SCEVSearch(const SCEV *N): Node(N), IsFound(false) {} + SCEVSearch(const SCEV *N): Node(N), IsFound(false) {} - bool follow(const SCEV *S) { - IsFound |= (S == Node); - return !IsFound; - } - bool isDone() const { return IsFound; } -}; -} + bool follow(const SCEV *S) { + IsFound |= (S == Node); + return !IsFound; + } + bool isDone() const { return IsFound; } + }; -bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { SCEVSearch Search(Op); visitAll(S, Search); return Search.IsFound; @@ -9125,23 +9447,22 @@ static void replaceSubString(std::string &Str, StringRef From, StringRef To) { /// getLoopBackedgeTakenCounts - Helper method for verifyAnalysis. static void getLoopBackedgeTakenCounts(Loop *L, VerifyMap &Map, ScalarEvolution &SE) { - for (Loop::reverse_iterator I = L->rbegin(), E = L->rend(); I != E; ++I) { - getLoopBackedgeTakenCounts(*I, Map, SE); // recurse. - - std::string &S = Map[L]; - if (S.empty()) { - raw_string_ostream OS(S); - SE.getBackedgeTakenCount(L)->print(OS); + std::string &S = Map[L]; + if (S.empty()) { + raw_string_ostream OS(S); + SE.getBackedgeTakenCount(L)->print(OS); - // false and 0 are semantically equivalent. This can happen in dead loops. - replaceSubString(OS.str(), "false", "0"); - // Remove wrap flags, their use in SCEV is highly fragile. - // FIXME: Remove this when SCEV gets smarter about them. - replaceSubString(OS.str(), "", ""); - replaceSubString(OS.str(), "", ""); - replaceSubString(OS.str(), "", ""); - } + // false and 0 are semantically equivalent. This can happen in dead loops. + replaceSubString(OS.str(), "false", "0"); + // Remove wrap flags, their use in SCEV is highly fragile. + // FIXME: Remove this when SCEV gets smarter about them. + replaceSubString(OS.str(), "", ""); + replaceSubString(OS.str(), "", ""); + replaceSubString(OS.str(), "", ""); } + + for (auto *R : reverse(*L)) + getLoopBackedgeTakenCounts(R, Map, SE); // recurse. } void ScalarEvolution::verify() const { @@ -9251,3 +9572,178 @@ void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequiredTransitive(); AU.addRequiredTransitive(); } + +const SCEVPredicate * +ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS, + const SCEVConstant *RHS) { + FoldingSetNodeID ID; + // Unique this node based on the arguments + ID.AddInteger(SCEVPredicate::P_Equal); + ID.AddPointer(LHS); + ID.AddPointer(RHS); + void *IP = nullptr; + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) + return S; + SCEVEqualPredicate *Eq = new (SCEVAllocator) + SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS); + UniquePreds.InsertNode(Eq, IP); + return Eq; +} + +namespace { +class SCEVPredicateRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, + SCEVUnionPredicate &A) { + SCEVPredicateRewriter Rewriter(SE, A); + return Rewriter.visit(Scev); + } + + SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P) + : SCEVRewriteVisitor(SE), P(P) {} + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + auto ExprPreds = P.getPredicatesForExpr(Expr); + for (auto *Pred : ExprPreds) + if (const auto *IPred = dyn_cast(Pred)) + if (IPred->getLHS() == Expr) + return IPred->getRHS(); + + return Expr; + } + +private: + SCEVUnionPredicate &P; +}; +} // end anonymous namespace + +const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, + SCEVUnionPredicate &Preds) { + return SCEVPredicateRewriter::rewrite(Scev, *this, Preds); +} + +/// SCEV predicates +SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, + SCEVPredicateKind Kind) + : FastID(ID), Kind(Kind) {} + +SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, + const SCEVUnknown *LHS, + const SCEVConstant *RHS) + : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {} + +bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast(N); + + if (!Op) + return false; + + return Op->LHS == LHS && Op->RHS == RHS; +} + +bool SCEVEqualPredicate::isAlwaysTrue() const { return false; } + +const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; } + +void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; +} + +/// Union predicates don't get cached so create a dummy set ID for it. +SCEVUnionPredicate::SCEVUnionPredicate() + : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} + +bool SCEVUnionPredicate::isAlwaysTrue() const { + return all_of(Preds, + [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); +} + +ArrayRef +SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) { + auto I = SCEVToPreds.find(Expr); + if (I == SCEVToPreds.end()) + return ArrayRef(); + return I->second; +} + +bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { + if (const auto *Set = dyn_cast(N)) + return all_of(Set->Preds, + [this](const SCEVPredicate *I) { return this->implies(I); }); + + auto ScevPredsIt = SCEVToPreds.find(N->getExpr()); + if (ScevPredsIt == SCEVToPreds.end()) + return false; + auto &SCEVPreds = ScevPredsIt->second; + + return any_of(SCEVPreds, + [N](const SCEVPredicate *I) { return I->implies(N); }); +} + +const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; } + +void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { + for (auto Pred : Preds) + Pred->print(OS, Depth); +} + +void SCEVUnionPredicate::add(const SCEVPredicate *N) { + if (const auto *Set = dyn_cast(N)) { + for (auto Pred : Set->Preds) + add(Pred); + return; + } + + if (implies(N)) + return; + + const SCEV *Key = N->getExpr(); + assert(Key && "Only SCEVUnionPredicate doesn't have an " + " associated expression!"); + + SCEVToPreds[Key].push_back(N); + Preds.push_back(N); +} + +PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE) + : SE(SE), Generation(0) {} + +const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { + const SCEV *Expr = SE.getSCEV(V); + RewriteEntry &Entry = RewriteMap[Expr]; + + // If we already have an entry and the version matches, return it. + if (Entry.second && Generation == Entry.first) + return Entry.second; + + // We found an entry but it's stale. Rewrite the stale entry + // acording to the current predicate. + if (Entry.second) + Expr = Entry.second; + + const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds); + Entry = {Generation, NewSCEV}; + + return NewSCEV; +} + +void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { + if (Preds.implies(&Pred)) + return; + Preds.add(&Pred); + updateGeneration(); +} + +const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const { + return Preds; +} + +void PredicatedScalarEvolution::updateGeneration() { + // If the generation number wrapped recompute everything. + if (++Generation == 0) { + for (auto &II : RewriteMap) { + const SCEV *Rewritten = II.second.second; + II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)}; + } + } +}