X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FAnalysis%2FScalarEvolution.cpp;h=34074efd1cebcce7fe41395dd8d3e98c4bed374f;hb=810605370d53b5ded5243df2ca8bcdbb3ed04047;hp=38f57fc1a0d1414ecde6298d5ae910539087d1d0;hpb=33624191afee63657cfbc68be041299060ee59d2;p=oota-llvm.git diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 38f57fc1a0d..34074efd1ce 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -83,6 +83,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -123,6 +124,12 @@ VerifySCEV("verify-scev", // Implementation of the SCEV class. // +LLVM_DUMP_METHOD +void SCEV::dump() const { + print(dbgs()); + dbgs() << '\n'; +} + void SCEV::print(raw_ostream &OS) const { switch (static_cast(getSCEVType())) { case scConstant: @@ -287,7 +294,7 @@ bool SCEV::isNonConstantNegative() const { if (!SC) return false; // Return true if the value is negative, this matches things like (-42 * V). - return SC->getValue()->getValue().isNegative(); + return SC->getAPInt().isNegative(); } SCEVCouldNotCompute::SCEVCouldNotCompute() : @@ -439,179 +446,179 @@ bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const { //===----------------------------------------------------------------------===// namespace { - /// SCEVComplexityCompare - Return true if the complexity of the LHS is less - /// than the complexity of the RHS. This comparator is used to canonicalize - /// expressions. - class SCEVComplexityCompare { - const LoopInfo *const LI; - public: - explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {} - - // Return true or false if LHS is less than, or at least RHS, respectively. - bool operator()(const SCEV *LHS, const SCEV *RHS) const { - return compare(LHS, RHS) < 0; - } - - // Return negative, zero, or positive, if LHS is less than, equal to, or - // greater than RHS, respectively. A three-way result allows recursive - // comparisons to be more efficient. - int compare(const SCEV *LHS, const SCEV *RHS) const { - // Fast-path: SCEVs are uniqued so we can do a quick equality check. - if (LHS == RHS) - return 0; - - // Primarily, sort the SCEVs by their getSCEVType(). - unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); - if (LType != RType) - return (int)LType - (int)RType; - - // Aside from the getSCEVType() ordering, the particular ordering - // isn't very important except that it's beneficial to be consistent, - // so that (a + b) and (b + a) don't end up as different expressions. - switch (static_cast(LType)) { - case scUnknown: { - const SCEVUnknown *LU = cast(LHS); - const SCEVUnknown *RU = cast(RHS); - - // Sort SCEVUnknown values with some loose heuristics. TODO: This is - // not as complete as it could be. - const Value *LV = LU->getValue(), *RV = RU->getValue(); - - // Order pointer values after integer values. This helps SCEVExpander - // form GEPs. - bool LIsPointer = LV->getType()->isPointerTy(), - RIsPointer = RV->getType()->isPointerTy(); - if (LIsPointer != RIsPointer) - return (int)LIsPointer - (int)RIsPointer; - - // Compare getValueID values. - unsigned LID = LV->getValueID(), - RID = RV->getValueID(); - if (LID != RID) - return (int)LID - (int)RID; - - // Sort arguments by their position. - if (const Argument *LA = dyn_cast(LV)) { - const Argument *RA = cast(RV); - unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); - return (int)LArgNo - (int)RArgNo; - } - - // For instructions, compare their loop depth, and their operand - // count. This is pretty loose. - if (const Instruction *LInst = dyn_cast(LV)) { - const Instruction *RInst = cast(RV); - - // Compare loop depths. - const BasicBlock *LParent = LInst->getParent(), - *RParent = RInst->getParent(); - if (LParent != RParent) { - unsigned LDepth = LI->getLoopDepth(LParent), - RDepth = LI->getLoopDepth(RParent); - if (LDepth != RDepth) - return (int)LDepth - (int)RDepth; - } - - // Compare the number of operands. - unsigned LNumOps = LInst->getNumOperands(), - RNumOps = RInst->getNumOperands(); - return (int)LNumOps - (int)RNumOps; - } +/// SCEVComplexityCompare - Return true if the complexity of the LHS is less +/// than the complexity of the RHS. This comparator is used to canonicalize +/// expressions. +class SCEVComplexityCompare { + const LoopInfo *const LI; +public: + explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {} - return 0; - } + // Return true or false if LHS is less than, or at least RHS, respectively. + bool operator()(const SCEV *LHS, const SCEV *RHS) const { + return compare(LHS, RHS) < 0; + } - case scConstant: { - const SCEVConstant *LC = cast(LHS); - const SCEVConstant *RC = cast(RHS); - - // Compare constant values. - const APInt &LA = LC->getValue()->getValue(); - const APInt &RA = RC->getValue()->getValue(); - unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); - if (LBitWidth != RBitWidth) - return (int)LBitWidth - (int)RBitWidth; - return LA.ult(RA) ? -1 : 1; + // Return negative, zero, or positive, if LHS is less than, equal to, or + // greater than RHS, respectively. A three-way result allows recursive + // comparisons to be more efficient. + int compare(const SCEV *LHS, const SCEV *RHS) const { + // Fast-path: SCEVs are uniqued so we can do a quick equality check. + if (LHS == RHS) + return 0; + + // Primarily, sort the SCEVs by their getSCEVType(). + unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); + if (LType != RType) + return (int)LType - (int)RType; + + // Aside from the getSCEVType() ordering, the particular ordering + // isn't very important except that it's beneficial to be consistent, + // so that (a + b) and (b + a) don't end up as different expressions. + switch (static_cast(LType)) { + case scUnknown: { + const SCEVUnknown *LU = cast(LHS); + const SCEVUnknown *RU = cast(RHS); + + // Sort SCEVUnknown values with some loose heuristics. TODO: This is + // not as complete as it could be. + const Value *LV = LU->getValue(), *RV = RU->getValue(); + + // Order pointer values after integer values. This helps SCEVExpander + // form GEPs. + bool LIsPointer = LV->getType()->isPointerTy(), + RIsPointer = RV->getType()->isPointerTy(); + if (LIsPointer != RIsPointer) + return (int)LIsPointer - (int)RIsPointer; + + // Compare getValueID values. + unsigned LID = LV->getValueID(), + RID = RV->getValueID(); + if (LID != RID) + return (int)LID - (int)RID; + + // Sort arguments by their position. + if (const Argument *LA = dyn_cast(LV)) { + const Argument *RA = cast(RV); + unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); + return (int)LArgNo - (int)RArgNo; } - case scAddRecExpr: { - const SCEVAddRecExpr *LA = cast(LHS); - const SCEVAddRecExpr *RA = cast(RHS); - - // Compare addrec loop depths. - const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); - if (LLoop != RLoop) { - unsigned LDepth = LLoop->getLoopDepth(), - RDepth = RLoop->getLoopDepth(); + // For instructions, compare their loop depth, and their operand + // count. This is pretty loose. + if (const Instruction *LInst = dyn_cast(LV)) { + const Instruction *RInst = cast(RV); + + // Compare loop depths. + const BasicBlock *LParent = LInst->getParent(), + *RParent = RInst->getParent(); + if (LParent != RParent) { + unsigned LDepth = LI->getLoopDepth(LParent), + RDepth = LI->getLoopDepth(RParent); if (LDepth != RDepth) return (int)LDepth - (int)RDepth; } - // Addrec complexity grows with operand count. - unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); - if (LNumOps != RNumOps) - return (int)LNumOps - (int)RNumOps; + // Compare the number of operands. + unsigned LNumOps = LInst->getNumOperands(), + RNumOps = RInst->getNumOperands(); + return (int)LNumOps - (int)RNumOps; + } + + return 0; + } - // Lexicographically compare. - for (unsigned i = 0; i != LNumOps; ++i) { - long X = compare(LA->getOperand(i), RA->getOperand(i)); - if (X != 0) - return X; - } + case scConstant: { + const SCEVConstant *LC = cast(LHS); + const SCEVConstant *RC = cast(RHS); + + // Compare constant values. + const APInt &LA = LC->getAPInt(); + const APInt &RA = RC->getAPInt(); + unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); + if (LBitWidth != RBitWidth) + return (int)LBitWidth - (int)RBitWidth; + return LA.ult(RA) ? -1 : 1; + } - return 0; + case scAddRecExpr: { + const SCEVAddRecExpr *LA = cast(LHS); + const SCEVAddRecExpr *RA = cast(RHS); + + // Compare addrec loop depths. + const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); + if (LLoop != RLoop) { + unsigned LDepth = LLoop->getLoopDepth(), + RDepth = RLoop->getLoopDepth(); + if (LDepth != RDepth) + return (int)LDepth - (int)RDepth; } - case scAddExpr: - case scMulExpr: - case scSMaxExpr: - case scUMaxExpr: { - const SCEVNAryExpr *LC = cast(LHS); - const SCEVNAryExpr *RC = cast(RHS); - - // Lexicographically compare n-ary expressions. - unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); - if (LNumOps != RNumOps) - return (int)LNumOps - (int)RNumOps; - - for (unsigned i = 0; i != LNumOps; ++i) { - if (i >= RNumOps) - return 1; - long X = compare(LC->getOperand(i), RC->getOperand(i)); - if (X != 0) - return X; - } + // Addrec complexity grows with operand count. + unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); + if (LNumOps != RNumOps) return (int)LNumOps - (int)RNumOps; + + // Lexicographically compare. + for (unsigned i = 0; i != LNumOps; ++i) { + long X = compare(LA->getOperand(i), RA->getOperand(i)); + if (X != 0) + return X; } - case scUDivExpr: { - const SCEVUDivExpr *LC = cast(LHS); - const SCEVUDivExpr *RC = cast(RHS); + return 0; + } + + case scAddExpr: + case scMulExpr: + case scSMaxExpr: + case scUMaxExpr: { + const SCEVNAryExpr *LC = cast(LHS); + const SCEVNAryExpr *RC = cast(RHS); - // Lexicographically compare udiv expressions. - long X = compare(LC->getLHS(), RC->getLHS()); + // Lexicographically compare n-ary expressions. + unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); + if (LNumOps != RNumOps) + return (int)LNumOps - (int)RNumOps; + + for (unsigned i = 0; i != LNumOps; ++i) { + if (i >= RNumOps) + return 1; + long X = compare(LC->getOperand(i), RC->getOperand(i)); if (X != 0) return X; - return compare(LC->getRHS(), RC->getRHS()); } + return (int)LNumOps - (int)RNumOps; + } - case scTruncate: - case scZeroExtend: - case scSignExtend: { - const SCEVCastExpr *LC = cast(LHS); - const SCEVCastExpr *RC = cast(RHS); + case scUDivExpr: { + const SCEVUDivExpr *LC = cast(LHS); + const SCEVUDivExpr *RC = cast(RHS); - // Compare cast expressions by operand. - return compare(LC->getOperand(), RC->getOperand()); - } + // Lexicographically compare udiv expressions. + long X = compare(LC->getLHS(), RC->getLHS()); + if (X != 0) + return X; + return compare(LC->getRHS(), RC->getRHS()); + } - case scCouldNotCompute: - llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); - } - llvm_unreachable("Unknown SCEV kind!"); + case scTruncate: + case scZeroExtend: + case scSignExtend: { + const SCEVCastExpr *LC = cast(LHS); + const SCEVCastExpr *RC = cast(RHS); + + // Compare cast expressions by operand. + return compare(LC->getOperand(), RC->getOperand()); } - }; -} + + case scCouldNotCompute: + llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); + } + llvm_unreachable("Unknown SCEV kind!"); + } +}; +} // end anonymous namespace /// GroupByComplexity - Given a list of SCEV objects, order them by their /// complexity, and group objects of the same complexity together by value. @@ -659,24 +666,22 @@ static void GroupByComplexity(SmallVectorImpl &Ops, } } -namespace { -struct FindSCEVSize { - int Size; - FindSCEVSize() : Size(0) {} - - bool follow(const SCEV *S) { - ++Size; - // Keep looking at all operands of S. - return true; - } - bool isDone() const { - return false; - } -}; -} - // Returns the size of the SCEV S. static inline int sizeOfSCEV(const SCEV *S) { + struct FindSCEVSize { + int Size; + FindSCEVSize() : Size(0) {} + + bool follow(const SCEV *S) { + ++Size; + // Keep looking at all operands of S. + return true; + } + bool isDone() const { + return false; + } + }; + FindSCEVSize F; SCEVTraversal ST(F); ST.visitAll(S); @@ -755,8 +760,8 @@ public: void visitConstant(const SCEVConstant *Numerator) { if (const SCEVConstant *D = dyn_cast(Denominator)) { - APInt NumeratorVal = Numerator->getValue()->getValue(); - APInt DenominatorVal = D->getValue()->getValue(); + APInt NumeratorVal = Numerator->getAPInt(); + APInt DenominatorVal = D->getAPInt(); uint32_t NumeratorBW = NumeratorVal.getBitWidth(); uint32_t DenominatorBW = DenominatorVal.getBitWidth(); @@ -1366,7 +1371,7 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, if (!StartC) return false; - APInt StartAI = StartC->getValue()->getValue(); + APInt StartAI = StartC->getAPInt(); for (unsigned Delta : {-2, -1, 1, 2}) { const SCEV *PreStart = getConstant(StartAI - Delta); @@ -1627,8 +1632,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, auto *SMul = dyn_cast(SA->getOperand(1)); if (SMul && SC1) { if (auto *SC2 = dyn_cast(SMul->getOperand(0))) { - const APInt &C1 = SC1->getValue()->getValue(); - const APInt &C2 = SC2->getValue()->getValue(); + const APInt &C1 = SC1->getAPInt(); + const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) return getAddExpr(getSignExtendExpr(SC1, Ty), @@ -1753,8 +1758,8 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, auto *SC1 = dyn_cast(Start); auto *SC2 = dyn_cast(Step); if (SC1 && SC2) { - const APInt &C1 = SC1->getValue()->getValue(); - const APInt &C2 = SC2->getValue()->getValue(); + const APInt &C1 = SC1->getAPInt(); + const APInt &C2 = SC2->getAPInt(); if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && C2.isPowerOf2()) { Start = getSignExtendExpr(Start, Ty); @@ -1794,7 +1799,7 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, // Sign-extend negative constants. if (const SCEVConstant *SC = dyn_cast(Op)) - if (SC->getValue()->getValue().isNegative()) + if (SC->getAPInt().isNegative()) return getSignExtendExpr(Op, Ty); // Peel off a truncate cast. @@ -1872,7 +1877,7 @@ CollectAddOperandsWithScales(DenseMap &M, // Pull a buried constant out to the outside. if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero()) Interesting = true; - AccumulatedConstant += Scale * C->getValue()->getValue(); + AccumulatedConstant += Scale * C->getAPInt(); } // Next comes everything else. We're especially interested in multiplies @@ -1881,7 +1886,7 @@ CollectAddOperandsWithScales(DenseMap &M, const SCEVMulExpr *Mul = dyn_cast(Ops[i]); if (Mul && isa(Mul->getOperand(0))) { APInt NewScale = - Scale * cast(Mul->getOperand(0))->getValue()->getValue(); + Scale * cast(Mul->getOperand(0))->getAPInt(); if (Mul->getNumOperands() == 2 && isa(Mul->getOperand(1))) { // A multiplication of a constant with another add; recurse. const SCEVAddExpr *Add = cast(Mul->getOperand(1)); @@ -1922,14 +1927,6 @@ CollectAddOperandsWithScales(DenseMap &M, return Interesting; } -namespace { - struct APIntCompare { - bool operator()(const APInt &LHS, const APInt &RHS) const { - return LHS.ult(RHS); - } - }; -} - // We're trying to construct a SCEV of type `Type' with `Ops' as operands and // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of // can't-overflow flags for the operation if possible. @@ -1950,11 +1947,11 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - auto IsKnownNonNegative = - std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1); + auto IsKnownNonNegative = [&](const SCEV *S) { + return SE->isKnownNonNegative(S); + }; - if (SignOrUnsignWrap == SCEV::FlagNSW && - std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative)) + if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative)) Flags = ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); @@ -1966,7 +1963,7 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, // (A + C) --> (A + C) if the addition does not sign overflow // (A + C) --> (A + C) if the addition does not unsign overflow - const APInt &C = cast(Ops[0])->getValue()->getValue(); + const APInt &C = cast(Ops[0])->getAPInt(); if (!(SignOrUnsignWrap & SCEV::FlagNSW)) { auto NSWRegion = ConstantRange::makeNoWrapRegion(Instruction::Add, C, OBO::NoSignedWrap); @@ -2012,8 +2009,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - Ops[0] = getConstant(LHSC->getValue()->getValue() + - RHSC->getValue()->getValue()); + Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt()); if (Ops.size() == 2) return Ops[0]; Ops.erase(Ops.begin()+1); // Erase the folded element LHSC = cast(Ops[0]); @@ -2142,22 +2138,26 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, Ops.data(), Ops.size(), APInt(BitWidth, 1), *this)) { + struct APIntCompare { + bool operator()(const APInt &LHS, const APInt &RHS) const { + return LHS.ult(RHS); + } + }; + // Some interesting folding opportunity is present, so its worthwhile to // re-generate the operands list. Group the operands by constant scale, // to avoid multiplying by the same constant scale multiple times. std::map, APIntCompare> MulOpLists; - for (SmallVectorImpl::const_iterator I = NewOps.begin(), - E = NewOps.end(); I != E; ++I) - MulOpLists[M.find(*I)->second].push_back(*I); + for (const SCEV *NewOp : NewOps) + MulOpLists[M.find(NewOp)->second].push_back(NewOp); // Re-generate the operands list. Ops.clear(); if (AccumulatedConstant != 0) Ops.push_back(getConstant(AccumulatedConstant)); - for (std::map, APIntCompare>::iterator - I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I) - if (I->first != 0) - Ops.push_back(getMulExpr(getConstant(I->first), - getAddExpr(I->second))); + for (auto &MulOp : MulOpLists) + if (MulOp.first != 0) + Ops.push_back(getMulExpr(getConstant(MulOp.first), + getAddExpr(MulOp.second))); if (Ops.empty()) return getZero(Ty); if (Ops.size() == 1) @@ -2298,8 +2298,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl &Ops, AddRec->op_end()); for (; OtherIdx != Ops.size() && isa(Ops[OtherIdx]); ++OtherIdx) - if (const SCEVAddRecExpr *OtherAddRec = - dyn_cast(Ops[OtherIdx])) + if (const auto *OtherAddRec = dyn_cast(Ops[OtherIdx])) if (OtherAddRec->getLoop() == AddRecLoop) { for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) { @@ -2429,9 +2428,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, ++Idx; while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get(getContext(), - LHSC->getValue()->getValue() * - RHSC->getValue()->getValue()); + ConstantInt *Fold = + ConstantInt::get(getContext(), LHSC->getAPInt() * RHSC->getAPInt()); Ops[0] = getConstant(Fold); Ops.erase(Ops.begin()+1); // Erase the folded element if (Ops.size() == 1) return Ops[0]; @@ -2452,9 +2450,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, if (const SCEVAddExpr *Add = dyn_cast(Ops[1])) { SmallVector NewOps; bool AnyFolded = false; - for (SCEVAddRecExpr::op_iterator I = Add->op_begin(), - E = Add->op_end(); I != E; ++I) { - const SCEV *Mul = getMulExpr(Ops[0], *I); + for (const SCEV *AddOp : Add->operands()) { + const SCEV *Mul = getMulExpr(Ops[0], AddOp); if (!isa(Mul)) AnyFolded = true; NewOps.push_back(Mul); } @@ -2463,10 +2460,9 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, } else if (const auto *AddRec = dyn_cast(Ops[1])) { // Negation preserves a recurrence's no self-wrap property. SmallVector Operands; - for (SCEVAddRecExpr::op_iterator I = AddRec->op_begin(), - E = AddRec->op_end(); I != E; ++I) { - Operands.push_back(getMulExpr(Ops[0], *I)); - } + for (const SCEV *AddRecOp : AddRec->operands()) + Operands.push_back(getMulExpr(Ops[0], AddRecOp)); + return getAddRecExpr(Operands, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW)); } @@ -2655,11 +2651,11 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, // its operands. // TODO: Generalize this to non-constants by using known-bits information. Type *Ty = LHS->getType(); - unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros(); + unsigned LZ = RHSC->getAPInt().countLeadingZeros(); unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1; // For non-power-of-two values, effectively round the value up to the // nearest power of two. - if (!RHSC->getValue()->getValue().isPowerOf2()) + if (!RHSC->getAPInt().isPowerOf2()) ++MaxShiftAmt; IntegerType *ExtTy = IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt); @@ -2667,8 +2663,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, if (const SCEVConstant *Step = dyn_cast(AR->getStepRecurrence(*this))) { // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. - const APInt &StepInt = Step->getValue()->getValue(); - const APInt &DivInt = RHSC->getValue()->getValue(); + const APInt &StepInt = Step->getAPInt(); + const APInt &DivInt = RHSC->getAPInt(); if (!StepInt.urem(DivInt) && getZeroExtendExpr(AR, ExtTy) == getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), @@ -2688,7 +2684,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), getZeroExtendExpr(Step, ExtTy), AR->getLoop(), SCEV::FlagAnyWrap)) { - const APInt &StartInt = StartC->getValue()->getValue(); + const APInt &StartInt = StartC->getAPInt(); const APInt &StartRem = StartInt.urem(StepInt); if (StartRem != 0) LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step, @@ -2755,8 +2751,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, } static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue().abs(); - APInt B = C2->getValue()->getValue().abs(); + APInt A = C1->getAPInt().abs(); + APInt B = C2->getAPInt().abs(); uint32_t ABW = A.getBitWidth(); uint32_t BBW = B.getBitWidth(); @@ -2797,10 +2793,10 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, // check. APInt Factor = gcd(LHSCst, RHSCst); if (!Factor.isIntN(1)) { - LHSCst = cast( - getConstant(LHSCst->getValue()->getValue().udiv(Factor))); - RHSCst = cast( - getConstant(RHSCst->getValue()->getValue().udiv(Factor))); + LHSCst = + cast(getConstant(LHSCst->getAPInt().udiv(Factor))); + RHSCst = + cast(getConstant(RHSCst->getAPInt().udiv(Factor))); SmallVector Operands; Operands.push_back(LHSCst); Operands.append(Mul->op_begin() + 1, Mul->op_end()); @@ -2884,9 +2880,8 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, // AddRecs require their operands be loop-invariant with respect to their // loops. Don't perform this transformation if it would break this // requirement. - bool AllInvariant = - std::all_of(Operands.begin(), Operands.end(), - [&](const SCEV *Op) { return isLoopInvariant(Op, L); }); + bool AllInvariant = all_of( + Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); }); if (AllInvariant) { // Create a recurrence for the outer loop with the same step size. @@ -2897,9 +2892,9 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl &Operands, maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags()); NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags); - AllInvariant = std::all_of( - NestedOperands.begin(), NestedOperands.end(), - [&](const SCEV *Op) { return isLoopInvariant(Op, NestedLoop); }); + AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) { + return isLoopInvariant(Op, NestedLoop); + }); if (AllInvariant) { // Ok, both add recurrences are valid after the transformation. @@ -3017,9 +3012,8 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl &Ops) { assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get(getContext(), - APIntOps::smax(LHSC->getValue()->getValue(), - RHSC->getValue()->getValue())); + ConstantInt *Fold = ConstantInt::get( + getContext(), APIntOps::smax(LHSC->getAPInt(), RHSC->getAPInt())); Ops[0] = getConstant(Fold); Ops.erase(Ops.begin()+1); // Erase the folded element if (Ops.size() == 1) return Ops[0]; @@ -3121,9 +3115,8 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl &Ops) { assert(Idx < Ops.size()); while (const SCEVConstant *RHSC = dyn_cast(Ops[Idx])) { // We found two constants, fold them together! - ConstantInt *Fold = ConstantInt::get(getContext(), - APIntOps::umax(LHSC->getValue()->getValue(), - RHSC->getValue()->getValue())); + ConstantInt *Fold = ConstantInt::get( + getContext(), APIntOps::umax(LHSC->getAPInt(), RHSC->getAPInt())); Ops[0] = getConstant(Fold); Ops.erase(Ops.begin()+1); // Erase the folded element if (Ops.size() == 1) return Ops[0]; @@ -3212,8 +3205,7 @@ const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { // We can bypass creating a target-independent // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. - return getConstant(IntTy, - F.getParent()->getDataLayout().getTypeAllocSize(AllocTy)); + return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); } const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, @@ -3223,9 +3215,7 @@ const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, // constant expression and then folding it back into a ConstantInt. // This is just a compile-time optimization. return getConstant( - IntTy, - F.getParent()->getDataLayout().getStructLayout(STy)->getElementOffset( - FieldNo)); + IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo)); } const SCEV *ScalarEvolution::getUnknown(Value *V) { @@ -3267,7 +3257,7 @@ bool ScalarEvolution::isSCEVable(Type *Ty) const { /// for which isSCEVable must return true. uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); - return F.getParent()->getDataLayout().getTypeSizeInBits(Ty); + return getDataLayout().getTypeSizeInBits(Ty); } /// getEffectiveSCEVType - Return a type with the same bitwidth as @@ -3282,14 +3272,15 @@ Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { // The only other support type is pointer. assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!"); - return F.getParent()->getDataLayout().getIntPtrType(Ty); + return getDataLayout().getIntPtrType(Ty); } const SCEV *ScalarEvolution::getCouldNotCompute() { return CouldNotCompute.get(); } -namespace { + +bool ScalarEvolution::checkValidity(const SCEV *S) const { // Helper class working with SCEVTraversal to figure out if a SCEV contains // a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne // is set iff if find such SCEVUnknown. @@ -3311,9 +3302,7 @@ namespace { } bool isDone() const { return FindOne; } }; -} -bool ScalarEvolution::checkValidity(const SCEV *S) const { FindInvalidSCEVUnknown F; SCEVTraversal ST(F); ST.visitAll(S); @@ -3555,13 +3544,12 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { return getPointerBase(Cast->getOperand()); } else if (const SCEVNAryExpr *NAry = dyn_cast(V)) { const SCEV *PtrOp = nullptr; - for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); - I != E; ++I) { - if ((*I)->getType()->isPointerTy()) { + for (const SCEV *NAryOp : NAry->operands()) { + if (NAryOp->getType()->isPointerTy()) { // Cannot find the base of an expression with multiple pointer operands. if (PtrOp) return V; - PtrOp = *I; + PtrOp = NAryOp; } } if (!PtrOp) @@ -3625,6 +3613,73 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { } } +namespace { +class SCEVInitRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE) { + SCEVInitRewriter Rewriter(L, SE); + const SCEV *Result = Rewriter.visit(Scev); + return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); + } + + SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L), Valid(true) {} + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) + Valid = false; + return Expr; + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + // Only allow AddRecExprs for this loop. + if (Expr->getLoop() == L) + return Expr->getStart(); + Valid = false; + return Expr; + } + + bool isValid() { return Valid; } + +private: + const Loop *L; + bool Valid; +}; + +class SCEVShiftRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE) { + SCEVShiftRewriter Rewriter(L, SE); + const SCEV *Result = Rewriter.visit(Scev); + return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); + } + + SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L), Valid(true) {} + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + // Only allow AddRecExprs for this loop. + if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) + Valid = false; + return Expr; + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + if (Expr->getLoop() == L && Expr->isAffine()) + return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE)); + Valid = false; + return Expr; + } + bool isValid() { return Valid; } + +private: + const Loop *L; + bool Valid; +}; +} // end anonymous namespace + const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const Loop *L = LI.getLoopFor(PN->getParent()); if (!L || L->getHeader() != PN->getParent()) @@ -3737,30 +3792,28 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { return PHISCEV; } } - } else if (const auto *AddRec = dyn_cast(BEValue)) { + } else { // Otherwise, this could be a loop like this: // i = 0; for (j = 1; ..; ++j) { .... i = j; } // In this case, j = {1,+,1} and BEValue is j. // Because the other in-value of i (0) fits the evolution of BEValue // i really is an addrec evolution. - if (AddRec->getLoop() == L && AddRec->isAffine()) { + // + // We can generalize this saying that i is the shifted value of BEValue + // by one iteration: + // PHI(f(0), f({1,+,1})) --> f({0,+,1}) + const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); + const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this); + if (Shifted != getCouldNotCompute() && + Start != getCouldNotCompute()) { const SCEV *StartVal = getSCEV(StartValueV); - - // If StartVal = j.start - j.stride, we can use StartVal as the - // initial step of the addrec evolution. - if (StartVal == - getMinusSCEV(AddRec->getOperand(0), AddRec->getOperand(1))) { - // FIXME: For constant StartVal, we should be able to infer - // no-wrap flags. - const SCEV *PHISCEV = getAddRecExpr(StartVal, AddRec->getOperand(1), - L, SCEV::FlagAnyWrap); - + if (Start == StartVal) { // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the // entries for the scalars that use the symbolic expression. ForgetSymbolicName(PN, SymbolicName); - ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; - return PHISCEV; + ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; + return Shifted; } } } @@ -3879,6 +3932,11 @@ const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { if (PN->getNumIncomingValues() == 2) { const Loop *L = LI.getLoopFor(PN->getParent()); + // We don't want to break LCSSA, even in a SCEV expression tree. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (LI.getLoopFor(PN->getIncomingBlock(i)) != L) + return nullptr; + // Try to match // // br %cond, label %left, label %right @@ -3918,8 +3976,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // PHI's incoming blocks are in a different loop, in which case doing so // risks breaking LCSSA form. Instcombine would normally zap these, but // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = SimplifyInstruction(PN, F.getParent()->getDataLayout(), &TLI, - &DT, &AC)) + if (Value *V = SimplifyInstruction(PN, getDataLayout(), &TLI, &DT, &AC)) if (LI.replacementPreservesLCSSAForm(PN, V)) return getSCEV(V); @@ -4049,7 +4106,7 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { if (const SCEVConstant *C = dyn_cast(S)) - return C->getValue()->getValue().countTrailingZeros(); + return C->getAPInt().countTrailingZeros(); if (const SCEVTruncateExpr *T = dyn_cast(S)) return std::min(GetMinTrailingZeros(T->getOperand()), @@ -4114,8 +4171,8 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { // For a SCEVUnknown, ask ValueTracking. unsigned BitWidth = getTypeSizeInBits(U->getType()); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, F.getParent()->getDataLayout(), - 0, &AC, nullptr, &DT); + computeKnownBits(U->getValue(), Zeros, Ones, getDataLayout(), 0, &AC, + nullptr, &DT); return Zeros.countTrailingOnes(); } @@ -4150,7 +4207,7 @@ ScalarEvolution::getRange(const SCEV *S, return I->second; if (const SCEVConstant *C = dyn_cast(S)) - return setRange(C, SignHint, ConstantRange(C->getValue()->getValue())); + return setRange(C, SignHint, ConstantRange(C->getAPInt())); unsigned BitWidth = getTypeSizeInBits(S->getType()); ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); @@ -4228,9 +4285,8 @@ ScalarEvolution::getRange(const SCEV *S, if (AddRec->getNoWrapFlags(SCEV::FlagNUW)) if (const SCEVConstant *C = dyn_cast(AddRec->getStart())) if (!C->getValue()->isZero()) - ConservativeResult = - ConservativeResult.intersectWith( - ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0))); + ConservativeResult = ConservativeResult.intersectWith( + ConstantRange(C->getAPInt(), APInt(BitWidth, 0))); // If there's no signed wrap, and all the operands have the same sign or // zero, the value won't ever change sign. @@ -4328,7 +4384,7 @@ ScalarEvolution::getRange(const SCEV *S, // Split here to avoid paying the compile-time cost of calling both // computeKnownBits and ComputeNumSignBits. This restriction can be lifted // if needed. - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = getDataLayout(); if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { // For a SCEVUnknown, ask ValueTracking. APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); @@ -4536,8 +4592,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, - F.getParent()->getDataLayout(), 0, &AC, nullptr, &DT); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, getDataLayout(), + 0, &AC, nullptr, &DT); APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); @@ -5396,6 +5452,11 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, return ItCnt; } + ExitLimit ShiftEL = computeShiftCompareExitLimit( + ExitCond->getOperand(0), ExitCond->getOperand(1), L, Cond); + if (ShiftEL.hasAnyInfo()) + return ShiftEL; + const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); @@ -5421,7 +5482,7 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, if (AddRec->getLoop() == L) { // Form the constant range. ConstantRange CompRange( - ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue())); + ICmpInst::makeConstantRange(Cond, RHSC->getAPInt())); const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); if (!isa(Ret)) return Ret; @@ -5455,14 +5516,6 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, break; } default: -#if 0 - dbgs() << "computeBackedgeTakenCount "; - if (ExitCond->getOperand(0)->getType()->isUnsigned()) - dbgs() << "[unsigned] "; - dbgs() << *LHS << " " - << Instruction::getOpcodeName(Instruction::ICmp) - << " " << *RHS << "\n"; -#endif break; } return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); @@ -5575,11 +5628,6 @@ ScalarEvolution::computeLoadConstantCompareExitLimit( Result = ConstantExpr::getICmp(predicate, Result, RHS); if (!isa(Result)) break; // Couldn't decide for sure if (cast(Result)->getValue().isMinValue()) { -#if 0 - dbgs() << "\n***\n*** Computed loop count " << *ItCst - << "\n*** From global " << *GV << "*** BB: " << *L->getHeader() - << "***\n"; -#endif ++NumArrayLenItCounts; return getConstant(ItCst); // Found terminating iteration! } @@ -5587,6 +5635,149 @@ ScalarEvolution::computeLoadConstantCompareExitLimit( return getCouldNotCompute(); } +ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( + Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) { + ConstantInt *RHS = dyn_cast(RHSV); + if (!RHS) + return getCouldNotCompute(); + + const BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return getCouldNotCompute(); + + const BasicBlock *Predecessor = L->getLoopPredecessor(); + if (!Predecessor) + return getCouldNotCompute(); + + // Return true if V is of the form "LHS `shift_op` ". + // Return LHS in OutLHS and shift_opt in OutOpCode. + auto MatchPositiveShift = + [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) { + + using namespace PatternMatch; + + ConstantInt *ShiftAmt; + if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) + OutOpCode = Instruction::LShr; + else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) + OutOpCode = Instruction::AShr; + else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) + OutOpCode = Instruction::Shl; + else + return false; + + return ShiftAmt->getValue().isStrictlyPositive(); + }; + + // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in + // + // loop: + // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ] + // %iv.shifted = lshr i32 %iv, + // + // Return true on a succesful match. Return the corresponding PHI node (%iv + // above) in PNOut and the opcode of the shift operation in OpCodeOut. + auto MatchShiftRecurrence = + [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) { + Optional PostShiftOpCode; + + { + Instruction::BinaryOps OpC; + Value *V; + + // If we encounter a shift instruction, "peel off" the shift operation, + // and remember that we did so. Later when we inspect %iv's backedge + // value, we will make sure that the backedge value uses the same + // operation. + // + // Note: the peeled shift operation does not have to be the same + // instruction as the one feeding into the PHI's backedge value. We only + // really care about it being the same *kind* of shift instruction -- + // that's all that is required for our later inferences to hold. + if (MatchPositiveShift(LHS, V, OpC)) { + PostShiftOpCode = OpC; + LHS = V; + } + } + + PNOut = dyn_cast(LHS); + if (!PNOut || PNOut->getParent() != L->getHeader()) + return false; + + Value *BEValue = PNOut->getIncomingValueForBlock(Latch); + Value *OpLHS; + + return + // The backedge value for the PHI node must be a shift by a positive + // amount + MatchPositiveShift(BEValue, OpLHS, OpCodeOut) && + + // of the PHI node itself + OpLHS == PNOut && + + // and the kind of shift should be match the kind of shift we peeled + // off, if any. + (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut); + }; + + PHINode *PN; + Instruction::BinaryOps OpCode; + if (!MatchShiftRecurrence(LHS, PN, OpCode)) + return getCouldNotCompute(); + + const DataLayout &DL = getDataLayout(); + + // The key rationale for this optimization is that for some kinds of shift + // recurrences, the value of the recurrence "stabilizes" to either 0 or -1 + // within a finite number of iterations. If the condition guarding the + // backedge (in the sense that the backedge is taken if the condition is true) + // is false for the value the shift recurrence stabilizes to, then we know + // that the backedge is taken only a finite number of times. + + ConstantInt *StableValue = nullptr; + switch (OpCode) { + default: + llvm_unreachable("Impossible case!"); + + case Instruction::AShr: { + // {K,ashr,} stabilizes to signum(K) in at most + // bitwidth(K) iterations. + Value *FirstValue = PN->getIncomingValueForBlock(Predecessor); + bool KnownZero, KnownOne; + ComputeSignBit(FirstValue, KnownZero, KnownOne, DL, 0, nullptr, + Predecessor->getTerminator(), &DT); + auto *Ty = cast(RHS->getType()); + if (KnownZero) + StableValue = ConstantInt::get(Ty, 0); + else if (KnownOne) + StableValue = ConstantInt::get(Ty, -1, true); + else + return getCouldNotCompute(); + + break; + } + case Instruction::LShr: + case Instruction::Shl: + // Both {K,lshr,} and {K,shl,} + // stabilize to 0 in at most bitwidth(K) iterations. + StableValue = ConstantInt::get(cast(RHS->getType()), 0); + break; + } + + auto *Result = + ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI); + assert(Result->getType()->isIntegerTy(1) && + "Otherwise cannot be an operand to a branch instruction"); + + if (Result->isZeroValue()) { + unsigned BitWidth = getTypeSizeInBits(RHS->getType()); + const SCEV *UpperBound = + getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); + return ExitLimit(getCouldNotCompute(), UpperBound); + } + + return getCouldNotCompute(); +} /// CanConstantFold - Return true if we can constant fold an instruction of the /// specified type, assuming that all operands were constants. @@ -5628,12 +5819,10 @@ getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, // Otherwise, we can evaluate this instruction if all of its operands are // constant or derived from a PHI node themselves. PHINode *PHI = nullptr; - for (Instruction::op_iterator OpI = UseInst->op_begin(), - OpE = UseInst->op_end(); OpI != OpE; ++OpI) { - - if (isa(*OpI)) continue; + for (Value *Op : UseInst->operands()) { + if (isa(Op)) continue; - Instruction *OpInst = dyn_cast(*OpI); + Instruction *OpInst = dyn_cast(Op); if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr; PHINode *P = dyn_cast(OpInst); @@ -5725,6 +5914,30 @@ static Constant *EvaluateExpression(Value *V, const Loop *L, TLI); } + +// If every incoming value to PN except the one for BB is a specific Constant, +// return that, else return nullptr. +static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) { + Constant *IncomingVal = nullptr; + + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + if (PN->getIncomingBlock(i) == BB) + continue; + + auto *CurrentVal = dyn_cast(PN->getIncomingValue(i)); + if (!CurrentVal) + return nullptr; + + if (IncomingVal != CurrentVal) { + if (IncomingVal) + return nullptr; + IncomingVal = CurrentVal; + } + } + + return IncomingVal; +} + /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a /// constant number of times, and the PHI node is just a recurrence @@ -5750,25 +5963,10 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, if (!Latch) return nullptr; - // Since the loop has one latch, the PHI node must have two entries. One - // entry must be a constant (coming in from outside of the loop), and the - // second must be derived from the same PHI. - - BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0) - ? PN->getIncomingBlock(1) - : PN->getIncomingBlock(0); - - assert(PN->getNumIncomingValues() == 2 && "Follows from having one latch!"); - - // Note: not all PHI nodes in the same block have to have their incoming - // values in the same order, so we use the basic block to look up the incoming - // value, not an index. - for (auto &I : *Header) { PHINode *PHI = dyn_cast(&I); if (!PHI) break; - auto *StartCST = - dyn_cast(PHI->getIncomingValueForBlock(NonLatch)); + auto *StartCST = getOtherIncomingValue(PHI, Latch); if (!StartCST) continue; CurrentIterVals[PHI] = StartCST; } @@ -5783,7 +5981,7 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, unsigned NumIterations = BEs.getZExtValue(); // must be in range unsigned IterationNum = 0; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = getDataLayout(); for (; ; ++IterationNum) { if (IterationNum == NumIterations) return RetVal = CurrentIterVals[PN]; // Got exit value! @@ -5847,21 +6045,11 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Should follow from NumIncomingValues == 2!"); - // NonLatch is the preheader, or something equivalent. - BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0) - ? PN->getIncomingBlock(1) - : PN->getIncomingBlock(0); - - // Note: not all PHI nodes in the same block have to have their incoming - // values in the same order, so we use the basic block to look up the incoming - // value, not an index. - for (auto &I : *Header) { PHINode *PHI = dyn_cast(&I); if (!PHI) break; - auto *StartCST = - dyn_cast(PHI->getIncomingValueForBlock(NonLatch)); + auto *StartCST = getOtherIncomingValue(PHI, Latch); if (!StartCST) continue; CurrentIterVals[PHI] = StartCST; } @@ -5872,7 +6060,7 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, // the loop symbolically to determine when the condition gets a value of // "ExitWhen". unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis. - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = getDataLayout(); for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){ auto *CondVal = dyn_cast_or_null( EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI)); @@ -5922,22 +6110,22 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, /// In the case that a relevant loop exit value cannot be computed, the /// original value V is returned. const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { + SmallVector, 2> &Values = + ValuesAtScopes[V]; // Check to see if we've folded this expression at this loop before. - SmallVector, 2> &Values = ValuesAtScopes[V]; - for (unsigned u = 0; u < Values.size(); u++) { - if (Values[u].first == L) - return Values[u].second ? Values[u].second : V; - } - Values.push_back(std::make_pair(L, static_cast(nullptr))); + for (auto &LS : Values) + if (LS.first == L) + return LS.second ? LS.second : V; + + Values.emplace_back(L, nullptr); + // Otherwise compute it. const SCEV *C = computeSCEVAtScope(V, L); - SmallVector, 2> &Values2 = ValuesAtScopes[V]; - for (unsigned u = Values2.size(); u > 0; u--) { - if (Values2[u - 1].first == L) { - Values2[u - 1].second = C; + for (auto &LS : reverse(ValuesAtScopes[V])) + if (LS.first == L) { + LS.second = C; break; } - } return C; } @@ -6061,9 +6249,8 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // Okay, we know how many times the containing loop executes. If // this is a constant evolving PHI node, get the final value at // the specified iteration number. - Constant *RV = getConstantEvolutionLoopExitValue(PN, - BTCC->getValue()->getValue(), - LI); + Constant *RV = + getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), LI); if (RV) return getSCEV(RV); } } @@ -6104,7 +6291,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { // Check to see if getSCEVAtScope actually made an improvement. if (MadeImprovement) { Constant *C = nullptr; - const DataLayout &DL = F.getParent()->getDataLayout(); + const DataLayout &DL = getDataLayout(); if (const CmpInst *CI = dyn_cast(I)) C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], Operands[1], DL, &TLI); @@ -6304,10 +6491,10 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { return std::make_pair(CNC, CNC); } - uint32_t BitWidth = LC->getValue()->getValue().getBitWidth(); - const APInt &L = LC->getValue()->getValue(); - const APInt &M = MC->getValue()->getValue(); - const APInt &N = NC->getValue()->getValue(); + uint32_t BitWidth = LC->getAPInt().getBitWidth(); + const APInt &L = LC->getAPInt(); + const APInt &M = MC->getAPInt(); + const APInt &N = NC->getAPInt(); APInt Two(BitWidth, 2); APInt Four(BitWidth, 4); @@ -6386,10 +6573,6 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { const SCEVConstant *R1 = dyn_cast(Roots.first); const SCEVConstant *R2 = dyn_cast(Roots.second); if (R1 && R2) { -#if 0 - dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1 - << " sol#2: " << *R2 << "\n"; -#endif // Pick the smallest positive root value. if (ConstantInt *CB = dyn_cast(ConstantExpr::getICmp(CmpInst::ICMP_ULT, @@ -6443,7 +6626,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // For negative steps (counting down to zero): // N = Start/-Step // First compute the unsigned distance from zero in the direction of Step. - bool CountDown = StepC->getValue()->getValue().isNegative(); + bool CountDown = StepC->getAPInt().isNegative(); const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start); // Handle unitary steps, which cannot wraparound. @@ -6468,7 +6651,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // done by counting and comparing the number of trailing zeros of Step and // Distance. if (!CountDown) { - const APInt &StepV = StepC->getValue()->getValue(); + const APInt &StepV = StepC->getAPInt(); // StepV.isPowerOf2() returns true if StepV is an positive power of two. It // also returns true if StepV is maximally negative (eg, INT_MIN), but that // case is not handled as this code is guarded by !CountDown. @@ -6530,8 +6713,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // Then, try to solve the above equation provided that Start is constant. if (const SCEVConstant *StartC = dyn_cast(Start)) - return SolveLinEquationWithOverflow(StepC->getValue()->getValue(), - -StartC->getValue()->getValue(), + return SolveLinEquationWithOverflow(StepC->getAPInt(), -StartC->getAPInt(), *this); return getCouldNotCompute(); } @@ -6654,7 +6836,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, // If there's a constant operand, canonicalize comparisons with boundary // cases, and canonicalize *-or-equal comparisons to regular comparisons. if (const SCEVConstant *RC = dyn_cast(RHS)) { - const APInt &RA = RC->getValue()->getValue(); + const APInt &RA = RC->getAPInt(); switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: @@ -6845,16 +7027,14 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, Pred = ICmpInst::ICMP_ULT; Changed = true; } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) { - LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS, - SCEV::FlagNUW); + LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS); Pred = ICmpInst::ICMP_ULT; Changed = true; } break; case ICmpInst::ICMP_UGE: if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) { - RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, - SCEV::FlagNUW); + RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS); Pred = ICmpInst::ICMP_UGT; Changed = true; } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) { @@ -7166,7 +7346,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, !isa(ConstOp) || NonConstOp != X) return false; - OutY = cast(ConstOp)->getValue()->getValue(); + OutY = cast(ConstOp)->getAPInt(); return (FlagsPresent & ExpectedFlags) == ExpectedFlags; }; @@ -7187,6 +7367,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && !C.isStrictlyPositive()) return true; + break; case ICmpInst::ICMP_SGT: std::swap(LHS, RHS); @@ -7199,6 +7380,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, // (X + C) s< X if C < 0 if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative()) return true; + break; } return false; @@ -7221,12 +7403,9 @@ bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the // interesting cases seen in practice. We can consider "upgrading" L >= 0 to // use isKnownPredicate later if needed. - if (isKnownNonNegative(RHS) && - isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) && - isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS)) - return true; - - return false; + return isKnownNonNegative(RHS) && + isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) && + isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS); } /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is @@ -7379,6 +7558,7 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, return false; } +namespace { /// RAII wrapper to prevent recursive application of isImpliedCond. /// ScalarEvolution's PendingLoopPredicates set must be empty unless we are /// currently evaluating isImpliedCond. @@ -7396,6 +7576,7 @@ struct MarkPendingLoopPredicate { LoopPreds.erase(Cond); } }; +} // end anonymous namespace /// isImpliedCond - Test whether the condition described by Pred, LHS, /// and RHS is true whenever the given Cond value evaluates to true. @@ -7527,7 +7708,7 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, APInt Min = ICmpInst::isSigned(Pred) ? getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin(); - if (Min == C->getValue()->getValue()) { + if (Min == C->getAPInt()) { // Given (V >= Min && V != Min) we conclude V >= (Min + 1). // This is true even if (Min + 1) wraps around -- in case of // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)). @@ -7619,8 +7800,8 @@ bool ScalarEvolution::computeConstantDifference(const SCEV *Less, } if (isa(Less) && isa(More)) { - const auto &M = cast(More)->getValue()->getValue(); - const auto &L = cast(Less)->getValue()->getValue(); + const auto &M = cast(More)->getAPInt(); + const auto &L = cast(Less)->getAPInt(); C = M - L; return true; } @@ -7630,14 +7811,14 @@ bool ScalarEvolution::computeConstantDifference(const SCEV *Less, if (splitBinaryAdd(Less, L, R, Flags)) if (const auto *LC = dyn_cast(L)) if (R == More) { - C = -(LC->getValue()->getValue()); + C = -(LC->getAPInt()); return true; } if (splitBinaryAdd(More, L, R, Flags)) if (const auto *LC = dyn_cast(L)) if (R == Less) { - C = LC->getValue()->getValue(); + C = LC->getAPInt(); return true; } @@ -7766,8 +7947,7 @@ static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr, const MaxExprType *MaxExpr = dyn_cast(MaybeMaxExpr); if (!MaxExpr) return false; - auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate); - return It != MaxExpr->op_end(); + return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end(); } @@ -7918,7 +8098,7 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, !isa(AddLHS->getOperand(0))) return false; - APInt ConstFoundRHS = cast(FoundRHS)->getValue()->getValue(); + APInt ConstFoundRHS = cast(FoundRHS)->getAPInt(); // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the // antecedent "`FoundLHS` `Pred` `FoundRHS`". @@ -7927,13 +8107,12 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, // Since `LHS` is `FoundLHS` + `AddLHS->getOperand(0)`, we can compute a range // for `LHS`: - APInt Addend = - cast(AddLHS->getOperand(0))->getValue()->getValue(); + APInt Addend = cast(AddLHS->getOperand(0))->getAPInt(); ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(Addend)); // We can also compute the range of values for `LHS` that satisfy the // consequent, "`LHS` `Pred` `RHS`": - APInt ConstRHS = cast(RHS)->getValue()->getValue(); + APInt ConstRHS = cast(RHS)->getAPInt(); ConstantRange SatisfyingLHSRange = ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS); @@ -8057,7 +8236,7 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, // overflow, in which case if RHS - Start is a constant, we don't need to // do a max operation since we can just figure it out statically if (NoWrap && isa(Diff)) { - APInt D = dyn_cast(Diff)->getValue()->getValue(); + APInt D = dyn_cast(Diff)->getAPInt(); if (D.isNegative()) End = Start; } else @@ -8138,7 +8317,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, // overflow, in which case if RHS - Start is a constant, we don't need to // do a max operation since we can just figure it out statically if (NoWrap && isa(Diff)) { - APInt D = dyn_cast(Diff)->getValue()->getValue(); + APInt D = dyn_cast(Diff)->getAPInt(); if (!D.isNegative()) End = Start; } else @@ -8198,15 +8377,14 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, getNoWrapFlags(FlagNW)); if (const auto *ShiftedAddRec = dyn_cast(Shifted)) return ShiftedAddRec->getNumIterationsInRange( - Range.subtract(SC->getValue()->getValue()), SE); + Range.subtract(SC->getAPInt()), SE); // This is strange and shouldn't happen. return SE.getCouldNotCompute(); } // The only time we can solve this is when we have all constant indices. // Otherwise, we cannot determine the overflow conditions. - if (std::any_of(op_begin(), op_end(), - [](const SCEV *Op) { return !isa(Op);})) + if (any_of(operands(), [](const SCEV *Op) { return !isa(Op); })) return SE.getCouldNotCompute(); // Okay at this point we know that all elements of the chrec are constants and @@ -8227,7 +8405,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // If A is negative then the lower of the range is the last possible loop // value. Also note that we already checked for a full range. APInt One(BitWidth,1); - APInt A = cast(getOperand(1))->getValue()->getValue(); + APInt A = cast(getOperand(1))->getAPInt(); APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower(); // The exit value should be (End+A)/A. @@ -8259,15 +8437,13 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, FlagAnyWrap); // Next, solve the constructed addrec - std::pair Roots = - SolveQuadraticEquation(cast(NewAddRec), SE); + auto Roots = SolveQuadraticEquation(cast(NewAddRec), SE); const SCEVConstant *R1 = dyn_cast(Roots.first); const SCEVConstant *R2 = dyn_cast(Roots.second); if (R1) { // Pick the smallest positive root value. - if (ConstantInt *CB = - dyn_cast(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, - R1->getValue(), R2->getValue()))) { + if (ConstantInt *CB = dyn_cast(ConstantExpr::getICmp( + ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (!CB->getZExtValue()) std::swap(R1, R2); // R1 is the minimum root now. @@ -8280,7 +8456,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, if (Range.contains(R1Val->getValue())) { // The next iteration must be out of the range... ConstantInt *NextVal = - ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1); + ConstantInt::get(SE.getContext(), R1->getAPInt() + 1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (!Range.contains(R1Val->getValue())) @@ -8291,7 +8467,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, // If R1 was not in the range, then it is a good return value. Make // sure that R1-1 WAS in the range though, just in case. ConstantInt *NextVal = - ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1); + ConstantInt::get(SE.getContext(), R1->getAPInt() - 1); R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (Range.contains(R1Val->getValue())) return R1; @@ -8527,30 +8703,28 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE, return true; } -namespace { -struct FindParameter { - bool FoundParameter; - FindParameter() : FoundParameter(false) {} - - bool follow(const SCEV *S) { - if (isa(S)) { - FoundParameter = true; - // Stop recursion: we found a parameter. - return false; - } - // Keep looking. - return true; - } - bool isDone() const { - // Stop recursion if we have found a parameter. - return FoundParameter; - } -}; -} - // Returns true when S contains at least a SCEVUnknown parameter. static inline bool containsParameters(const SCEV *S) { + struct FindParameter { + bool FoundParameter; + FindParameter() : FoundParameter(false) {} + + bool follow(const SCEV *S) { + if (isa(S)) { + FoundParameter = true; + // Stop recursion: we found a parameter. + return false; + } + // Keep looking. + return true; + } + bool isDone() const { + // Stop recursion if we have found a parameter. + return FoundParameter; + } + }; + FindParameter F; SCEVTraversal ST(F); ST.visitAll(S); @@ -8893,6 +9067,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) UnsignedRanges(std::move(Arg.UnsignedRanges)), SignedRanges(std::move(Arg.SignedRanges)), UniqueSCEVs(std::move(Arg.UniqueSCEVs)), + UniquePreds(std::move(Arg.UniquePreds)), SCEVAllocator(std::move(Arg.SCEVAllocator)), FirstUnknown(Arg.FirstUnknown) { Arg.FirstUnknown = nullptr; @@ -9067,9 +9242,8 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { // This recurrence is variant w.r.t. L if any of its operands // are variant. - for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end(); - I != E; ++I) - if (!isLoopInvariant(*I, L)) + for (auto *Op : AR->operands()) + if (!isLoopInvariant(Op, L)) return LoopVariant; // Otherwise it's loop-invariant. @@ -9079,11 +9253,9 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { case scMulExpr: case scUMaxExpr: case scSMaxExpr: { - const SCEVNAryExpr *NAry = cast(S); bool HasVarying = false; - for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); - I != E; ++I) { - LoopDisposition D = getLoopDisposition(*I, L); + for (auto *Op : cast(S)->operands()) { + LoopDisposition D = getLoopDisposition(Op, L); if (D == LoopVariant) return LoopVariant; if (D == LoopComputable) @@ -9107,7 +9279,7 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { // invariant if they are not contained in the specified loop. // Instructions are never considered invariant in the function body // (null loop) because they are defined within the "loop". - if (Instruction *I = dyn_cast(cast(S)->getValue())) + if (auto *I = dyn_cast(cast(S)->getValue())) return (L && !L->contains(I)) ? LoopInvariant : LoopVariant; return LoopInvariant; case scCouldNotCompute: @@ -9168,9 +9340,8 @@ ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { case scSMaxExpr: { const SCEVNAryExpr *NAry = cast(S); bool Proper = true; - for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); - I != E; ++I) { - BlockDisposition D = getBlockDisposition(*I, BB); + for (const SCEV *NAryOp : NAry->operands()) { + BlockDisposition D = getBlockDisposition(NAryOp, BB); if (D == DoesNotDominateBlock) return DoesNotDominateBlock; if (D == DominatesBlock) @@ -9214,24 +9385,22 @@ bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) { return getBlockDisposition(S, BB) == ProperlyDominatesBlock; } -namespace { -// Search for a SCEV expression node within an expression tree. -// Implements SCEVTraversal::Visitor. -struct SCEVSearch { - const SCEV *Node; - bool IsFound; +bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { + // Search for a SCEV expression node within an expression tree. + // Implements SCEVTraversal::Visitor. + struct SCEVSearch { + const SCEV *Node; + bool IsFound; - SCEVSearch(const SCEV *N): Node(N), IsFound(false) {} + SCEVSearch(const SCEV *N): Node(N), IsFound(false) {} - bool follow(const SCEV *S) { - IsFound |= (S == Node); - return !IsFound; - } - bool isDone() const { return IsFound; } -}; -} + bool follow(const SCEV *S) { + IsFound |= (S == Node); + return !IsFound; + } + bool isDone() const { return IsFound; } + }; -bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { SCEVSearch Search(Op); visitAll(S, Search); return Search.IsFound; @@ -9270,23 +9439,22 @@ static void replaceSubString(std::string &Str, StringRef From, StringRef To) { /// getLoopBackedgeTakenCounts - Helper method for verifyAnalysis. static void getLoopBackedgeTakenCounts(Loop *L, VerifyMap &Map, ScalarEvolution &SE) { - for (Loop::reverse_iterator I = L->rbegin(), E = L->rend(); I != E; ++I) { - getLoopBackedgeTakenCounts(*I, Map, SE); // recurse. - - std::string &S = Map[L]; - if (S.empty()) { - raw_string_ostream OS(S); - SE.getBackedgeTakenCount(L)->print(OS); + std::string &S = Map[L]; + if (S.empty()) { + raw_string_ostream OS(S); + SE.getBackedgeTakenCount(L)->print(OS); - // false and 0 are semantically equivalent. This can happen in dead loops. - replaceSubString(OS.str(), "false", "0"); - // Remove wrap flags, their use in SCEV is highly fragile. - // FIXME: Remove this when SCEV gets smarter about them. - replaceSubString(OS.str(), "", ""); - replaceSubString(OS.str(), "", ""); - replaceSubString(OS.str(), "", ""); - } + // false and 0 are semantically equivalent. This can happen in dead loops. + replaceSubString(OS.str(), "false", "0"); + // Remove wrap flags, their use in SCEV is highly fragile. + // FIXME: Remove this when SCEV gets smarter about them. + replaceSubString(OS.str(), "", ""); + replaceSubString(OS.str(), "", ""); + replaceSubString(OS.str(), "", ""); } + + for (auto *R : reverse(*L)) + getLoopBackedgeTakenCounts(R, Map, SE); // recurse. } void ScalarEvolution::verify() const { @@ -9396,3 +9564,178 @@ void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequiredTransitive(); AU.addRequiredTransitive(); } + +const SCEVPredicate * +ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS, + const SCEVConstant *RHS) { + FoldingSetNodeID ID; + // Unique this node based on the arguments + ID.AddInteger(SCEVPredicate::P_Equal); + ID.AddPointer(LHS); + ID.AddPointer(RHS); + void *IP = nullptr; + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) + return S; + SCEVEqualPredicate *Eq = new (SCEVAllocator) + SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS); + UniquePreds.InsertNode(Eq, IP); + return Eq; +} + +namespace { +class SCEVPredicateRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, + SCEVUnionPredicate &A) { + SCEVPredicateRewriter Rewriter(SE, A); + return Rewriter.visit(Scev); + } + + SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P) + : SCEVRewriteVisitor(SE), P(P) {} + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + auto ExprPreds = P.getPredicatesForExpr(Expr); + for (auto *Pred : ExprPreds) + if (const auto *IPred = dyn_cast(Pred)) + if (IPred->getLHS() == Expr) + return IPred->getRHS(); + + return Expr; + } + +private: + SCEVUnionPredicate &P; +}; +} // end anonymous namespace + +const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, + SCEVUnionPredicate &Preds) { + return SCEVPredicateRewriter::rewrite(Scev, *this, Preds); +} + +/// SCEV predicates +SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, + SCEVPredicateKind Kind) + : FastID(ID), Kind(Kind) {} + +SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, + const SCEVUnknown *LHS, + const SCEVConstant *RHS) + : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {} + +bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast(N); + + if (!Op) + return false; + + return Op->LHS == LHS && Op->RHS == RHS; +} + +bool SCEVEqualPredicate::isAlwaysTrue() const { return false; } + +const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; } + +void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; +} + +/// Union predicates don't get cached so create a dummy set ID for it. +SCEVUnionPredicate::SCEVUnionPredicate() + : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} + +bool SCEVUnionPredicate::isAlwaysTrue() const { + return all_of(Preds, + [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); +} + +ArrayRef +SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) { + auto I = SCEVToPreds.find(Expr); + if (I == SCEVToPreds.end()) + return ArrayRef(); + return I->second; +} + +bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { + if (const auto *Set = dyn_cast(N)) + return all_of(Set->Preds, + [this](const SCEVPredicate *I) { return this->implies(I); }); + + auto ScevPredsIt = SCEVToPreds.find(N->getExpr()); + if (ScevPredsIt == SCEVToPreds.end()) + return false; + auto &SCEVPreds = ScevPredsIt->second; + + return any_of(SCEVPreds, + [N](const SCEVPredicate *I) { return I->implies(N); }); +} + +const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; } + +void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { + for (auto Pred : Preds) + Pred->print(OS, Depth); +} + +void SCEVUnionPredicate::add(const SCEVPredicate *N) { + if (const auto *Set = dyn_cast(N)) { + for (auto Pred : Set->Preds) + add(Pred); + return; + } + + if (implies(N)) + return; + + const SCEV *Key = N->getExpr(); + assert(Key && "Only SCEVUnionPredicate doesn't have an " + " associated expression!"); + + SCEVToPreds[Key].push_back(N); + Preds.push_back(N); +} + +PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE) + : SE(SE), Generation(0) {} + +const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { + const SCEV *Expr = SE.getSCEV(V); + RewriteEntry &Entry = RewriteMap[Expr]; + + // If we already have an entry and the version matches, return it. + if (Entry.second && Generation == Entry.first) + return Entry.second; + + // We found an entry but it's stale. Rewrite the stale entry + // acording to the current predicate. + if (Entry.second) + Expr = Entry.second; + + const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds); + Entry = {Generation, NewSCEV}; + + return NewSCEV; +} + +void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { + if (Preds.implies(&Pred)) + return; + Preds.add(&Pred); + updateGeneration(); +} + +const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const { + return Preds; +} + +void PredicatedScalarEvolution::updateGeneration() { + // If the generation number wrapped recompute everything. + if (++Generation == 0) { + for (auto &II : RewriteMap) { + const SCEV *Rewritten = II.second.second; + II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)}; + } + } +}