Clarified the SCEV getSmallConstantTripCount interface with in-your-face comments.
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains the implementation of the scalar evolution analysis
11 // engine, which is used primarily to analyze expressions involving induction
12 // variables in loops.
13 //
14 // There are several aspects to this library.  First is the representation of
15 // scalar expressions, which are represented as subclasses of the SCEV class.
16 // These classes are used to represent certain types of subexpressions that we
17 // can handle. We only create one SCEV of a particular shape, so
18 // pointer-comparisons for equality are legal.
19 //
20 // One important aspect of the SCEV objects is that they are never cyclic, even
21 // if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
22 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
23 // recurrence) then we represent it directly as a recurrence node, otherwise we
24 // represent it as a SCEVUnknown node.
25 //
26 // In addition to being able to represent expressions of various types, we also
27 // have folders that are used to build the *canonical* representation for a
28 // particular expression.  These folders are capable of using a variety of
29 // rewrite rules to simplify the expressions.
30 //
31 // Once the folders are defined, we can implement the more interesting
32 // higher-level code, such as the code that recognizes PHI nodes of various
33 // types, computes the execution count of a loop, etc.
34 //
35 // TODO: We should use these routines and value representations to implement
36 // dependence analysis!
37 //
38 //===----------------------------------------------------------------------===//
39 //
40 // There are several good references for the techniques used in this analysis.
41 //
42 //  Chains of recurrences -- a method to expedite the evaluation
43 //  of closed-form functions
44 //  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
45 //
46 //  On computational properties of chains of recurrences
47 //  Eugene V. Zima
48 //
49 //  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
50 //  Robert A. van Engelen
51 //
52 //  Efficient Symbolic Analysis for Optimizing Compilers
53 //  Robert A. van Engelen
54 //
55 //  Using the chains of recurrences algebra for data dependence testing and
56 //  induction variable substitution
57 //  MS Thesis, Johnie Birch
58 //
59 //===----------------------------------------------------------------------===//
60
61 #define DEBUG_TYPE "scalar-evolution"
62 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
63 #include "llvm/Constants.h"
64 #include "llvm/DerivedTypes.h"
65 #include "llvm/GlobalVariable.h"
66 #include "llvm/GlobalAlias.h"
67 #include "llvm/Instructions.h"
68 #include "llvm/LLVMContext.h"
69 #include "llvm/Operator.h"
70 #include "llvm/Analysis/ConstantFolding.h"
71 #include "llvm/Analysis/Dominators.h"
72 #include "llvm/Analysis/InstructionSimplify.h"
73 #include "llvm/Analysis/LoopInfo.h"
74 #include "llvm/Analysis/ValueTracking.h"
75 #include "llvm/Assembly/Writer.h"
76 #include "llvm/Target/TargetData.h"
77 #include "llvm/Target/TargetLibraryInfo.h"
78 #include "llvm/Support/CommandLine.h"
79 #include "llvm/Support/ConstantRange.h"
80 #include "llvm/Support/Debug.h"
81 #include "llvm/Support/ErrorHandling.h"
82 #include "llvm/Support/GetElementPtrTypeIterator.h"
83 #include "llvm/Support/InstIterator.h"
84 #include "llvm/Support/MathExtras.h"
85 #include "llvm/Support/raw_ostream.h"
86 #include "llvm/ADT/Statistic.h"
87 #include "llvm/ADT/STLExtras.h"
88 #include "llvm/ADT/SmallPtrSet.h"
89 #include <algorithm>
90 using namespace llvm;
91
92 STATISTIC(NumArrayLenItCounts,
93           "Number of trip counts computed with array length");
94 STATISTIC(NumTripCountsComputed,
95           "Number of loops with predictable loop counts");
96 STATISTIC(NumTripCountsNotComputed,
97           "Number of loops without predictable loop counts");
98 STATISTIC(NumBruteForceTripCountsComputed,
99           "Number of loops with trip counts computed by force");
100
101 static cl::opt<unsigned>
102 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
103                         cl::desc("Maximum number of iterations SCEV will "
104                                  "symbolically execute a constant "
105                                  "derived loop"),
106                         cl::init(100));
107
108 INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution",
109                 "Scalar Evolution Analysis", false, true)
110 INITIALIZE_PASS_DEPENDENCY(LoopInfo)
111 INITIALIZE_PASS_DEPENDENCY(DominatorTree)
112 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo)
113 INITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution",
114                 "Scalar Evolution Analysis", false, true)
115 char ScalarEvolution::ID = 0;
116
117 //===----------------------------------------------------------------------===//
118 //                           SCEV class definitions
119 //===----------------------------------------------------------------------===//
120
121 //===----------------------------------------------------------------------===//
122 // Implementation of the SCEV class.
123 //
124
125 void SCEV::dump() const {
126   print(dbgs());
127   dbgs() << '\n';
128 }
129
130 void SCEV::print(raw_ostream &OS) const {
131   switch (getSCEVType()) {
132   case scConstant:
133     WriteAsOperand(OS, cast<SCEVConstant>(this)->getValue(), false);
134     return;
135   case scTruncate: {
136     const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
137     const SCEV *Op = Trunc->getOperand();
138     OS << "(trunc " << *Op->getType() << " " << *Op << " to "
139        << *Trunc->getType() << ")";
140     return;
141   }
142   case scZeroExtend: {
143     const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
144     const SCEV *Op = ZExt->getOperand();
145     OS << "(zext " << *Op->getType() << " " << *Op << " to "
146        << *ZExt->getType() << ")";
147     return;
148   }
149   case scSignExtend: {
150     const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
151     const SCEV *Op = SExt->getOperand();
152     OS << "(sext " << *Op->getType() << " " << *Op << " to "
153        << *SExt->getType() << ")";
154     return;
155   }
156   case scAddRecExpr: {
157     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
158     OS << "{" << *AR->getOperand(0);
159     for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
160       OS << ",+," << *AR->getOperand(i);
161     OS << "}<";
162     if (AR->getNoWrapFlags(FlagNUW))
163       OS << "nuw><";
164     if (AR->getNoWrapFlags(FlagNSW))
165       OS << "nsw><";
166     if (AR->getNoWrapFlags(FlagNW) &&
167         !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
168       OS << "nw><";
169     WriteAsOperand(OS, AR->getLoop()->getHeader(), /*PrintType=*/false);
170     OS << ">";
171     return;
172   }
173   case scAddExpr:
174   case scMulExpr:
175   case scUMaxExpr:
176   case scSMaxExpr: {
177     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
178     const char *OpStr = 0;
179     switch (NAry->getSCEVType()) {
180     case scAddExpr: OpStr = " + "; break;
181     case scMulExpr: OpStr = " * "; break;
182     case scUMaxExpr: OpStr = " umax "; break;
183     case scSMaxExpr: OpStr = " smax "; break;
184     }
185     OS << "(";
186     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
187          I != E; ++I) {
188       OS << **I;
189       if (llvm::next(I) != E)
190         OS << OpStr;
191     }
192     OS << ")";
193     switch (NAry->getSCEVType()) {
194     case scAddExpr:
195     case scMulExpr:
196       if (NAry->getNoWrapFlags(FlagNUW))
197         OS << "<nuw>";
198       if (NAry->getNoWrapFlags(FlagNSW))
199         OS << "<nsw>";
200     }
201     return;
202   }
203   case scUDivExpr: {
204     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
205     OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
206     return;
207   }
208   case scUnknown: {
209     const SCEVUnknown *U = cast<SCEVUnknown>(this);
210     Type *AllocTy;
211     if (U->isSizeOf(AllocTy)) {
212       OS << "sizeof(" << *AllocTy << ")";
213       return;
214     }
215     if (U->isAlignOf(AllocTy)) {
216       OS << "alignof(" << *AllocTy << ")";
217       return;
218     }
219
220     Type *CTy;
221     Constant *FieldNo;
222     if (U->isOffsetOf(CTy, FieldNo)) {
223       OS << "offsetof(" << *CTy << ", ";
224       WriteAsOperand(OS, FieldNo, false);
225       OS << ")";
226       return;
227     }
228
229     // Otherwise just print it normally.
230     WriteAsOperand(OS, U->getValue(), false);
231     return;
232   }
233   case scCouldNotCompute:
234     OS << "***COULDNOTCOMPUTE***";
235     return;
236   default: break;
237   }
238   llvm_unreachable("Unknown SCEV kind!");
239 }
240
241 Type *SCEV::getType() const {
242   switch (getSCEVType()) {
243   case scConstant:
244     return cast<SCEVConstant>(this)->getType();
245   case scTruncate:
246   case scZeroExtend:
247   case scSignExtend:
248     return cast<SCEVCastExpr>(this)->getType();
249   case scAddRecExpr:
250   case scMulExpr:
251   case scUMaxExpr:
252   case scSMaxExpr:
253     return cast<SCEVNAryExpr>(this)->getType();
254   case scAddExpr:
255     return cast<SCEVAddExpr>(this)->getType();
256   case scUDivExpr:
257     return cast<SCEVUDivExpr>(this)->getType();
258   case scUnknown:
259     return cast<SCEVUnknown>(this)->getType();
260   case scCouldNotCompute:
261     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
262     return 0;
263   default: break;
264   }
265   llvm_unreachable("Unknown SCEV kind!");
266   return 0;
267 }
268
269 bool SCEV::isZero() const {
270   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
271     return SC->getValue()->isZero();
272   return false;
273 }
274
275 bool SCEV::isOne() const {
276   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
277     return SC->getValue()->isOne();
278   return false;
279 }
280
281 bool SCEV::isAllOnesValue() const {
282   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
283     return SC->getValue()->isAllOnesValue();
284   return false;
285 }
286
287 /// isNonConstantNegative - Return true if the specified scev is negated, but
288 /// not a constant.
289 bool SCEV::isNonConstantNegative() const {
290   const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
291   if (!Mul) return false;
292
293   // If there is a constant factor, it will be first.
294   const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
295   if (!SC) return false;
296
297   // Return true if the value is negative, this matches things like (-42 * V).
298   return SC->getValue()->getValue().isNegative();
299 }
300
301 SCEVCouldNotCompute::SCEVCouldNotCompute() :
302   SCEV(FoldingSetNodeIDRef(), scCouldNotCompute) {}
303
304 bool SCEVCouldNotCompute::classof(const SCEV *S) {
305   return S->getSCEVType() == scCouldNotCompute;
306 }
307
308 const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
309   FoldingSetNodeID ID;
310   ID.AddInteger(scConstant);
311   ID.AddPointer(V);
312   void *IP = 0;
313   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
314   SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
315   UniqueSCEVs.InsertNode(S, IP);
316   return S;
317 }
318
319 const SCEV *ScalarEvolution::getConstant(const APInt& Val) {
320   return getConstant(ConstantInt::get(getContext(), Val));
321 }
322
323 const SCEV *
324 ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
325   IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
326   return getConstant(ConstantInt::get(ITy, V, isSigned));
327 }
328
329 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID,
330                            unsigned SCEVTy, const SCEV *op, Type *ty)
331   : SCEV(ID, SCEVTy), Op(op), Ty(ty) {}
332
333 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
334                                    const SCEV *op, Type *ty)
335   : SCEVCastExpr(ID, scTruncate, op, ty) {
336   assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
337          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
338          "Cannot truncate non-integer value!");
339 }
340
341 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
342                                        const SCEV *op, Type *ty)
343   : SCEVCastExpr(ID, scZeroExtend, op, ty) {
344   assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
345          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
346          "Cannot zero extend non-integer value!");
347 }
348
349 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
350                                        const SCEV *op, Type *ty)
351   : SCEVCastExpr(ID, scSignExtend, op, ty) {
352   assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
353          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
354          "Cannot sign extend non-integer value!");
355 }
356
357 void SCEVUnknown::deleted() {
358   // Clear this SCEVUnknown from various maps.
359   SE->forgetMemoizedResults(this);
360
361   // Remove this SCEVUnknown from the uniquing map.
362   SE->UniqueSCEVs.RemoveNode(this);
363
364   // Release the value.
365   setValPtr(0);
366 }
367
368 void SCEVUnknown::allUsesReplacedWith(Value *New) {
369   // Clear this SCEVUnknown from various maps.
370   SE->forgetMemoizedResults(this);
371
372   // Remove this SCEVUnknown from the uniquing map.
373   SE->UniqueSCEVs.RemoveNode(this);
374
375   // Update this SCEVUnknown to point to the new value. This is needed
376   // because there may still be outstanding SCEVs which still point to
377   // this SCEVUnknown.
378   setValPtr(New);
379 }
380
381 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
382   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
383     if (VCE->getOpcode() == Instruction::PtrToInt)
384       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
385         if (CE->getOpcode() == Instruction::GetElementPtr &&
386             CE->getOperand(0)->isNullValue() &&
387             CE->getNumOperands() == 2)
388           if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
389             if (CI->isOne()) {
390               AllocTy = cast<PointerType>(CE->getOperand(0)->getType())
391                                  ->getElementType();
392               return true;
393             }
394
395   return false;
396 }
397
398 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
399   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
400     if (VCE->getOpcode() == Instruction::PtrToInt)
401       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
402         if (CE->getOpcode() == Instruction::GetElementPtr &&
403             CE->getOperand(0)->isNullValue()) {
404           Type *Ty =
405             cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
406           if (StructType *STy = dyn_cast<StructType>(Ty))
407             if (!STy->isPacked() &&
408                 CE->getNumOperands() == 3 &&
409                 CE->getOperand(1)->isNullValue()) {
410               if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
411                 if (CI->isOne() &&
412                     STy->getNumElements() == 2 &&
413                     STy->getElementType(0)->isIntegerTy(1)) {
414                   AllocTy = STy->getElementType(1);
415                   return true;
416                 }
417             }
418         }
419
420   return false;
421 }
422
423 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
424   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
425     if (VCE->getOpcode() == Instruction::PtrToInt)
426       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
427         if (CE->getOpcode() == Instruction::GetElementPtr &&
428             CE->getNumOperands() == 3 &&
429             CE->getOperand(0)->isNullValue() &&
430             CE->getOperand(1)->isNullValue()) {
431           Type *Ty =
432             cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
433           // Ignore vector types here so that ScalarEvolutionExpander doesn't
434           // emit getelementptrs that index into vectors.
435           if (Ty->isStructTy() || Ty->isArrayTy()) {
436             CTy = Ty;
437             FieldNo = CE->getOperand(2);
438             return true;
439           }
440         }
441
442   return false;
443 }
444
445 //===----------------------------------------------------------------------===//
446 //                               SCEV Utilities
447 //===----------------------------------------------------------------------===//
448
449 namespace {
450   /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
451   /// than the complexity of the RHS.  This comparator is used to canonicalize
452   /// expressions.
453   class SCEVComplexityCompare {
454     const LoopInfo *const LI;
455   public:
456     explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {}
457
458     // Return true or false if LHS is less than, or at least RHS, respectively.
459     bool operator()(const SCEV *LHS, const SCEV *RHS) const {
460       return compare(LHS, RHS) < 0;
461     }
462
463     // Return negative, zero, or positive, if LHS is less than, equal to, or
464     // greater than RHS, respectively. A three-way result allows recursive
465     // comparisons to be more efficient.
466     int compare(const SCEV *LHS, const SCEV *RHS) const {
467       // Fast-path: SCEVs are uniqued so we can do a quick equality check.
468       if (LHS == RHS)
469         return 0;
470
471       // Primarily, sort the SCEVs by their getSCEVType().
472       unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
473       if (LType != RType)
474         return (int)LType - (int)RType;
475
476       // Aside from the getSCEVType() ordering, the particular ordering
477       // isn't very important except that it's beneficial to be consistent,
478       // so that (a + b) and (b + a) don't end up as different expressions.
479       switch (LType) {
480       case scUnknown: {
481         const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
482         const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
483
484         // Sort SCEVUnknown values with some loose heuristics. TODO: This is
485         // not as complete as it could be.
486         const Value *LV = LU->getValue(), *RV = RU->getValue();
487
488         // Order pointer values after integer values. This helps SCEVExpander
489         // form GEPs.
490         bool LIsPointer = LV->getType()->isPointerTy(),
491              RIsPointer = RV->getType()->isPointerTy();
492         if (LIsPointer != RIsPointer)
493           return (int)LIsPointer - (int)RIsPointer;
494
495         // Compare getValueID values.
496         unsigned LID = LV->getValueID(),
497                  RID = RV->getValueID();
498         if (LID != RID)
499           return (int)LID - (int)RID;
500
501         // Sort arguments by their position.
502         if (const Argument *LA = dyn_cast<Argument>(LV)) {
503           const Argument *RA = cast<Argument>(RV);
504           unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
505           return (int)LArgNo - (int)RArgNo;
506         }
507
508         // For instructions, compare their loop depth, and their operand
509         // count.  This is pretty loose.
510         if (const Instruction *LInst = dyn_cast<Instruction>(LV)) {
511           const Instruction *RInst = cast<Instruction>(RV);
512
513           // Compare loop depths.
514           const BasicBlock *LParent = LInst->getParent(),
515                            *RParent = RInst->getParent();
516           if (LParent != RParent) {
517             unsigned LDepth = LI->getLoopDepth(LParent),
518                      RDepth = LI->getLoopDepth(RParent);
519             if (LDepth != RDepth)
520               return (int)LDepth - (int)RDepth;
521           }
522
523           // Compare the number of operands.
524           unsigned LNumOps = LInst->getNumOperands(),
525                    RNumOps = RInst->getNumOperands();
526           return (int)LNumOps - (int)RNumOps;
527         }
528
529         return 0;
530       }
531
532       case scConstant: {
533         const SCEVConstant *LC = cast<SCEVConstant>(LHS);
534         const SCEVConstant *RC = cast<SCEVConstant>(RHS);
535
536         // Compare constant values.
537         const APInt &LA = LC->getValue()->getValue();
538         const APInt &RA = RC->getValue()->getValue();
539         unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
540         if (LBitWidth != RBitWidth)
541           return (int)LBitWidth - (int)RBitWidth;
542         return LA.ult(RA) ? -1 : 1;
543       }
544
545       case scAddRecExpr: {
546         const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
547         const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
548
549         // Compare addrec loop depths.
550         const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
551         if (LLoop != RLoop) {
552           unsigned LDepth = LLoop->getLoopDepth(),
553                    RDepth = RLoop->getLoopDepth();
554           if (LDepth != RDepth)
555             return (int)LDepth - (int)RDepth;
556         }
557
558         // Addrec complexity grows with operand count.
559         unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
560         if (LNumOps != RNumOps)
561           return (int)LNumOps - (int)RNumOps;
562
563         // Lexicographically compare.
564         for (unsigned i = 0; i != LNumOps; ++i) {
565           long X = compare(LA->getOperand(i), RA->getOperand(i));
566           if (X != 0)
567             return X;
568         }
569
570         return 0;
571       }
572
573       case scAddExpr:
574       case scMulExpr:
575       case scSMaxExpr:
576       case scUMaxExpr: {
577         const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
578         const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
579
580         // Lexicographically compare n-ary expressions.
581         unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
582         for (unsigned i = 0; i != LNumOps; ++i) {
583           if (i >= RNumOps)
584             return 1;
585           long X = compare(LC->getOperand(i), RC->getOperand(i));
586           if (X != 0)
587             return X;
588         }
589         return (int)LNumOps - (int)RNumOps;
590       }
591
592       case scUDivExpr: {
593         const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
594         const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
595
596         // Lexicographically compare udiv expressions.
597         long X = compare(LC->getLHS(), RC->getLHS());
598         if (X != 0)
599           return X;
600         return compare(LC->getRHS(), RC->getRHS());
601       }
602
603       case scTruncate:
604       case scZeroExtend:
605       case scSignExtend: {
606         const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
607         const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
608
609         // Compare cast expressions by operand.
610         return compare(LC->getOperand(), RC->getOperand());
611       }
612
613       default:
614         break;
615       }
616
617       llvm_unreachable("Unknown SCEV kind!");
618       return 0;
619     }
620   };
621 }
622
623 /// GroupByComplexity - Given a list of SCEV objects, order them by their
624 /// complexity, and group objects of the same complexity together by value.
625 /// When this routine is finished, we know that any duplicates in the vector are
626 /// consecutive and that complexity is monotonically increasing.
627 ///
628 /// Note that we go take special precautions to ensure that we get deterministic
629 /// results from this routine.  In other words, we don't want the results of
630 /// this to depend on where the addresses of various SCEV objects happened to
631 /// land in memory.
632 ///
633 static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
634                               LoopInfo *LI) {
635   if (Ops.size() < 2) return;  // Noop
636   if (Ops.size() == 2) {
637     // This is the common case, which also happens to be trivially simple.
638     // Special case it.
639     const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
640     if (SCEVComplexityCompare(LI)(RHS, LHS))
641       std::swap(LHS, RHS);
642     return;
643   }
644
645   // Do the rough sort by complexity.
646   std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI));
647
648   // Now that we are sorted by complexity, group elements of the same
649   // complexity.  Note that this is, at worst, N^2, but the vector is likely to
650   // be extremely short in practice.  Note that we take this approach because we
651   // do not want to depend on the addresses of the objects we are grouping.
652   for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
653     const SCEV *S = Ops[i];
654     unsigned Complexity = S->getSCEVType();
655
656     // If there are any objects of the same complexity and same value as this
657     // one, group them.
658     for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
659       if (Ops[j] == S) { // Found a duplicate.
660         // Move it to immediately after i'th element.
661         std::swap(Ops[i+1], Ops[j]);
662         ++i;   // no need to rescan it.
663         if (i == e-2) return;  // Done!
664       }
665     }
666   }
667 }
668
669
670
671 //===----------------------------------------------------------------------===//
672 //                      Simple SCEV method implementations
673 //===----------------------------------------------------------------------===//
674
675 /// BinomialCoefficient - Compute BC(It, K).  The result has width W.
676 /// Assume, K > 0.
677 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
678                                        ScalarEvolution &SE,
679                                        Type *ResultTy) {
680   // Handle the simplest case efficiently.
681   if (K == 1)
682     return SE.getTruncateOrZeroExtend(It, ResultTy);
683
684   // We are using the following formula for BC(It, K):
685   //
686   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
687   //
688   // Suppose, W is the bitwidth of the return value.  We must be prepared for
689   // overflow.  Hence, we must assure that the result of our computation is
690   // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
691   // safe in modular arithmetic.
692   //
693   // However, this code doesn't use exactly that formula; the formula it uses
694   // is something like the following, where T is the number of factors of 2 in
695   // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
696   // exponentiation:
697   //
698   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
699   //
700   // This formula is trivially equivalent to the previous formula.  However,
701   // this formula can be implemented much more efficiently.  The trick is that
702   // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
703   // arithmetic.  To do exact division in modular arithmetic, all we have
704   // to do is multiply by the inverse.  Therefore, this step can be done at
705   // width W.
706   //
707   // The next issue is how to safely do the division by 2^T.  The way this
708   // is done is by doing the multiplication step at a width of at least W + T
709   // bits.  This way, the bottom W+T bits of the product are accurate. Then,
710   // when we perform the division by 2^T (which is equivalent to a right shift
711   // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
712   // truncated out after the division by 2^T.
713   //
714   // In comparison to just directly using the first formula, this technique
715   // is much more efficient; using the first formula requires W * K bits,
716   // but this formula less than W + K bits. Also, the first formula requires
717   // a division step, whereas this formula only requires multiplies and shifts.
718   //
719   // It doesn't matter whether the subtraction step is done in the calculation
720   // width or the input iteration count's width; if the subtraction overflows,
721   // the result must be zero anyway.  We prefer here to do it in the width of
722   // the induction variable because it helps a lot for certain cases; CodeGen
723   // isn't smart enough to ignore the overflow, which leads to much less
724   // efficient code if the width of the subtraction is wider than the native
725   // register width.
726   //
727   // (It's possible to not widen at all by pulling out factors of 2 before
728   // the multiplication; for example, K=2 can be calculated as
729   // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
730   // extra arithmetic, so it's not an obvious win, and it gets
731   // much more complicated for K > 3.)
732
733   // Protection from insane SCEVs; this bound is conservative,
734   // but it probably doesn't matter.
735   if (K > 1000)
736     return SE.getCouldNotCompute();
737
738   unsigned W = SE.getTypeSizeInBits(ResultTy);
739
740   // Calculate K! / 2^T and T; we divide out the factors of two before
741   // multiplying for calculating K! / 2^T to avoid overflow.
742   // Other overflow doesn't matter because we only care about the bottom
743   // W bits of the result.
744   APInt OddFactorial(W, 1);
745   unsigned T = 1;
746   for (unsigned i = 3; i <= K; ++i) {
747     APInt Mult(W, i);
748     unsigned TwoFactors = Mult.countTrailingZeros();
749     T += TwoFactors;
750     Mult = Mult.lshr(TwoFactors);
751     OddFactorial *= Mult;
752   }
753
754   // We need at least W + T bits for the multiplication step
755   unsigned CalculationBits = W + T;
756
757   // Calculate 2^T, at width T+W.
758   APInt DivFactor = APInt(CalculationBits, 1).shl(T);
759
760   // Calculate the multiplicative inverse of K! / 2^T;
761   // this multiplication factor will perform the exact division by
762   // K! / 2^T.
763   APInt Mod = APInt::getSignedMinValue(W+1);
764   APInt MultiplyFactor = OddFactorial.zext(W+1);
765   MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
766   MultiplyFactor = MultiplyFactor.trunc(W);
767
768   // Calculate the product, at width T+W
769   IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
770                                                       CalculationBits);
771   const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
772   for (unsigned i = 1; i != K; ++i) {
773     const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
774     Dividend = SE.getMulExpr(Dividend,
775                              SE.getTruncateOrZeroExtend(S, CalculationTy));
776   }
777
778   // Divide by 2^T
779   const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
780
781   // Truncate the result, and divide by K! / 2^T.
782
783   return SE.getMulExpr(SE.getConstant(MultiplyFactor),
784                        SE.getTruncateOrZeroExtend(DivResult, ResultTy));
785 }
786
787 /// evaluateAtIteration - Return the value of this chain of recurrences at
788 /// the specified iteration number.  We can evaluate this recurrence by
789 /// multiplying each element in the chain by the binomial coefficient
790 /// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
791 ///
792 ///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
793 ///
794 /// where BC(It, k) stands for binomial coefficient.
795 ///
796 const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
797                                                 ScalarEvolution &SE) const {
798   const SCEV *Result = getStart();
799   for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
800     // The computation is correct in the face of overflow provided that the
801     // multiplication is performed _after_ the evaluation of the binomial
802     // coefficient.
803     const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType());
804     if (isa<SCEVCouldNotCompute>(Coeff))
805       return Coeff;
806
807     Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
808   }
809   return Result;
810 }
811
812 //===----------------------------------------------------------------------===//
813 //                    SCEV Expression folder implementations
814 //===----------------------------------------------------------------------===//
815
816 const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
817                                              Type *Ty) {
818   assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
819          "This is not a truncating conversion!");
820   assert(isSCEVable(Ty) &&
821          "This is not a conversion to a SCEVable type!");
822   Ty = getEffectiveSCEVType(Ty);
823
824   FoldingSetNodeID ID;
825   ID.AddInteger(scTruncate);
826   ID.AddPointer(Op);
827   ID.AddPointer(Ty);
828   void *IP = 0;
829   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
830
831   // Fold if the operand is constant.
832   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
833     return getConstant(
834       cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(),
835                                                getEffectiveSCEVType(Ty))));
836
837   // trunc(trunc(x)) --> trunc(x)
838   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
839     return getTruncateExpr(ST->getOperand(), Ty);
840
841   // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
842   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
843     return getTruncateOrSignExtend(SS->getOperand(), Ty);
844
845   // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
846   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
847     return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
848
849   // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can
850   // eliminate all the truncates.
851   if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) {
852     SmallVector<const SCEV *, 4> Operands;
853     bool hasTrunc = false;
854     for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) {
855       const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty);
856       hasTrunc = isa<SCEVTruncateExpr>(S);
857       Operands.push_back(S);
858     }
859     if (!hasTrunc)
860       return getAddExpr(Operands);
861     UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
862   }
863
864   // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can
865   // eliminate all the truncates.
866   if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) {
867     SmallVector<const SCEV *, 4> Operands;
868     bool hasTrunc = false;
869     for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) {
870       const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty);
871       hasTrunc = isa<SCEVTruncateExpr>(S);
872       Operands.push_back(S);
873     }
874     if (!hasTrunc)
875       return getMulExpr(Operands);
876     UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
877   }
878
879   // If the input value is a chrec scev, truncate the chrec's operands.
880   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
881     SmallVector<const SCEV *, 4> Operands;
882     for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
883       Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
884     return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
885   }
886
887   // As a special case, fold trunc(undef) to undef. We don't want to
888   // know too much about SCEVUnknowns, but this special case is handy
889   // and harmless.
890   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Op))
891     if (isa<UndefValue>(U->getValue()))
892       return getSCEV(UndefValue::get(Ty));
893
894   // The cast wasn't folded; create an explicit cast node. We can reuse
895   // the existing insert position since if we get here, we won't have
896   // made any changes which would invalidate it.
897   SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
898                                                  Op, Ty);
899   UniqueSCEVs.InsertNode(S, IP);
900   return S;
901 }
902
903 const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
904                                                Type *Ty) {
905   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
906          "This is not an extending conversion!");
907   assert(isSCEVable(Ty) &&
908          "This is not a conversion to a SCEVable type!");
909   Ty = getEffectiveSCEVType(Ty);
910
911   // Fold if the operand is constant.
912   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
913     return getConstant(
914       cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(),
915                                               getEffectiveSCEVType(Ty))));
916
917   // zext(zext(x)) --> zext(x)
918   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
919     return getZeroExtendExpr(SZ->getOperand(), Ty);
920
921   // Before doing any expensive analysis, check to see if we've already
922   // computed a SCEV for this Op and Ty.
923   FoldingSetNodeID ID;
924   ID.AddInteger(scZeroExtend);
925   ID.AddPointer(Op);
926   ID.AddPointer(Ty);
927   void *IP = 0;
928   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
929
930   // zext(trunc(x)) --> zext(x) or x or trunc(x)
931   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
932     // It's possible the bits taken off by the truncate were all zero bits. If
933     // so, we should be able to simplify this further.
934     const SCEV *X = ST->getOperand();
935     ConstantRange CR = getUnsignedRange(X);
936     unsigned TruncBits = getTypeSizeInBits(ST->getType());
937     unsigned NewBits = getTypeSizeInBits(Ty);
938     if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
939             CR.zextOrTrunc(NewBits)))
940       return getTruncateOrZeroExtend(X, Ty);
941   }
942
943   // If the input value is a chrec scev, and we can prove that the value
944   // did not overflow the old, smaller, value, we can zero extend all of the
945   // operands (often constants).  This allows analysis of something like
946   // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
947   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
948     if (AR->isAffine()) {
949       const SCEV *Start = AR->getStart();
950       const SCEV *Step = AR->getStepRecurrence(*this);
951       unsigned BitWidth = getTypeSizeInBits(AR->getType());
952       const Loop *L = AR->getLoop();
953
954       // If we have special knowledge that this addrec won't overflow,
955       // we don't need to do any further analysis.
956       if (AR->getNoWrapFlags(SCEV::FlagNUW))
957         return getAddRecExpr(getZeroExtendExpr(Start, Ty),
958                              getZeroExtendExpr(Step, Ty),
959                              L, AR->getNoWrapFlags());
960
961       // Check whether the backedge-taken count is SCEVCouldNotCompute.
962       // Note that this serves two purposes: It filters out loops that are
963       // simply not analyzable, and it covers the case where this code is
964       // being called from within backedge-taken count analysis, such that
965       // attempting to ask for the backedge-taken count would likely result
966       // in infinite recursion. In the later case, the analysis code will
967       // cope with a conservative value, and it will take care to purge
968       // that value once it has finished.
969       const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
970       if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
971         // Manually compute the final value for AR, checking for
972         // overflow.
973
974         // Check whether the backedge-taken count can be losslessly casted to
975         // the addrec's type. The count is always unsigned.
976         const SCEV *CastedMaxBECount =
977           getTruncateOrZeroExtend(MaxBECount, Start->getType());
978         const SCEV *RecastedMaxBECount =
979           getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
980         if (MaxBECount == RecastedMaxBECount) {
981           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
982           // Check whether Start+Step*MaxBECount has no unsigned overflow.
983           const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step);
984           const SCEV *Add = getAddExpr(Start, ZMul);
985           const SCEV *OperandExtendedAdd =
986             getAddExpr(getZeroExtendExpr(Start, WideTy),
987                        getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
988                                   getZeroExtendExpr(Step, WideTy)));
989           if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) {
990             // Cache knowledge of AR NUW, which is propagated to this AddRec.
991             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
992             // Return the expression with the addrec on the outside.
993             return getAddRecExpr(getZeroExtendExpr(Start, Ty),
994                                  getZeroExtendExpr(Step, Ty),
995                                  L, AR->getNoWrapFlags());
996           }
997           // Similar to above, only this time treat the step value as signed.
998           // This covers loops that count down.
999           const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
1000           Add = getAddExpr(Start, SMul);
1001           OperandExtendedAdd =
1002             getAddExpr(getZeroExtendExpr(Start, WideTy),
1003                        getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1004                                   getSignExtendExpr(Step, WideTy)));
1005           if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) {
1006             // Cache knowledge of AR NW, which is propagated to this AddRec.
1007             // Negative step causes unsigned wrap, but it still can't self-wrap.
1008             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1009             // Return the expression with the addrec on the outside.
1010             return getAddRecExpr(getZeroExtendExpr(Start, Ty),
1011                                  getSignExtendExpr(Step, Ty),
1012                                  L, AR->getNoWrapFlags());
1013           }
1014         }
1015
1016         // If the backedge is guarded by a comparison with the pre-inc value
1017         // the addrec is safe. Also, if the entry is guarded by a comparison
1018         // with the start value and the backedge is guarded by a comparison
1019         // with the post-inc value, the addrec is safe.
1020         if (isKnownPositive(Step)) {
1021           const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
1022                                       getUnsignedRange(Step).getUnsignedMax());
1023           if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
1024               (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) &&
1025                isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT,
1026                                            AR->getPostIncExpr(*this), N))) {
1027             // Cache knowledge of AR NUW, which is propagated to this AddRec.
1028             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1029             // Return the expression with the addrec on the outside.
1030             return getAddRecExpr(getZeroExtendExpr(Start, Ty),
1031                                  getZeroExtendExpr(Step, Ty),
1032                                  L, AR->getNoWrapFlags());
1033           }
1034         } else if (isKnownNegative(Step)) {
1035           const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
1036                                       getSignedRange(Step).getSignedMin());
1037           if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
1038               (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) &&
1039                isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT,
1040                                            AR->getPostIncExpr(*this), N))) {
1041             // Cache knowledge of AR NW, which is propagated to this AddRec.
1042             // Negative step causes unsigned wrap, but it still can't self-wrap.
1043             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1044             // Return the expression with the addrec on the outside.
1045             return getAddRecExpr(getZeroExtendExpr(Start, Ty),
1046                                  getSignExtendExpr(Step, Ty),
1047                                  L, AR->getNoWrapFlags());
1048           }
1049         }
1050       }
1051     }
1052
1053   // The cast wasn't folded; create an explicit cast node.
1054   // Recompute the insert position, as it may have been invalidated.
1055   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1056   SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1057                                                    Op, Ty);
1058   UniqueSCEVs.InsertNode(S, IP);
1059   return S;
1060 }
1061
1062 // Get the limit of a recurrence such that incrementing by Step cannot cause
1063 // signed overflow as long as the value of the recurrence within the loop does
1064 // not exceed this limit before incrementing.
1065 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1066                                            ICmpInst::Predicate *Pred,
1067                                            ScalarEvolution *SE) {
1068   unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1069   if (SE->isKnownPositive(Step)) {
1070     *Pred = ICmpInst::ICMP_SLT;
1071     return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
1072                            SE->getSignedRange(Step).getSignedMax());
1073   }
1074   if (SE->isKnownNegative(Step)) {
1075     *Pred = ICmpInst::ICMP_SGT;
1076     return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
1077                        SE->getSignedRange(Step).getSignedMin());
1078   }
1079   return 0;
1080 }
1081
1082 // The recurrence AR has been shown to have no signed wrap. Typically, if we can
1083 // prove NSW for AR, then we can just as easily prove NSW for its preincrement
1084 // or postincrement sibling. This allows normalizing a sign extended AddRec as
1085 // such: {sext(Step + Start),+,Step} => {(Step + sext(Start),+,Step} As a
1086 // result, the expression "Step + sext(PreIncAR)" is congruent with
1087 // "sext(PostIncAR)"
1088 static const SCEV *getPreStartForSignExtend(const SCEVAddRecExpr *AR,
1089                                             Type *Ty,
1090                                             ScalarEvolution *SE) {
1091   const Loop *L = AR->getLoop();
1092   const SCEV *Start = AR->getStart();
1093   const SCEV *Step = AR->getStepRecurrence(*SE);
1094
1095   // Check for a simple looking step prior to loop entry.
1096   const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1097   if (!SA)
1098     return 0;
1099
1100   // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1101   // subtraction is expensive. For this purpose, perform a quick and dirty
1102   // difference, by checking for Step in the operand list.
1103   SmallVector<const SCEV *, 4> DiffOps;
1104   for (SCEVAddExpr::op_iterator I = SA->op_begin(), E = SA->op_end();
1105        I != E; ++I) {
1106     if (*I != Step)
1107       DiffOps.push_back(*I);
1108   }
1109   if (DiffOps.size() == SA->getNumOperands())
1110     return 0;
1111
1112   // This is a postinc AR. Check for overflow on the preinc recurrence using the
1113   // same three conditions that getSignExtendedExpr checks.
1114
1115   // 1. NSW flags on the step increment.
1116   const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags());
1117   const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1118     SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1119
1120   if (PreAR && PreAR->getNoWrapFlags(SCEV::FlagNSW))
1121     return PreStart;
1122
1123   // 2. Direct overflow check on the step operation's expression.
1124   unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1125   Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1126   const SCEV *OperandExtendedStart =
1127     SE->getAddExpr(SE->getSignExtendExpr(PreStart, WideTy),
1128                    SE->getSignExtendExpr(Step, WideTy));
1129   if (SE->getSignExtendExpr(Start, WideTy) == OperandExtendedStart) {
1130     // Cache knowledge of PreAR NSW.
1131     if (PreAR)
1132       const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(SCEV::FlagNSW);
1133     // FIXME: this optimization needs a unit test
1134     DEBUG(dbgs() << "SCEV: untested prestart overflow check\n");
1135     return PreStart;
1136   }
1137
1138   // 3. Loop precondition.
1139   ICmpInst::Predicate Pred;
1140   const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, SE);
1141
1142   if (OverflowLimit &&
1143       SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) {
1144     return PreStart;
1145   }
1146   return 0;
1147 }
1148
1149 // Get the normalized sign-extended expression for this AddRec's Start.
1150 static const SCEV *getSignExtendAddRecStart(const SCEVAddRecExpr *AR,
1151                                             Type *Ty,
1152                                             ScalarEvolution *SE) {
1153   const SCEV *PreStart = getPreStartForSignExtend(AR, Ty, SE);
1154   if (!PreStart)
1155     return SE->getSignExtendExpr(AR->getStart(), Ty);
1156
1157   return SE->getAddExpr(SE->getSignExtendExpr(AR->getStepRecurrence(*SE), Ty),
1158                         SE->getSignExtendExpr(PreStart, Ty));
1159 }
1160
1161 const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
1162                                                Type *Ty) {
1163   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1164          "This is not an extending conversion!");
1165   assert(isSCEVable(Ty) &&
1166          "This is not a conversion to a SCEVable type!");
1167   Ty = getEffectiveSCEVType(Ty);
1168
1169   // Fold if the operand is constant.
1170   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1171     return getConstant(
1172       cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(),
1173                                               getEffectiveSCEVType(Ty))));
1174
1175   // sext(sext(x)) --> sext(x)
1176   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1177     return getSignExtendExpr(SS->getOperand(), Ty);
1178
1179   // sext(zext(x)) --> zext(x)
1180   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1181     return getZeroExtendExpr(SZ->getOperand(), Ty);
1182
1183   // Before doing any expensive analysis, check to see if we've already
1184   // computed a SCEV for this Op and Ty.
1185   FoldingSetNodeID ID;
1186   ID.AddInteger(scSignExtend);
1187   ID.AddPointer(Op);
1188   ID.AddPointer(Ty);
1189   void *IP = 0;
1190   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1191
1192   // If the input value is provably positive, build a zext instead.
1193   if (isKnownNonNegative(Op))
1194     return getZeroExtendExpr(Op, Ty);
1195
1196   // sext(trunc(x)) --> sext(x) or x or trunc(x)
1197   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1198     // It's possible the bits taken off by the truncate were all sign bits. If
1199     // so, we should be able to simplify this further.
1200     const SCEV *X = ST->getOperand();
1201     ConstantRange CR = getSignedRange(X);
1202     unsigned TruncBits = getTypeSizeInBits(ST->getType());
1203     unsigned NewBits = getTypeSizeInBits(Ty);
1204     if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1205             CR.sextOrTrunc(NewBits)))
1206       return getTruncateOrSignExtend(X, Ty);
1207   }
1208
1209   // If the input value is a chrec scev, and we can prove that the value
1210   // did not overflow the old, smaller, value, we can sign extend all of the
1211   // operands (often constants).  This allows analysis of something like
1212   // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
1213   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1214     if (AR->isAffine()) {
1215       const SCEV *Start = AR->getStart();
1216       const SCEV *Step = AR->getStepRecurrence(*this);
1217       unsigned BitWidth = getTypeSizeInBits(AR->getType());
1218       const Loop *L = AR->getLoop();
1219
1220       // If we have special knowledge that this addrec won't overflow,
1221       // we don't need to do any further analysis.
1222       if (AR->getNoWrapFlags(SCEV::FlagNSW))
1223         return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1224                              getSignExtendExpr(Step, Ty),
1225                              L, SCEV::FlagNSW);
1226
1227       // Check whether the backedge-taken count is SCEVCouldNotCompute.
1228       // Note that this serves two purposes: It filters out loops that are
1229       // simply not analyzable, and it covers the case where this code is
1230       // being called from within backedge-taken count analysis, such that
1231       // attempting to ask for the backedge-taken count would likely result
1232       // in infinite recursion. In the later case, the analysis code will
1233       // cope with a conservative value, and it will take care to purge
1234       // that value once it has finished.
1235       const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
1236       if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1237         // Manually compute the final value for AR, checking for
1238         // overflow.
1239
1240         // Check whether the backedge-taken count can be losslessly casted to
1241         // the addrec's type. The count is always unsigned.
1242         const SCEV *CastedMaxBECount =
1243           getTruncateOrZeroExtend(MaxBECount, Start->getType());
1244         const SCEV *RecastedMaxBECount =
1245           getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
1246         if (MaxBECount == RecastedMaxBECount) {
1247           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1248           // Check whether Start+Step*MaxBECount has no signed overflow.
1249           const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
1250           const SCEV *Add = getAddExpr(Start, SMul);
1251           const SCEV *OperandExtendedAdd =
1252             getAddExpr(getSignExtendExpr(Start, WideTy),
1253                        getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1254                                   getSignExtendExpr(Step, WideTy)));
1255           if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) {
1256             // Cache knowledge of AR NSW, which is propagated to this AddRec.
1257             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1258             // Return the expression with the addrec on the outside.
1259             return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1260                                  getSignExtendExpr(Step, Ty),
1261                                  L, AR->getNoWrapFlags());
1262           }
1263           // Similar to above, only this time treat the step value as unsigned.
1264           // This covers loops that count up with an unsigned step.
1265           const SCEV *UMul = getMulExpr(CastedMaxBECount, Step);
1266           Add = getAddExpr(Start, UMul);
1267           OperandExtendedAdd =
1268             getAddExpr(getSignExtendExpr(Start, WideTy),
1269                        getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1270                                   getZeroExtendExpr(Step, WideTy)));
1271           if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) {
1272             // Cache knowledge of AR NSW, which is propagated to this AddRec.
1273             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1274             // Return the expression with the addrec on the outside.
1275             return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1276                                  getZeroExtendExpr(Step, Ty),
1277                                  L, AR->getNoWrapFlags());
1278           }
1279         }
1280
1281         // If the backedge is guarded by a comparison with the pre-inc value
1282         // the addrec is safe. Also, if the entry is guarded by a comparison
1283         // with the start value and the backedge is guarded by a comparison
1284         // with the post-inc value, the addrec is safe.
1285         ICmpInst::Predicate Pred;
1286         const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, this);
1287         if (OverflowLimit &&
1288             (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
1289              (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) &&
1290               isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this),
1291                                           OverflowLimit)))) {
1292           // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
1293           const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1294           return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1295                                getSignExtendExpr(Step, Ty),
1296                                L, AR->getNoWrapFlags());
1297         }
1298       }
1299     }
1300
1301   // The cast wasn't folded; create an explicit cast node.
1302   // Recompute the insert position, as it may have been invalidated.
1303   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1304   SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1305                                                    Op, Ty);
1306   UniqueSCEVs.InsertNode(S, IP);
1307   return S;
1308 }
1309
1310 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
1311 /// unspecified bits out to the given type.
1312 ///
1313 const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
1314                                               Type *Ty) {
1315   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1316          "This is not an extending conversion!");
1317   assert(isSCEVable(Ty) &&
1318          "This is not a conversion to a SCEVable type!");
1319   Ty = getEffectiveSCEVType(Ty);
1320
1321   // Sign-extend negative constants.
1322   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1323     if (SC->getValue()->getValue().isNegative())
1324       return getSignExtendExpr(Op, Ty);
1325
1326   // Peel off a truncate cast.
1327   if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
1328     const SCEV *NewOp = T->getOperand();
1329     if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
1330       return getAnyExtendExpr(NewOp, Ty);
1331     return getTruncateOrNoop(NewOp, Ty);
1332   }
1333
1334   // Next try a zext cast. If the cast is folded, use it.
1335   const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
1336   if (!isa<SCEVZeroExtendExpr>(ZExt))
1337     return ZExt;
1338
1339   // Next try a sext cast. If the cast is folded, use it.
1340   const SCEV *SExt = getSignExtendExpr(Op, Ty);
1341   if (!isa<SCEVSignExtendExpr>(SExt))
1342     return SExt;
1343
1344   // Force the cast to be folded into the operands of an addrec.
1345   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
1346     SmallVector<const SCEV *, 4> Ops;
1347     for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
1348          I != E; ++I)
1349       Ops.push_back(getAnyExtendExpr(*I, Ty));
1350     return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
1351   }
1352
1353   // As a special case, fold anyext(undef) to undef. We don't want to
1354   // know too much about SCEVUnknowns, but this special case is handy
1355   // and harmless.
1356   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Op))
1357     if (isa<UndefValue>(U->getValue()))
1358       return getSCEV(UndefValue::get(Ty));
1359
1360   // If the expression is obviously signed, use the sext cast value.
1361   if (isa<SCEVSMaxExpr>(Op))
1362     return SExt;
1363
1364   // Absent any other information, use the zext cast value.
1365   return ZExt;
1366 }
1367
1368 /// CollectAddOperandsWithScales - Process the given Ops list, which is
1369 /// a list of operands to be added under the given scale, update the given
1370 /// map. This is a helper function for getAddRecExpr. As an example of
1371 /// what it does, given a sequence of operands that would form an add
1372 /// expression like this:
1373 ///
1374 ///    m + n + 13 + (A * (o + p + (B * q + m + 29))) + r + (-1 * r)
1375 ///
1376 /// where A and B are constants, update the map with these values:
1377 ///
1378 ///    (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
1379 ///
1380 /// and add 13 + A*B*29 to AccumulatedConstant.
1381 /// This will allow getAddRecExpr to produce this:
1382 ///
1383 ///    13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
1384 ///
1385 /// This form often exposes folding opportunities that are hidden in
1386 /// the original operand list.
1387 ///
1388 /// Return true iff it appears that any interesting folding opportunities
1389 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
1390 /// the common case where no interesting opportunities are present, and
1391 /// is also used as a check to avoid infinite recursion.
1392 ///
1393 static bool
1394 CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
1395                              SmallVector<const SCEV *, 8> &NewOps,
1396                              APInt &AccumulatedConstant,
1397                              const SCEV *const *Ops, size_t NumOperands,
1398                              const APInt &Scale,
1399                              ScalarEvolution &SE) {
1400   bool Interesting = false;
1401
1402   // Iterate over the add operands. They are sorted, with constants first.
1403   unsigned i = 0;
1404   while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1405     ++i;
1406     // Pull a buried constant out to the outside.
1407     if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
1408       Interesting = true;
1409     AccumulatedConstant += Scale * C->getValue()->getValue();
1410   }
1411
1412   // Next comes everything else. We're especially interested in multiplies
1413   // here, but they're in the middle, so just visit the rest with one loop.
1414   for (; i != NumOperands; ++i) {
1415     const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
1416     if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
1417       APInt NewScale =
1418         Scale * cast<SCEVConstant>(Mul->getOperand(0))->getValue()->getValue();
1419       if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
1420         // A multiplication of a constant with another add; recurse.
1421         const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
1422         Interesting |=
1423           CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1424                                        Add->op_begin(), Add->getNumOperands(),
1425                                        NewScale, SE);
1426       } else {
1427         // A multiplication of a constant with some other value. Update
1428         // the map.
1429         SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
1430         const SCEV *Key = SE.getMulExpr(MulOps);
1431         std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1432           M.insert(std::make_pair(Key, NewScale));
1433         if (Pair.second) {
1434           NewOps.push_back(Pair.first->first);
1435         } else {
1436           Pair.first->second += NewScale;
1437           // The map already had an entry for this value, which may indicate
1438           // a folding opportunity.
1439           Interesting = true;
1440         }
1441       }
1442     } else {
1443       // An ordinary operand. Update the map.
1444       std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1445         M.insert(std::make_pair(Ops[i], Scale));
1446       if (Pair.second) {
1447         NewOps.push_back(Pair.first->first);
1448       } else {
1449         Pair.first->second += Scale;
1450         // The map already had an entry for this value, which may indicate
1451         // a folding opportunity.
1452         Interesting = true;
1453       }
1454     }
1455   }
1456
1457   return Interesting;
1458 }
1459
1460 namespace {
1461   struct APIntCompare {
1462     bool operator()(const APInt &LHS, const APInt &RHS) const {
1463       return LHS.ult(RHS);
1464     }
1465   };
1466 }
1467
1468 /// getAddExpr - Get a canonical add expression, or something simpler if
1469 /// possible.
1470 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
1471                                         SCEV::NoWrapFlags Flags) {
1472   assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
1473          "only nuw or nsw allowed");
1474   assert(!Ops.empty() && "Cannot get empty add!");
1475   if (Ops.size() == 1) return Ops[0];
1476 #ifndef NDEBUG
1477   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
1478   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1479     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
1480            "SCEVAddExpr operand types don't match!");
1481 #endif
1482
1483   // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
1484   // And vice-versa.
1485   int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
1486   SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
1487   if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
1488     bool All = true;
1489     for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
1490          E = Ops.end(); I != E; ++I)
1491       if (!isKnownNonNegative(*I)) {
1492         All = false;
1493         break;
1494       }
1495     if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
1496   }
1497
1498   // Sort by complexity, this groups all similar expression types together.
1499   GroupByComplexity(Ops, LI);
1500
1501   // If there are any constants, fold them together.
1502   unsigned Idx = 0;
1503   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1504     ++Idx;
1505     assert(Idx < Ops.size());
1506     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1507       // We found two constants, fold them together!
1508       Ops[0] = getConstant(LHSC->getValue()->getValue() +
1509                            RHSC->getValue()->getValue());
1510       if (Ops.size() == 2) return Ops[0];
1511       Ops.erase(Ops.begin()+1);  // Erase the folded element
1512       LHSC = cast<SCEVConstant>(Ops[0]);
1513     }
1514
1515     // If we are left with a constant zero being added, strip it off.
1516     if (LHSC->getValue()->isZero()) {
1517       Ops.erase(Ops.begin());
1518       --Idx;
1519     }
1520
1521     if (Ops.size() == 1) return Ops[0];
1522   }
1523
1524   // Okay, check to see if the same value occurs in the operand list more than
1525   // once.  If so, merge them together into an multiply expression.  Since we
1526   // sorted the list, these values are required to be adjacent.
1527   Type *Ty = Ops[0]->getType();
1528   bool FoundMatch = false;
1529   for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
1530     if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
1531       // Scan ahead to count how many equal operands there are.
1532       unsigned Count = 2;
1533       while (i+Count != e && Ops[i+Count] == Ops[i])
1534         ++Count;
1535       // Merge the values into a multiply.
1536       const SCEV *Scale = getConstant(Ty, Count);
1537       const SCEV *Mul = getMulExpr(Scale, Ops[i]);
1538       if (Ops.size() == Count)
1539         return Mul;
1540       Ops[i] = Mul;
1541       Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
1542       --i; e -= Count - 1;
1543       FoundMatch = true;
1544     }
1545   if (FoundMatch)
1546     return getAddExpr(Ops, Flags);
1547
1548   // Check for truncates. If all the operands are truncated from the same
1549   // type, see if factoring out the truncate would permit the result to be
1550   // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
1551   // if the contents of the resulting outer trunc fold to something simple.
1552   for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
1553     const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
1554     Type *DstType = Trunc->getType();
1555     Type *SrcType = Trunc->getOperand()->getType();
1556     SmallVector<const SCEV *, 8> LargeOps;
1557     bool Ok = true;
1558     // Check all the operands to see if they can be represented in the
1559     // source type of the truncate.
1560     for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
1561       if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
1562         if (T->getOperand()->getType() != SrcType) {
1563           Ok = false;
1564           break;
1565         }
1566         LargeOps.push_back(T->getOperand());
1567       } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1568         LargeOps.push_back(getAnyExtendExpr(C, SrcType));
1569       } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
1570         SmallVector<const SCEV *, 8> LargeMulOps;
1571         for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
1572           if (const SCEVTruncateExpr *T =
1573                 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
1574             if (T->getOperand()->getType() != SrcType) {
1575               Ok = false;
1576               break;
1577             }
1578             LargeMulOps.push_back(T->getOperand());
1579           } else if (const SCEVConstant *C =
1580                        dyn_cast<SCEVConstant>(M->getOperand(j))) {
1581             LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
1582           } else {
1583             Ok = false;
1584             break;
1585           }
1586         }
1587         if (Ok)
1588           LargeOps.push_back(getMulExpr(LargeMulOps));
1589       } else {
1590         Ok = false;
1591         break;
1592       }
1593     }
1594     if (Ok) {
1595       // Evaluate the expression in the larger type.
1596       const SCEV *Fold = getAddExpr(LargeOps, Flags);
1597       // If it folds to something simple, use it. Otherwise, don't.
1598       if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
1599         return getTruncateExpr(Fold, DstType);
1600     }
1601   }
1602
1603   // Skip past any other cast SCEVs.
1604   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
1605     ++Idx;
1606
1607   // If there are add operands they would be next.
1608   if (Idx < Ops.size()) {
1609     bool DeletedAdd = false;
1610     while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
1611       // If we have an add, expand the add operands onto the end of the operands
1612       // list.
1613       Ops.erase(Ops.begin()+Idx);
1614       Ops.append(Add->op_begin(), Add->op_end());
1615       DeletedAdd = true;
1616     }
1617
1618     // If we deleted at least one add, we added operands to the end of the list,
1619     // and they are not necessarily sorted.  Recurse to resort and resimplify
1620     // any operands we just acquired.
1621     if (DeletedAdd)
1622       return getAddExpr(Ops);
1623   }
1624
1625   // Skip over the add expression until we get to a multiply.
1626   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1627     ++Idx;
1628
1629   // Check to see if there are any folding opportunities present with
1630   // operands multiplied by constant values.
1631   if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
1632     uint64_t BitWidth = getTypeSizeInBits(Ty);
1633     DenseMap<const SCEV *, APInt> M;
1634     SmallVector<const SCEV *, 8> NewOps;
1635     APInt AccumulatedConstant(BitWidth, 0);
1636     if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1637                                      Ops.data(), Ops.size(),
1638                                      APInt(BitWidth, 1), *this)) {
1639       // Some interesting folding opportunity is present, so its worthwhile to
1640       // re-generate the operands list. Group the operands by constant scale,
1641       // to avoid multiplying by the same constant scale multiple times.
1642       std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
1643       for (SmallVector<const SCEV *, 8>::const_iterator I = NewOps.begin(),
1644            E = NewOps.end(); I != E; ++I)
1645         MulOpLists[M.find(*I)->second].push_back(*I);
1646       // Re-generate the operands list.
1647       Ops.clear();
1648       if (AccumulatedConstant != 0)
1649         Ops.push_back(getConstant(AccumulatedConstant));
1650       for (std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare>::iterator
1651            I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I)
1652         if (I->first != 0)
1653           Ops.push_back(getMulExpr(getConstant(I->first),
1654                                    getAddExpr(I->second)));
1655       if (Ops.empty())
1656         return getConstant(Ty, 0);
1657       if (Ops.size() == 1)
1658         return Ops[0];
1659       return getAddExpr(Ops);
1660     }
1661   }
1662
1663   // If we are adding something to a multiply expression, make sure the
1664   // something is not already an operand of the multiply.  If so, merge it into
1665   // the multiply.
1666   for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
1667     const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
1668     for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
1669       const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
1670       if (isa<SCEVConstant>(MulOpSCEV))
1671         continue;
1672       for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
1673         if (MulOpSCEV == Ops[AddOp]) {
1674           // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
1675           const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
1676           if (Mul->getNumOperands() != 2) {
1677             // If the multiply has more than two operands, we must get the
1678             // Y*Z term.
1679             SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
1680                                                 Mul->op_begin()+MulOp);
1681             MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
1682             InnerMul = getMulExpr(MulOps);
1683           }
1684           const SCEV *One = getConstant(Ty, 1);
1685           const SCEV *AddOne = getAddExpr(One, InnerMul);
1686           const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV);
1687           if (Ops.size() == 2) return OuterMul;
1688           if (AddOp < Idx) {
1689             Ops.erase(Ops.begin()+AddOp);
1690             Ops.erase(Ops.begin()+Idx-1);
1691           } else {
1692             Ops.erase(Ops.begin()+Idx);
1693             Ops.erase(Ops.begin()+AddOp-1);
1694           }
1695           Ops.push_back(OuterMul);
1696           return getAddExpr(Ops);
1697         }
1698
1699       // Check this multiply against other multiplies being added together.
1700       for (unsigned OtherMulIdx = Idx+1;
1701            OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
1702            ++OtherMulIdx) {
1703         const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
1704         // If MulOp occurs in OtherMul, we can fold the two multiplies
1705         // together.
1706         for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
1707              OMulOp != e; ++OMulOp)
1708           if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
1709             // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
1710             const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
1711             if (Mul->getNumOperands() != 2) {
1712               SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
1713                                                   Mul->op_begin()+MulOp);
1714               MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
1715               InnerMul1 = getMulExpr(MulOps);
1716             }
1717             const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
1718             if (OtherMul->getNumOperands() != 2) {
1719               SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
1720                                                   OtherMul->op_begin()+OMulOp);
1721               MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
1722               InnerMul2 = getMulExpr(MulOps);
1723             }
1724             const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
1725             const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
1726             if (Ops.size() == 2) return OuterMul;
1727             Ops.erase(Ops.begin()+Idx);
1728             Ops.erase(Ops.begin()+OtherMulIdx-1);
1729             Ops.push_back(OuterMul);
1730             return getAddExpr(Ops);
1731           }
1732       }
1733     }
1734   }
1735
1736   // If there are any add recurrences in the operands list, see if any other
1737   // added values are loop invariant.  If so, we can fold them into the
1738   // recurrence.
1739   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1740     ++Idx;
1741
1742   // Scan over all recurrences, trying to fold loop invariants into them.
1743   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1744     // Scan all of the other operands to this add and add them to the vector if
1745     // they are loop invariant w.r.t. the recurrence.
1746     SmallVector<const SCEV *, 8> LIOps;
1747     const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1748     const Loop *AddRecLoop = AddRec->getLoop();
1749     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1750       if (isLoopInvariant(Ops[i], AddRecLoop)) {
1751         LIOps.push_back(Ops[i]);
1752         Ops.erase(Ops.begin()+i);
1753         --i; --e;
1754       }
1755
1756     // If we found some loop invariants, fold them into the recurrence.
1757     if (!LIOps.empty()) {
1758       //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
1759       LIOps.push_back(AddRec->getStart());
1760
1761       SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
1762                                              AddRec->op_end());
1763       AddRecOps[0] = getAddExpr(LIOps);
1764
1765       // Build the new addrec. Propagate the NUW and NSW flags if both the
1766       // outer add and the inner addrec are guaranteed to have no overflow.
1767       // Always propagate NW.
1768       Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
1769       const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
1770
1771       // If all of the other operands were loop invariant, we are done.
1772       if (Ops.size() == 1) return NewRec;
1773
1774       // Otherwise, add the folded AddRec by the non-invariant parts.
1775       for (unsigned i = 0;; ++i)
1776         if (Ops[i] == AddRec) {
1777           Ops[i] = NewRec;
1778           break;
1779         }
1780       return getAddExpr(Ops);
1781     }
1782
1783     // Okay, if there weren't any loop invariants to be folded, check to see if
1784     // there are multiple AddRec's with the same loop induction variable being
1785     // added together.  If so, we can fold them.
1786     for (unsigned OtherIdx = Idx+1;
1787          OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
1788          ++OtherIdx)
1789       if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
1790         // Other + {A,+,B}<L> + {C,+,D}<L>  -->  Other + {A+C,+,B+D}<L>
1791         SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
1792                                                AddRec->op_end());
1793         for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
1794              ++OtherIdx)
1795           if (const SCEVAddRecExpr *OtherAddRec =
1796                 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
1797             if (OtherAddRec->getLoop() == AddRecLoop) {
1798               for (unsigned i = 0, e = OtherAddRec->getNumOperands();
1799                    i != e; ++i) {
1800                 if (i >= AddRecOps.size()) {
1801                   AddRecOps.append(OtherAddRec->op_begin()+i,
1802                                    OtherAddRec->op_end());
1803                   break;
1804                 }
1805                 AddRecOps[i] = getAddExpr(AddRecOps[i],
1806                                           OtherAddRec->getOperand(i));
1807               }
1808               Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
1809             }
1810         // Step size has changed, so we cannot guarantee no self-wraparound.
1811         Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
1812         return getAddExpr(Ops);
1813       }
1814
1815     // Otherwise couldn't fold anything into this recurrence.  Move onto the
1816     // next one.
1817   }
1818
1819   // Okay, it looks like we really DO need an add expr.  Check to see if we
1820   // already have one, otherwise create a new one.
1821   FoldingSetNodeID ID;
1822   ID.AddInteger(scAddExpr);
1823   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1824     ID.AddPointer(Ops[i]);
1825   void *IP = 0;
1826   SCEVAddExpr *S =
1827     static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1828   if (!S) {
1829     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
1830     std::uninitialized_copy(Ops.begin(), Ops.end(), O);
1831     S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator),
1832                                         O, Ops.size());
1833     UniqueSCEVs.InsertNode(S, IP);
1834   }
1835   S->setNoWrapFlags(Flags);
1836   return S;
1837 }
1838
1839 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
1840   uint64_t k = i*j;
1841   if (j > 1 && k / j != i) Overflow = true;
1842   return k;
1843 }
1844
1845 /// Compute the result of "n choose k", the binomial coefficient.  If an
1846 /// intermediate computation overflows, Overflow will be set and the return will
1847 /// be garbage. Overflow is not cleared on absense of overflow.
1848 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
1849   // We use the multiplicative formula:
1850   //     n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
1851   // At each iteration, we take the n-th term of the numeral and divide by the
1852   // (k-n)th term of the denominator.  This division will always produce an
1853   // integral result, and helps reduce the chance of overflow in the
1854   // intermediate computations. However, we can still overflow even when the
1855   // final result would fit.
1856
1857   if (n == 0 || n == k) return 1;
1858   if (k > n) return 0;
1859
1860   if (k > n/2)
1861     k = n-k;
1862
1863   uint64_t r = 1;
1864   for (uint64_t i = 1; i <= k; ++i) {
1865     r = umul_ov(r, n-(i-1), Overflow);
1866     r /= i;
1867   }
1868   return r;
1869 }
1870
1871 /// getMulExpr - Get a canonical multiply expression, or something simpler if
1872 /// possible.
1873 const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
1874                                         SCEV::NoWrapFlags Flags) {
1875   assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) &&
1876          "only nuw or nsw allowed");
1877   assert(!Ops.empty() && "Cannot get empty mul!");
1878   if (Ops.size() == 1) return Ops[0];
1879 #ifndef NDEBUG
1880   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
1881   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1882     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
1883            "SCEVMulExpr operand types don't match!");
1884 #endif
1885
1886   // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
1887   // And vice-versa.
1888   int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
1889   SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
1890   if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
1891     bool All = true;
1892     for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
1893          E = Ops.end(); I != E; ++I)
1894       if (!isKnownNonNegative(*I)) {
1895         All = false;
1896         break;
1897       }
1898     if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
1899   }
1900
1901   // Sort by complexity, this groups all similar expression types together.
1902   GroupByComplexity(Ops, LI);
1903
1904   // If there are any constants, fold them together.
1905   unsigned Idx = 0;
1906   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1907
1908     // C1*(C2+V) -> C1*C2 + C1*V
1909     if (Ops.size() == 2)
1910       if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
1911         if (Add->getNumOperands() == 2 &&
1912             isa<SCEVConstant>(Add->getOperand(0)))
1913           return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
1914                             getMulExpr(LHSC, Add->getOperand(1)));
1915
1916     ++Idx;
1917     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1918       // We found two constants, fold them together!
1919       ConstantInt *Fold = ConstantInt::get(getContext(),
1920                                            LHSC->getValue()->getValue() *
1921                                            RHSC->getValue()->getValue());
1922       Ops[0] = getConstant(Fold);
1923       Ops.erase(Ops.begin()+1);  // Erase the folded element
1924       if (Ops.size() == 1) return Ops[0];
1925       LHSC = cast<SCEVConstant>(Ops[0]);
1926     }
1927
1928     // If we are left with a constant one being multiplied, strip it off.
1929     if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
1930       Ops.erase(Ops.begin());
1931       --Idx;
1932     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1933       // If we have a multiply of zero, it will always be zero.
1934       return Ops[0];
1935     } else if (Ops[0]->isAllOnesValue()) {
1936       // If we have a mul by -1 of an add, try distributing the -1 among the
1937       // add operands.
1938       if (Ops.size() == 2) {
1939         if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
1940           SmallVector<const SCEV *, 4> NewOps;
1941           bool AnyFolded = false;
1942           for (SCEVAddRecExpr::op_iterator I = Add->op_begin(),
1943                  E = Add->op_end(); I != E; ++I) {
1944             const SCEV *Mul = getMulExpr(Ops[0], *I);
1945             if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
1946             NewOps.push_back(Mul);
1947           }
1948           if (AnyFolded)
1949             return getAddExpr(NewOps);
1950         }
1951         else if (const SCEVAddRecExpr *
1952                  AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
1953           // Negation preserves a recurrence's no self-wrap property.
1954           SmallVector<const SCEV *, 4> Operands;
1955           for (SCEVAddRecExpr::op_iterator I = AddRec->op_begin(),
1956                  E = AddRec->op_end(); I != E; ++I) {
1957             Operands.push_back(getMulExpr(Ops[0], *I));
1958           }
1959           return getAddRecExpr(Operands, AddRec->getLoop(),
1960                                AddRec->getNoWrapFlags(SCEV::FlagNW));
1961         }
1962       }
1963     }
1964
1965     if (Ops.size() == 1)
1966       return Ops[0];
1967   }
1968
1969   // Skip over the add expression until we get to a multiply.
1970   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1971     ++Idx;
1972
1973   // If there are mul operands inline them all into this expression.
1974   if (Idx < Ops.size()) {
1975     bool DeletedMul = false;
1976     while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
1977       // If we have an mul, expand the mul operands onto the end of the operands
1978       // list.
1979       Ops.erase(Ops.begin()+Idx);
1980       Ops.append(Mul->op_begin(), Mul->op_end());
1981       DeletedMul = true;
1982     }
1983
1984     // If we deleted at least one mul, we added operands to the end of the list,
1985     // and they are not necessarily sorted.  Recurse to resort and resimplify
1986     // any operands we just acquired.
1987     if (DeletedMul)
1988       return getMulExpr(Ops);
1989   }
1990
1991   // If there are any add recurrences in the operands list, see if any other
1992   // added values are loop invariant.  If so, we can fold them into the
1993   // recurrence.
1994   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1995     ++Idx;
1996
1997   // Scan over all recurrences, trying to fold loop invariants into them.
1998   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1999     // Scan all of the other operands to this mul and add them to the vector if
2000     // they are loop invariant w.r.t. the recurrence.
2001     SmallVector<const SCEV *, 8> LIOps;
2002     const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2003     const Loop *AddRecLoop = AddRec->getLoop();
2004     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2005       if (isLoopInvariant(Ops[i], AddRecLoop)) {
2006         LIOps.push_back(Ops[i]);
2007         Ops.erase(Ops.begin()+i);
2008         --i; --e;
2009       }
2010
2011     // If we found some loop invariants, fold them into the recurrence.
2012     if (!LIOps.empty()) {
2013       //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
2014       SmallVector<const SCEV *, 4> NewOps;
2015       NewOps.reserve(AddRec->getNumOperands());
2016       const SCEV *Scale = getMulExpr(LIOps);
2017       for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
2018         NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
2019
2020       // Build the new addrec. Propagate the NUW and NSW flags if both the
2021       // outer mul and the inner addrec are guaranteed to have no overflow.
2022       //
2023       // No self-wrap cannot be guaranteed after changing the step size, but
2024       // will be inferred if either NUW or NSW is true.
2025       Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW));
2026       const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags);
2027
2028       // If all of the other operands were loop invariant, we are done.
2029       if (Ops.size() == 1) return NewRec;
2030
2031       // Otherwise, multiply the folded AddRec by the non-invariant parts.
2032       for (unsigned i = 0;; ++i)
2033         if (Ops[i] == AddRec) {
2034           Ops[i] = NewRec;
2035           break;
2036         }
2037       return getMulExpr(Ops);
2038     }
2039
2040     // Okay, if there weren't any loop invariants to be folded, check to see if
2041     // there are multiple AddRec's with the same loop induction variable being
2042     // multiplied together.  If so, we can fold them.
2043     for (unsigned OtherIdx = Idx+1;
2044          OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2045          ++OtherIdx) {
2046       if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2047         // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
2048         // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
2049         //       choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
2050         //   ]]],+,...up to x=2n}.
2051         // Note that the arguments to choose() are always integers with values
2052         // known at compile time, never SCEV objects.
2053         //
2054         // The implementation avoids pointless extra computations when the two
2055         // addrec's are of different length (mathematically, it's equivalent to
2056         // an infinite stream of zeros on the right).
2057         bool OpsModified = false;
2058         for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2059              ++OtherIdx)
2060           if (const SCEVAddRecExpr *OtherAddRec =
2061                 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
2062             if (OtherAddRec->getLoop() == AddRecLoop) {
2063               bool Overflow = false;
2064               Type *Ty = AddRec->getType();
2065               bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
2066               SmallVector<const SCEV*, 7> AddRecOps;
2067               for (int x = 0, xe = AddRec->getNumOperands() +
2068                      OtherAddRec->getNumOperands() - 1;
2069                    x != xe && !Overflow; ++x) {
2070                 const SCEV *Term = getConstant(Ty, 0);
2071                 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
2072                   uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
2073                   for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
2074                          ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
2075                        z < ze && !Overflow; ++z) {
2076                     uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
2077                     uint64_t Coeff;
2078                     if (LargerThan64Bits)
2079                       Coeff = umul_ov(Coeff1, Coeff2, Overflow);
2080                     else
2081                       Coeff = Coeff1*Coeff2;
2082                     const SCEV *CoeffTerm = getConstant(Ty, Coeff);
2083                     const SCEV *Term1 = AddRec->getOperand(y-z);
2084                     const SCEV *Term2 = OtherAddRec->getOperand(z);
2085                     Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2));
2086                   }
2087                 }
2088                 AddRecOps.push_back(Term);
2089               }
2090               if (!Overflow) {
2091                 const SCEV *NewAddRec = getAddRecExpr(AddRecOps,
2092                                                       AddRec->getLoop(),
2093                                                       SCEV::FlagAnyWrap);
2094                 if (Ops.size() == 2) return NewAddRec;
2095                 Ops[Idx] = AddRec = cast<SCEVAddRecExpr>(NewAddRec);
2096                 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2097                 OpsModified = true;
2098               }
2099             }
2100         if (OpsModified)
2101           return getMulExpr(Ops);
2102       }
2103     }
2104
2105     // Otherwise couldn't fold anything into this recurrence.  Move onto the
2106     // next one.
2107   }
2108
2109   // Okay, it looks like we really DO need an mul expr.  Check to see if we
2110   // already have one, otherwise create a new one.
2111   FoldingSetNodeID ID;
2112   ID.AddInteger(scMulExpr);
2113   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2114     ID.AddPointer(Ops[i]);
2115   void *IP = 0;
2116   SCEVMulExpr *S =
2117     static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2118   if (!S) {
2119     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2120     std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2121     S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2122                                         O, Ops.size());
2123     UniqueSCEVs.InsertNode(S, IP);
2124   }
2125   S->setNoWrapFlags(Flags);
2126   return S;
2127 }
2128
2129 /// getUDivExpr - Get a canonical unsigned division expression, or something
2130 /// simpler if possible.
2131 const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
2132                                          const SCEV *RHS) {
2133   assert(getEffectiveSCEVType(LHS->getType()) ==
2134          getEffectiveSCEVType(RHS->getType()) &&
2135          "SCEVUDivExpr operand types don't match!");
2136
2137   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
2138     if (RHSC->getValue()->equalsInt(1))
2139       return LHS;                               // X udiv 1 --> x
2140     // If the denominator is zero, the result of the udiv is undefined. Don't
2141     // try to analyze it, because the resolution chosen here may differ from
2142     // the resolution chosen in other parts of the compiler.
2143     if (!RHSC->getValue()->isZero()) {
2144       // Determine if the division can be folded into the operands of
2145       // its operands.
2146       // TODO: Generalize this to non-constants by using known-bits information.
2147       Type *Ty = LHS->getType();
2148       unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
2149       unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
2150       // For non-power-of-two values, effectively round the value up to the
2151       // nearest power of two.
2152       if (!RHSC->getValue()->getValue().isPowerOf2())
2153         ++MaxShiftAmt;
2154       IntegerType *ExtTy =
2155         IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
2156       if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
2157         if (const SCEVConstant *Step =
2158             dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
2159           // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
2160           const APInt &StepInt = Step->getValue()->getValue();
2161           const APInt &DivInt = RHSC->getValue()->getValue();
2162           if (!StepInt.urem(DivInt) &&
2163               getZeroExtendExpr(AR, ExtTy) ==
2164               getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
2165                             getZeroExtendExpr(Step, ExtTy),
2166                             AR->getLoop(), SCEV::FlagAnyWrap)) {
2167             SmallVector<const SCEV *, 4> Operands;
2168             for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
2169               Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
2170             return getAddRecExpr(Operands, AR->getLoop(),
2171                                  SCEV::FlagNW);
2172           }
2173           /// Get a canonical UDivExpr for a recurrence.
2174           /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
2175           // We can currently only fold X%N if X is constant.
2176           const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
2177           if (StartC && !DivInt.urem(StepInt) &&
2178               getZeroExtendExpr(AR, ExtTy) ==
2179               getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
2180                             getZeroExtendExpr(Step, ExtTy),
2181                             AR->getLoop(), SCEV::FlagAnyWrap)) {
2182             const APInt &StartInt = StartC->getValue()->getValue();
2183             const APInt &StartRem = StartInt.urem(StepInt);
2184             if (StartRem != 0)
2185               LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step,
2186                                   AR->getLoop(), SCEV::FlagNW);
2187           }
2188         }
2189       // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
2190       if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
2191         SmallVector<const SCEV *, 4> Operands;
2192         for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
2193           Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
2194         if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
2195           // Find an operand that's safely divisible.
2196           for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
2197             const SCEV *Op = M->getOperand(i);
2198             const SCEV *Div = getUDivExpr(Op, RHSC);
2199             if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
2200               Operands = SmallVector<const SCEV *, 4>(M->op_begin(),
2201                                                       M->op_end());
2202               Operands[i] = Div;
2203               return getMulExpr(Operands);
2204             }
2205           }
2206       }
2207       // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
2208       if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
2209         SmallVector<const SCEV *, 4> Operands;
2210         for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
2211           Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
2212         if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
2213           Operands.clear();
2214           for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
2215             const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
2216             if (isa<SCEVUDivExpr>(Op) ||
2217                 getMulExpr(Op, RHS) != A->getOperand(i))
2218               break;
2219             Operands.push_back(Op);
2220           }
2221           if (Operands.size() == A->getNumOperands())
2222             return getAddExpr(Operands);
2223         }
2224       }
2225
2226       // Fold if both operands are constant.
2227       if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
2228         Constant *LHSCV = LHSC->getValue();
2229         Constant *RHSCV = RHSC->getValue();
2230         return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
2231                                                                    RHSCV)));
2232       }
2233     }
2234   }
2235
2236   FoldingSetNodeID ID;
2237   ID.AddInteger(scUDivExpr);
2238   ID.AddPointer(LHS);
2239   ID.AddPointer(RHS);
2240   void *IP = 0;
2241   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2242   SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
2243                                              LHS, RHS);
2244   UniqueSCEVs.InsertNode(S, IP);
2245   return S;
2246 }
2247
2248
2249 /// getAddRecExpr - Get an add recurrence expression for the specified loop.
2250 /// Simplify the expression as much as possible.
2251 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
2252                                            const Loop *L,
2253                                            SCEV::NoWrapFlags Flags) {
2254   SmallVector<const SCEV *, 4> Operands;
2255   Operands.push_back(Start);
2256   if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
2257     if (StepChrec->getLoop() == L) {
2258       Operands.append(StepChrec->op_begin(), StepChrec->op_end());
2259       return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
2260     }
2261
2262   Operands.push_back(Step);
2263   return getAddRecExpr(Operands, L, Flags);
2264 }
2265
2266 /// getAddRecExpr - Get an add recurrence expression for the specified loop.
2267 /// Simplify the expression as much as possible.
2268 const SCEV *
2269 ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
2270                                const Loop *L, SCEV::NoWrapFlags Flags) {
2271   if (Operands.size() == 1) return Operands[0];
2272 #ifndef NDEBUG
2273   Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
2274   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
2275     assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
2276            "SCEVAddRecExpr operand types don't match!");
2277   for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2278     assert(isLoopInvariant(Operands[i], L) &&
2279            "SCEVAddRecExpr operand is not loop-invariant!");
2280 #endif
2281
2282   if (Operands.back()->isZero()) {
2283     Operands.pop_back();
2284     return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0}  -->  X
2285   }
2286
2287   // It's tempting to want to call getMaxBackedgeTakenCount count here and
2288   // use that information to infer NUW and NSW flags. However, computing a
2289   // BE count requires calling getAddRecExpr, so we may not yet have a
2290   // meaningful BE count at this point (and if we don't, we'd be stuck
2291   // with a SCEVCouldNotCompute as the cached BE count).
2292
2293   // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2294   // And vice-versa.
2295   int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2296   SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
2297   if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
2298     bool All = true;
2299     for (SmallVectorImpl<const SCEV *>::const_iterator I = Operands.begin(),
2300          E = Operands.end(); I != E; ++I)
2301       if (!isKnownNonNegative(*I)) {
2302         All = false;
2303         break;
2304       }
2305     if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2306   }
2307
2308   // Canonicalize nested AddRecs in by nesting them in order of loop depth.
2309   if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
2310     const Loop *NestedLoop = NestedAR->getLoop();
2311     if (L->contains(NestedLoop) ?
2312         (L->getLoopDepth() < NestedLoop->getLoopDepth()) :
2313         (!NestedLoop->contains(L) &&
2314          DT->dominates(L->getHeader(), NestedLoop->getHeader()))) {
2315       SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
2316                                                   NestedAR->op_end());
2317       Operands[0] = NestedAR->getStart();
2318       // AddRecs require their operands be loop-invariant with respect to their
2319       // loops. Don't perform this transformation if it would break this
2320       // requirement.
2321       bool AllInvariant = true;
2322       for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2323         if (!isLoopInvariant(Operands[i], L)) {
2324           AllInvariant = false;
2325           break;
2326         }
2327       if (AllInvariant) {
2328         // Create a recurrence for the outer loop with the same step size.
2329         //
2330         // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
2331         // inner recurrence has the same property.
2332         SCEV::NoWrapFlags OuterFlags =
2333           maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
2334
2335         NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
2336         AllInvariant = true;
2337         for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i)
2338           if (!isLoopInvariant(NestedOperands[i], NestedLoop)) {
2339             AllInvariant = false;
2340             break;
2341           }
2342         if (AllInvariant) {
2343           // Ok, both add recurrences are valid after the transformation.
2344           //
2345           // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
2346           // the outer recurrence has the same property.
2347           SCEV::NoWrapFlags InnerFlags =
2348             maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
2349           return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
2350         }
2351       }
2352       // Reset Operands to its original state.
2353       Operands[0] = NestedAR;
2354     }
2355   }
2356
2357   // Okay, it looks like we really DO need an addrec expr.  Check to see if we
2358   // already have one, otherwise create a new one.
2359   FoldingSetNodeID ID;
2360   ID.AddInteger(scAddRecExpr);
2361   for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2362     ID.AddPointer(Operands[i]);
2363   ID.AddPointer(L);
2364   void *IP = 0;
2365   SCEVAddRecExpr *S =
2366     static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2367   if (!S) {
2368     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size());
2369     std::uninitialized_copy(Operands.begin(), Operands.end(), O);
2370     S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator),
2371                                            O, Operands.size(), L);
2372     UniqueSCEVs.InsertNode(S, IP);
2373   }
2374   S->setNoWrapFlags(Flags);
2375   return S;
2376 }
2377
2378 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS,
2379                                          const SCEV *RHS) {
2380   SmallVector<const SCEV *, 2> Ops;
2381   Ops.push_back(LHS);
2382   Ops.push_back(RHS);
2383   return getSMaxExpr(Ops);
2384 }
2385
2386 const SCEV *
2387 ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
2388   assert(!Ops.empty() && "Cannot get empty smax!");
2389   if (Ops.size() == 1) return Ops[0];
2390 #ifndef NDEBUG
2391   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2392   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2393     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2394            "SCEVSMaxExpr operand types don't match!");
2395 #endif
2396
2397   // Sort by complexity, this groups all similar expression types together.
2398   GroupByComplexity(Ops, LI);
2399
2400   // If there are any constants, fold them together.
2401   unsigned Idx = 0;
2402   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2403     ++Idx;
2404     assert(Idx < Ops.size());
2405     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2406       // We found two constants, fold them together!
2407       ConstantInt *Fold = ConstantInt::get(getContext(),
2408                               APIntOps::smax(LHSC->getValue()->getValue(),
2409                                              RHSC->getValue()->getValue()));
2410       Ops[0] = getConstant(Fold);
2411       Ops.erase(Ops.begin()+1);  // Erase the folded element
2412       if (Ops.size() == 1) return Ops[0];
2413       LHSC = cast<SCEVConstant>(Ops[0]);
2414     }
2415
2416     // If we are left with a constant minimum-int, strip it off.
2417     if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
2418       Ops.erase(Ops.begin());
2419       --Idx;
2420     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) {
2421       // If we have an smax with a constant maximum-int, it will always be
2422       // maximum-int.
2423       return Ops[0];
2424     }
2425
2426     if (Ops.size() == 1) return Ops[0];
2427   }
2428
2429   // Find the first SMax
2430   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
2431     ++Idx;
2432
2433   // Check to see if one of the operands is an SMax. If so, expand its operands
2434   // onto our operand list, and recurse to simplify.
2435   if (Idx < Ops.size()) {
2436     bool DeletedSMax = false;
2437     while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
2438       Ops.erase(Ops.begin()+Idx);
2439       Ops.append(SMax->op_begin(), SMax->op_end());
2440       DeletedSMax = true;
2441     }
2442
2443     if (DeletedSMax)
2444       return getSMaxExpr(Ops);
2445   }
2446
2447   // Okay, check to see if the same value occurs in the operand list twice.  If
2448   // so, delete one.  Since we sorted the list, these values are required to
2449   // be adjacent.
2450   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
2451     //  X smax Y smax Y  -->  X smax Y
2452     //  X smax Y         -->  X, if X is always greater than Y
2453     if (Ops[i] == Ops[i+1] ||
2454         isKnownPredicate(ICmpInst::ICMP_SGE, Ops[i], Ops[i+1])) {
2455       Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
2456       --i; --e;
2457     } else if (isKnownPredicate(ICmpInst::ICMP_SLE, Ops[i], Ops[i+1])) {
2458       Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
2459       --i; --e;
2460     }
2461
2462   if (Ops.size() == 1) return Ops[0];
2463
2464   assert(!Ops.empty() && "Reduced smax down to nothing!");
2465
2466   // Okay, it looks like we really DO need an smax expr.  Check to see if we
2467   // already have one, otherwise create a new one.
2468   FoldingSetNodeID ID;
2469   ID.AddInteger(scSMaxExpr);
2470   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2471     ID.AddPointer(Ops[i]);
2472   void *IP = 0;
2473   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2474   const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2475   std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2476   SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator),
2477                                              O, Ops.size());
2478   UniqueSCEVs.InsertNode(S, IP);
2479   return S;
2480 }
2481
2482 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS,
2483                                          const SCEV *RHS) {
2484   SmallVector<const SCEV *, 2> Ops;
2485   Ops.push_back(LHS);
2486   Ops.push_back(RHS);
2487   return getUMaxExpr(Ops);
2488 }
2489
2490 const SCEV *
2491 ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
2492   assert(!Ops.empty() && "Cannot get empty umax!");
2493   if (Ops.size() == 1) return Ops[0];
2494 #ifndef NDEBUG
2495   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2496   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2497     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2498            "SCEVUMaxExpr operand types don't match!");
2499 #endif
2500
2501   // Sort by complexity, this groups all similar expression types together.
2502   GroupByComplexity(Ops, LI);
2503
2504   // If there are any constants, fold them together.
2505   unsigned Idx = 0;
2506   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2507     ++Idx;
2508     assert(Idx < Ops.size());
2509     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2510       // We found two constants, fold them together!
2511       ConstantInt *Fold = ConstantInt::get(getContext(),
2512                               APIntOps::umax(LHSC->getValue()->getValue(),
2513                                              RHSC->getValue()->getValue()));
2514       Ops[0] = getConstant(Fold);
2515       Ops.erase(Ops.begin()+1);  // Erase the folded element
2516       if (Ops.size() == 1) return Ops[0];
2517       LHSC = cast<SCEVConstant>(Ops[0]);
2518     }
2519
2520     // If we are left with a constant minimum-int, strip it off.
2521     if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
2522       Ops.erase(Ops.begin());
2523       --Idx;
2524     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) {
2525       // If we have an umax with a constant maximum-int, it will always be
2526       // maximum-int.
2527       return Ops[0];
2528     }
2529
2530     if (Ops.size() == 1) return Ops[0];
2531   }
2532
2533   // Find the first UMax
2534   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
2535     ++Idx;
2536
2537   // Check to see if one of the operands is a UMax. If so, expand its operands
2538   // onto our operand list, and recurse to simplify.
2539   if (Idx < Ops.size()) {
2540     bool DeletedUMax = false;
2541     while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
2542       Ops.erase(Ops.begin()+Idx);
2543       Ops.append(UMax->op_begin(), UMax->op_end());
2544       DeletedUMax = true;
2545     }
2546
2547     if (DeletedUMax)
2548       return getUMaxExpr(Ops);
2549   }
2550
2551   // Okay, check to see if the same value occurs in the operand list twice.  If
2552   // so, delete one.  Since we sorted the list, these values are required to
2553   // be adjacent.
2554   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
2555     //  X umax Y umax Y  -->  X umax Y
2556     //  X umax Y         -->  X, if X is always greater than Y
2557     if (Ops[i] == Ops[i+1] ||
2558         isKnownPredicate(ICmpInst::ICMP_UGE, Ops[i], Ops[i+1])) {
2559       Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
2560       --i; --e;
2561     } else if (isKnownPredicate(ICmpInst::ICMP_ULE, Ops[i], Ops[i+1])) {
2562       Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
2563       --i; --e;
2564     }
2565
2566   if (Ops.size() == 1) return Ops[0];
2567
2568   assert(!Ops.empty() && "Reduced umax down to nothing!");
2569
2570   // Okay, it looks like we really DO need a umax expr.  Check to see if we
2571   // already have one, otherwise create a new one.
2572   FoldingSetNodeID ID;
2573   ID.AddInteger(scUMaxExpr);
2574   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2575     ID.AddPointer(Ops[i]);
2576   void *IP = 0;
2577   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2578   const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2579   std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2580   SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator),
2581                                              O, Ops.size());
2582   UniqueSCEVs.InsertNode(S, IP);
2583   return S;
2584 }
2585
2586 const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
2587                                          const SCEV *RHS) {
2588   // ~smax(~x, ~y) == smin(x, y).
2589   return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
2590 }
2591
2592 const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
2593                                          const SCEV *RHS) {
2594   // ~umax(~x, ~y) == umin(x, y)
2595   return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
2596 }
2597
2598 const SCEV *ScalarEvolution::getSizeOfExpr(Type *AllocTy) {
2599   // If we have TargetData, we can bypass creating a target-independent
2600   // constant expression and then folding it back into a ConstantInt.
2601   // This is just a compile-time optimization.
2602   if (TD)
2603     return getConstant(TD->getIntPtrType(getContext()),
2604                        TD->getTypeAllocSize(AllocTy));
2605
2606   Constant *C = ConstantExpr::getSizeOf(AllocTy);
2607   if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2608     if (Constant *Folded = ConstantFoldConstantExpression(CE, TD, TLI))
2609       C = Folded;
2610   Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy));
2611   return getTruncateOrZeroExtend(getSCEV(C), Ty);
2612 }
2613
2614 const SCEV *ScalarEvolution::getAlignOfExpr(Type *AllocTy) {
2615   Constant *C = ConstantExpr::getAlignOf(AllocTy);
2616   if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2617     if (Constant *Folded = ConstantFoldConstantExpression(CE, TD, TLI))
2618       C = Folded;
2619   Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy));
2620   return getTruncateOrZeroExtend(getSCEV(C), Ty);
2621 }
2622
2623 const SCEV *ScalarEvolution::getOffsetOfExpr(StructType *STy,
2624                                              unsigned FieldNo) {
2625   // If we have TargetData, we can bypass creating a target-independent
2626   // constant expression and then folding it back into a ConstantInt.
2627   // This is just a compile-time optimization.
2628   if (TD)
2629     return getConstant(TD->getIntPtrType(getContext()),
2630                        TD->getStructLayout(STy)->getElementOffset(FieldNo));
2631
2632   Constant *C = ConstantExpr::getOffsetOf(STy, FieldNo);
2633   if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2634     if (Constant *Folded = ConstantFoldConstantExpression(CE, TD, TLI))
2635       C = Folded;
2636   Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy));
2637   return getTruncateOrZeroExtend(getSCEV(C), Ty);
2638 }
2639
2640 const SCEV *ScalarEvolution::getOffsetOfExpr(Type *CTy,
2641                                              Constant *FieldNo) {
2642   Constant *C = ConstantExpr::getOffsetOf(CTy, FieldNo);
2643   if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2644     if (Constant *Folded = ConstantFoldConstantExpression(CE, TD, TLI))
2645       C = Folded;
2646   Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(CTy));
2647   return getTruncateOrZeroExtend(getSCEV(C), Ty);
2648 }
2649
2650 const SCEV *ScalarEvolution::getUnknown(Value *V) {
2651   // Don't attempt to do anything other than create a SCEVUnknown object
2652   // here.  createSCEV only calls getUnknown after checking for all other
2653   // interesting possibilities, and any other code that calls getUnknown
2654   // is doing so in order to hide a value from SCEV canonicalization.
2655
2656   FoldingSetNodeID ID;
2657   ID.AddInteger(scUnknown);
2658   ID.AddPointer(V);
2659   void *IP = 0;
2660   if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
2661     assert(cast<SCEVUnknown>(S)->getValue() == V &&
2662            "Stale SCEVUnknown in uniquing map!");
2663     return S;
2664   }
2665   SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
2666                                             FirstUnknown);
2667   FirstUnknown = cast<SCEVUnknown>(S);
2668   UniqueSCEVs.InsertNode(S, IP);
2669   return S;
2670 }
2671
2672 //===----------------------------------------------------------------------===//
2673 //            Basic SCEV Analysis and PHI Idiom Recognition Code
2674 //
2675
2676 /// isSCEVable - Test if values of the given type are analyzable within
2677 /// the SCEV framework. This primarily includes integer types, and it
2678 /// can optionally include pointer types if the ScalarEvolution class
2679 /// has access to target-specific information.
2680 bool ScalarEvolution::isSCEVable(Type *Ty) const {
2681   // Integers and pointers are always SCEVable.
2682   return Ty->isIntegerTy() || Ty->isPointerTy();
2683 }
2684
2685 /// getTypeSizeInBits - Return the size in bits of the specified type,
2686 /// for which isSCEVable must return true.
2687 uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
2688   assert(isSCEVable(Ty) && "Type is not SCEVable!");
2689
2690   // If we have a TargetData, use it!
2691   if (TD)
2692     return TD->getTypeSizeInBits(Ty);
2693
2694   // Integer types have fixed sizes.
2695   if (Ty->isIntegerTy())
2696     return Ty->getPrimitiveSizeInBits();
2697
2698   // The only other support type is pointer. Without TargetData, conservatively
2699   // assume pointers are 64-bit.
2700   assert(Ty->isPointerTy() && "isSCEVable permitted a non-SCEVable type!");
2701   return 64;
2702 }
2703
2704 /// getEffectiveSCEVType - Return a type with the same bitwidth as
2705 /// the given type and which represents how SCEV will treat the given
2706 /// type, for which isSCEVable must return true. For pointer types,
2707 /// this is the pointer-sized integer type.
2708 Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
2709   assert(isSCEVable(Ty) && "Type is not SCEVable!");
2710
2711   if (Ty->isIntegerTy())
2712     return Ty;
2713
2714   // The only other support type is pointer.
2715   assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
2716   if (TD) return TD->getIntPtrType(getContext());
2717
2718   // Without TargetData, conservatively assume pointers are 64-bit.
2719   return Type::getInt64Ty(getContext());
2720 }
2721
2722 const SCEV *ScalarEvolution::getCouldNotCompute() {
2723   return &CouldNotCompute;
2724 }
2725
2726 /// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
2727 /// expression and create a new one.
2728 const SCEV *ScalarEvolution::getSCEV(Value *V) {
2729   assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
2730
2731   ValueExprMapType::const_iterator I = ValueExprMap.find(V);
2732   if (I != ValueExprMap.end()) return I->second;
2733   const SCEV *S = createSCEV(V);
2734
2735   // The process of creating a SCEV for V may have caused other SCEVs
2736   // to have been created, so it's necessary to insert the new entry
2737   // from scratch, rather than trying to remember the insert position
2738   // above.
2739   ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S));
2740   return S;
2741 }
2742
2743 /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
2744 ///
2745 const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) {
2746   if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2747     return getConstant(
2748                cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
2749
2750   Type *Ty = V->getType();
2751   Ty = getEffectiveSCEVType(Ty);
2752   return getMulExpr(V,
2753                   getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))));
2754 }
2755
2756 /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
2757 const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
2758   if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2759     return getConstant(
2760                 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
2761
2762   Type *Ty = V->getType();
2763   Ty = getEffectiveSCEVType(Ty);
2764   const SCEV *AllOnes =
2765                    getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty)));
2766   return getMinusSCEV(AllOnes, V);
2767 }
2768
2769 /// getMinusSCEV - Return LHS-RHS.  Minus is represented in SCEV as A+B*-1.
2770 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
2771                                           SCEV::NoWrapFlags Flags) {
2772   assert(!maskFlags(Flags, SCEV::FlagNUW) && "subtraction does not have NUW");
2773
2774   // Fast path: X - X --> 0.
2775   if (LHS == RHS)
2776     return getConstant(LHS->getType(), 0);
2777
2778   // X - Y --> X + -Y
2779   return getAddExpr(LHS, getNegativeSCEV(RHS), Flags);
2780 }
2781
2782 /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
2783 /// input value to the specified type.  If the type must be extended, it is zero
2784 /// extended.
2785 const SCEV *
2786 ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) {
2787   Type *SrcTy = V->getType();
2788   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2789          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2790          "Cannot truncate or zero extend with non-integer arguments!");
2791   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2792     return V;  // No conversion
2793   if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2794     return getTruncateExpr(V, Ty);
2795   return getZeroExtendExpr(V, Ty);
2796 }
2797
2798 /// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
2799 /// input value to the specified type.  If the type must be extended, it is sign
2800 /// extended.
2801 const SCEV *
2802 ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
2803                                          Type *Ty) {
2804   Type *SrcTy = V->getType();
2805   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2806          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2807          "Cannot truncate or zero extend with non-integer arguments!");
2808   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2809     return V;  // No conversion
2810   if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2811     return getTruncateExpr(V, Ty);
2812   return getSignExtendExpr(V, Ty);
2813 }
2814
2815 /// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the
2816 /// input value to the specified type.  If the type must be extended, it is zero
2817 /// extended.  The conversion must not be narrowing.
2818 const SCEV *
2819 ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
2820   Type *SrcTy = V->getType();
2821   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2822          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2823          "Cannot noop or zero extend with non-integer arguments!");
2824   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2825          "getNoopOrZeroExtend cannot truncate!");
2826   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2827     return V;  // No conversion
2828   return getZeroExtendExpr(V, Ty);
2829 }
2830
2831 /// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the
2832 /// input value to the specified type.  If the type must be extended, it is sign
2833 /// extended.  The conversion must not be narrowing.
2834 const SCEV *
2835 ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
2836   Type *SrcTy = V->getType();
2837   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2838          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2839          "Cannot noop or sign extend with non-integer arguments!");
2840   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2841          "getNoopOrSignExtend cannot truncate!");
2842   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2843     return V;  // No conversion
2844   return getSignExtendExpr(V, Ty);
2845 }
2846
2847 /// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of
2848 /// the input value to the specified type. If the type must be extended,
2849 /// it is extended with unspecified bits. The conversion must not be
2850 /// narrowing.
2851 const SCEV *
2852 ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
2853   Type *SrcTy = V->getType();
2854   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2855          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2856          "Cannot noop or any extend with non-integer arguments!");
2857   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2858          "getNoopOrAnyExtend cannot truncate!");
2859   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2860     return V;  // No conversion
2861   return getAnyExtendExpr(V, Ty);
2862 }
2863
2864 /// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the
2865 /// input value to the specified type.  The conversion must not be widening.
2866 const SCEV *
2867 ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
2868   Type *SrcTy = V->getType();
2869   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2870          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2871          "Cannot truncate or noop with non-integer arguments!");
2872   assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
2873          "getTruncateOrNoop cannot extend!");
2874   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2875     return V;  // No conversion
2876   return getTruncateExpr(V, Ty);
2877 }
2878
2879 /// getUMaxFromMismatchedTypes - Promote the operands to the wider of
2880 /// the types using zero-extension, and then perform a umax operation
2881 /// with them.
2882 const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
2883                                                         const SCEV *RHS) {
2884   const SCEV *PromotedLHS = LHS;
2885   const SCEV *PromotedRHS = RHS;
2886
2887   if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2888     PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2889   else
2890     PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2891
2892   return getUMaxExpr(PromotedLHS, PromotedRHS);
2893 }
2894
2895 /// getUMinFromMismatchedTypes - Promote the operands to the wider of
2896 /// the types using zero-extension, and then perform a umin operation
2897 /// with them.
2898 const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
2899                                                         const SCEV *RHS) {
2900   const SCEV *PromotedLHS = LHS;
2901   const SCEV *PromotedRHS = RHS;
2902
2903   if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2904     PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2905   else
2906     PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2907
2908   return getUMinExpr(PromotedLHS, PromotedRHS);
2909 }
2910
2911 /// getPointerBase - Transitively follow the chain of pointer-type operands
2912 /// until reaching a SCEV that does not have a single pointer operand. This
2913 /// returns a SCEVUnknown pointer for well-formed pointer-type expressions,
2914 /// but corner cases do exist.
2915 const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
2916   // A pointer operand may evaluate to a nonpointer expression, such as null.
2917   if (!V->getType()->isPointerTy())
2918     return V;
2919
2920   if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
2921     return getPointerBase(Cast->getOperand());
2922   }
2923   else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
2924     const SCEV *PtrOp = 0;
2925     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
2926          I != E; ++I) {
2927       if ((*I)->getType()->isPointerTy()) {
2928         // Cannot find the base of an expression with multiple pointer operands.
2929         if (PtrOp)
2930           return V;
2931         PtrOp = *I;
2932       }
2933     }
2934     if (!PtrOp)
2935       return V;
2936     return getPointerBase(PtrOp);
2937   }
2938   return V;
2939 }
2940
2941 /// PushDefUseChildren - Push users of the given Instruction
2942 /// onto the given Worklist.
2943 static void
2944 PushDefUseChildren(Instruction *I,
2945                    SmallVectorImpl<Instruction *> &Worklist) {
2946   // Push the def-use children onto the Worklist stack.
2947   for (Value::use_iterator UI = I->use_begin(), UE = I->use_end();
2948        UI != UE; ++UI)
2949     Worklist.push_back(cast<Instruction>(*UI));
2950 }
2951
2952 /// ForgetSymbolicValue - This looks up computed SCEV values for all
2953 /// instructions that depend on the given instruction and removes them from
2954 /// the ValueExprMapType map if they reference SymName. This is used during PHI
2955 /// resolution.
2956 void
2957 ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
2958   SmallVector<Instruction *, 16> Worklist;
2959   PushDefUseChildren(PN, Worklist);
2960
2961   SmallPtrSet<Instruction *, 8> Visited;
2962   Visited.insert(PN);
2963   while (!Worklist.empty()) {
2964     Instruction *I = Worklist.pop_back_val();
2965     if (!Visited.insert(I)) continue;
2966
2967     ValueExprMapType::iterator It =
2968       ValueExprMap.find(static_cast<Value *>(I));
2969     if (It != ValueExprMap.end()) {
2970       const SCEV *Old = It->second;
2971
2972       // Short-circuit the def-use traversal if the symbolic name
2973       // ceases to appear in expressions.
2974       if (Old != SymName && !hasOperand(Old, SymName))
2975         continue;
2976
2977       // SCEVUnknown for a PHI either means that it has an unrecognized
2978       // structure, it's a PHI that's in the progress of being computed
2979       // by createNodeForPHI, or it's a single-value PHI. In the first case,
2980       // additional loop trip count information isn't going to change anything.
2981       // In the second case, createNodeForPHI will perform the necessary
2982       // updates on its own when it gets to that point. In the third, we do
2983       // want to forget the SCEVUnknown.
2984       if (!isa<PHINode>(I) ||
2985           !isa<SCEVUnknown>(Old) ||
2986           (I != PN && Old == SymName)) {
2987         forgetMemoizedResults(Old);
2988         ValueExprMap.erase(It);
2989       }
2990     }
2991
2992     PushDefUseChildren(I, Worklist);
2993   }
2994 }
2995
2996 /// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
2997 /// a loop header, making it a potential recurrence, or it doesn't.
2998 ///
2999 const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
3000   if (const Loop *L = LI->getLoopFor(PN->getParent()))
3001     if (L->getHeader() == PN->getParent()) {
3002       // The loop may have multiple entrances or multiple exits; we can analyze
3003       // this phi as an addrec if it has a unique entry value and a unique
3004       // backedge value.
3005       Value *BEValueV = 0, *StartValueV = 0;
3006       for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
3007         Value *V = PN->getIncomingValue(i);
3008         if (L->contains(PN->getIncomingBlock(i))) {
3009           if (!BEValueV) {
3010             BEValueV = V;
3011           } else if (BEValueV != V) {
3012             BEValueV = 0;
3013             break;
3014           }
3015         } else if (!StartValueV) {
3016           StartValueV = V;
3017         } else if (StartValueV != V) {
3018           StartValueV = 0;
3019           break;
3020         }
3021       }
3022       if (BEValueV && StartValueV) {
3023         // While we are analyzing this PHI node, handle its value symbolically.
3024         const SCEV *SymbolicName = getUnknown(PN);
3025         assert(ValueExprMap.find(PN) == ValueExprMap.end() &&
3026                "PHI node already processed?");
3027         ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
3028
3029         // Using this symbolic name for the PHI, analyze the value coming around
3030         // the back-edge.
3031         const SCEV *BEValue = getSCEV(BEValueV);
3032
3033         // NOTE: If BEValue is loop invariant, we know that the PHI node just
3034         // has a special value for the first iteration of the loop.
3035
3036         // If the value coming around the backedge is an add with the symbolic
3037         // value we just inserted, then we found a simple induction variable!
3038         if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
3039           // If there is a single occurrence of the symbolic value, replace it
3040           // with a recurrence.
3041           unsigned FoundIndex = Add->getNumOperands();
3042           for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
3043             if (Add->getOperand(i) == SymbolicName)
3044               if (FoundIndex == e) {
3045                 FoundIndex = i;
3046                 break;
3047               }
3048
3049           if (FoundIndex != Add->getNumOperands()) {
3050             // Create an add with everything but the specified operand.
3051             SmallVector<const SCEV *, 8> Ops;
3052             for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
3053               if (i != FoundIndex)
3054                 Ops.push_back(Add->getOperand(i));
3055             const SCEV *Accum = getAddExpr(Ops);
3056
3057             // This is not a valid addrec if the step amount is varying each
3058             // loop iteration, but is not itself an addrec in this loop.
3059             if (isLoopInvariant(Accum, L) ||
3060                 (isa<SCEVAddRecExpr>(Accum) &&
3061                  cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
3062               SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
3063
3064               // If the increment doesn't overflow, then neither the addrec nor
3065               // the post-increment will overflow.
3066               if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) {
3067                 if (OBO->hasNoUnsignedWrap())
3068                   Flags = setFlags(Flags, SCEV::FlagNUW);
3069                 if (OBO->hasNoSignedWrap())
3070                   Flags = setFlags(Flags, SCEV::FlagNSW);
3071               } else if (const GEPOperator *GEP =
3072                          dyn_cast<GEPOperator>(BEValueV)) {
3073                 // If the increment is an inbounds GEP, then we know the address
3074                 // space cannot be wrapped around. We cannot make any guarantee
3075                 // about signed or unsigned overflow because pointers are
3076                 // unsigned but we may have a negative index from the base
3077                 // pointer.
3078                 if (GEP->isInBounds())
3079                   Flags = setFlags(Flags, SCEV::FlagNW);
3080               }
3081
3082               const SCEV *StartVal = getSCEV(StartValueV);
3083               const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
3084
3085               // Since the no-wrap flags are on the increment, they apply to the
3086               // post-incremented value as well.
3087               if (isLoopInvariant(Accum, L))
3088                 (void)getAddRecExpr(getAddExpr(StartVal, Accum),
3089                                     Accum, L, Flags);
3090
3091               // Okay, for the entire analysis of this edge we assumed the PHI
3092               // to be symbolic.  We now need to go back and purge all of the
3093               // entries for the scalars that use the symbolic expression.
3094               ForgetSymbolicName(PN, SymbolicName);
3095               ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
3096               return PHISCEV;
3097             }
3098           }
3099         } else if (const SCEVAddRecExpr *AddRec =
3100                      dyn_cast<SCEVAddRecExpr>(BEValue)) {
3101           // Otherwise, this could be a loop like this:
3102           //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
3103           // In this case, j = {1,+,1}  and BEValue is j.
3104           // Because the other in-value of i (0) fits the evolution of BEValue
3105           // i really is an addrec evolution.
3106           if (AddRec->getLoop() == L && AddRec->isAffine()) {
3107             const SCEV *StartVal = getSCEV(StartValueV);
3108
3109             // If StartVal = j.start - j.stride, we can use StartVal as the
3110             // initial step of the addrec evolution.
3111             if (StartVal == getMinusSCEV(AddRec->getOperand(0),
3112                                          AddRec->getOperand(1))) {
3113               // FIXME: For constant StartVal, we should be able to infer
3114               // no-wrap flags.
3115               const SCEV *PHISCEV =
3116                 getAddRecExpr(StartVal, AddRec->getOperand(1), L,
3117                               SCEV::FlagAnyWrap);
3118
3119               // Okay, for the entire analysis of this edge we assumed the PHI
3120               // to be symbolic.  We now need to go back and purge all of the
3121               // entries for the scalars that use the symbolic expression.
3122               ForgetSymbolicName(PN, SymbolicName);
3123               ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
3124               return PHISCEV;
3125             }
3126           }
3127         }
3128       }
3129     }
3130
3131   // If the PHI has a single incoming value, follow that value, unless the
3132   // PHI's incoming blocks are in a different loop, in which case doing so
3133   // risks breaking LCSSA form. Instcombine would normally zap these, but
3134   // it doesn't have DominatorTree information, so it may miss cases.
3135   if (Value *V = SimplifyInstruction(PN, TD, TLI, DT))
3136     if (LI->replacementPreservesLCSSAForm(PN, V))
3137       return getSCEV(V);
3138
3139   // If it's not a loop phi, we can't handle it yet.
3140   return getUnknown(PN);
3141 }
3142
3143 /// createNodeForGEP - Expand GEP instructions into add and multiply
3144 /// operations. This allows them to be analyzed by regular SCEV code.
3145 ///
3146 const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
3147
3148   // Don't blindly transfer the inbounds flag from the GEP instruction to the
3149   // Add expression, because the Instruction may be guarded by control flow
3150   // and the no-overflow bits may not be valid for the expression in any
3151   // context.
3152   bool isInBounds = GEP->isInBounds();
3153
3154   Type *IntPtrTy = getEffectiveSCEVType(GEP->getType());
3155   Value *Base = GEP->getOperand(0);
3156   // Don't attempt to analyze GEPs over unsized objects.
3157   if (!cast<PointerType>(Base->getType())->getElementType()->isSized())
3158     return getUnknown(GEP);
3159   const SCEV *TotalOffset = getConstant(IntPtrTy, 0);
3160   gep_type_iterator GTI = gep_type_begin(GEP);
3161   for (GetElementPtrInst::op_iterator I = llvm::next(GEP->op_begin()),
3162                                       E = GEP->op_end();
3163        I != E; ++I) {
3164     Value *Index = *I;
3165     // Compute the (potentially symbolic) offset in bytes for this index.
3166     if (StructType *STy = dyn_cast<StructType>(*GTI++)) {
3167       // For a struct, add the member offset.
3168       unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
3169       const SCEV *FieldOffset = getOffsetOfExpr(STy, FieldNo);
3170
3171       // Add the field offset to the running total offset.
3172       TotalOffset = getAddExpr(TotalOffset, FieldOffset);
3173     } else {
3174       // For an array, add the element offset, explicitly scaled.
3175       const SCEV *ElementSize = getSizeOfExpr(*GTI);
3176       const SCEV *IndexS = getSCEV(Index);
3177       // Getelementptr indices are signed.
3178       IndexS = getTruncateOrSignExtend(IndexS, IntPtrTy);
3179
3180       // Multiply the index by the element size to compute the element offset.
3181       const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize,
3182                                            isInBounds ? SCEV::FlagNSW :
3183                                            SCEV::FlagAnyWrap);
3184
3185       // Add the element offset to the running total offset.
3186       TotalOffset = getAddExpr(TotalOffset, LocalOffset);
3187     }
3188   }
3189
3190   // Get the SCEV for the GEP base.
3191   const SCEV *BaseS = getSCEV(Base);
3192
3193   // Add the total offset from all the GEP indices to the base.
3194   return getAddExpr(BaseS, TotalOffset,
3195                     isInBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap);
3196 }
3197
3198 /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
3199 /// guaranteed to end in (at every loop iteration).  It is, at the same time,
3200 /// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
3201 /// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
3202 uint32_t
3203 ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
3204   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3205     return C->getValue()->getValue().countTrailingZeros();
3206
3207   if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
3208     return std::min(GetMinTrailingZeros(T->getOperand()),
3209                     (uint32_t)getTypeSizeInBits(T->getType()));
3210
3211   if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
3212     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
3213     return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
3214              getTypeSizeInBits(E->getType()) : OpRes;
3215   }
3216
3217   if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
3218     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
3219     return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
3220              getTypeSizeInBits(E->getType()) : OpRes;
3221   }
3222
3223   if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
3224     // The result is the min of all operands results.
3225     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
3226     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
3227       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
3228     return MinOpRes;
3229   }
3230
3231   if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
3232     // The result is the sum of all operands results.
3233     uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
3234     uint32_t BitWidth = getTypeSizeInBits(M->getType());
3235     for (unsigned i = 1, e = M->getNumOperands();
3236          SumOpRes != BitWidth && i != e; ++i)
3237       SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
3238                           BitWidth);
3239     return SumOpRes;
3240   }
3241
3242   if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
3243     // The result is the min of all operands results.
3244     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
3245     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
3246       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
3247     return MinOpRes;
3248   }
3249
3250   if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
3251     // The result is the min of all operands results.
3252     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
3253     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
3254       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
3255     return MinOpRes;
3256   }
3257
3258   if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
3259     // The result is the min of all operands results.
3260     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
3261     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
3262       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
3263     return MinOpRes;
3264   }
3265
3266   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3267     // For a SCEVUnknown, ask ValueTracking.
3268     unsigned BitWidth = getTypeSizeInBits(U->getType());
3269     APInt Mask = APInt::getAllOnesValue(BitWidth);
3270     APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
3271     ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones);
3272     return Zeros.countTrailingOnes();
3273   }
3274
3275   // SCEVUDivExpr
3276   return 0;
3277 }
3278
3279 /// getUnsignedRange - Determine the unsigned range for a particular SCEV.
3280 ///
3281 ConstantRange
3282 ScalarEvolution::getUnsignedRange(const SCEV *S) {
3283   // See if we've computed this range already.
3284   DenseMap<const SCEV *, ConstantRange>::iterator I = UnsignedRanges.find(S);
3285   if (I != UnsignedRanges.end())
3286     return I->second;
3287
3288   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3289     return setUnsignedRange(C, ConstantRange(C->getValue()->getValue()));
3290
3291   unsigned BitWidth = getTypeSizeInBits(S->getType());
3292   ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
3293
3294   // If the value has known zeros, the maximum unsigned value will have those
3295   // known zeros as well.
3296   uint32_t TZ = GetMinTrailingZeros(S);
3297   if (TZ != 0)
3298     ConservativeResult =
3299       ConstantRange(APInt::getMinValue(BitWidth),
3300                     APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
3301
3302   if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
3303     ConstantRange X = getUnsignedRange(Add->getOperand(0));
3304     for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
3305       X = X.add(getUnsignedRange(Add->getOperand(i)));
3306     return setUnsignedRange(Add, ConservativeResult.intersectWith(X));
3307   }
3308
3309   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
3310     ConstantRange X = getUnsignedRange(Mul->getOperand(0));
3311     for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
3312       X = X.multiply(getUnsignedRange(Mul->getOperand(i)));
3313     return setUnsignedRange(Mul, ConservativeResult.intersectWith(X));
3314   }
3315
3316   if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
3317     ConstantRange X = getUnsignedRange(SMax->getOperand(0));
3318     for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
3319       X = X.smax(getUnsignedRange(SMax->getOperand(i)));
3320     return setUnsignedRange(SMax, ConservativeResult.intersectWith(X));
3321   }
3322
3323   if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
3324     ConstantRange X = getUnsignedRange(UMax->getOperand(0));
3325     for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
3326       X = X.umax(getUnsignedRange(UMax->getOperand(i)));
3327     return setUnsignedRange(UMax, ConservativeResult.intersectWith(X));
3328   }
3329
3330   if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
3331     ConstantRange X = getUnsignedRange(UDiv->getLHS());
3332     ConstantRange Y = getUnsignedRange(UDiv->getRHS());
3333     return setUnsignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y)));
3334   }
3335
3336   if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
3337     ConstantRange X = getUnsignedRange(ZExt->getOperand());
3338     return setUnsignedRange(ZExt,
3339       ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
3340   }
3341
3342   if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
3343     ConstantRange X = getUnsignedRange(SExt->getOperand());
3344     return setUnsignedRange(SExt,
3345       ConservativeResult.intersectWith(X.signExtend(BitWidth)));
3346   }
3347
3348   if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
3349     ConstantRange X = getUnsignedRange(Trunc->getOperand());
3350     return setUnsignedRange(Trunc,
3351       ConservativeResult.intersectWith(X.truncate(BitWidth)));
3352   }
3353
3354   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
3355     // If there's no unsigned wrap, the value will never be less than its
3356     // initial value.
3357     if (AddRec->getNoWrapFlags(SCEV::FlagNUW))
3358       if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
3359         if (!C->getValue()->isZero())
3360           ConservativeResult =
3361             ConservativeResult.intersectWith(
3362               ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0)));
3363
3364     // TODO: non-affine addrec
3365     if (AddRec->isAffine()) {
3366       Type *Ty = AddRec->getType();
3367       const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
3368       if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
3369           getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
3370         MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
3371
3372         const SCEV *Start = AddRec->getStart();
3373         const SCEV *Step = AddRec->getStepRecurrence(*this);
3374
3375         ConstantRange StartRange = getUnsignedRange(Start);
3376         ConstantRange StepRange = getSignedRange(Step);
3377         ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
3378         ConstantRange EndRange =
3379           StartRange.add(MaxBECountRange.multiply(StepRange));
3380
3381         // Check for overflow. This must be done with ConstantRange arithmetic
3382         // because we could be called from within the ScalarEvolution overflow
3383         // checking code.
3384         ConstantRange ExtStartRange = StartRange.zextOrTrunc(BitWidth*2+1);
3385         ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1);
3386         ConstantRange ExtMaxBECountRange =
3387           MaxBECountRange.zextOrTrunc(BitWidth*2+1);
3388         ConstantRange ExtEndRange = EndRange.zextOrTrunc(BitWidth*2+1);
3389         if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) !=
3390             ExtEndRange)
3391           return setUnsignedRange(AddRec, ConservativeResult);
3392
3393         APInt Min = APIntOps::umin(StartRange.getUnsignedMin(),
3394                                    EndRange.getUnsignedMin());
3395         APInt Max = APIntOps::umax(StartRange.getUnsignedMax(),
3396                                    EndRange.getUnsignedMax());
3397         if (Min.isMinValue() && Max.isMaxValue())
3398           return setUnsignedRange(AddRec, ConservativeResult);
3399         return setUnsignedRange(AddRec,
3400           ConservativeResult.intersectWith(ConstantRange(Min, Max+1)));
3401       }
3402     }
3403
3404     return setUnsignedRange(AddRec, ConservativeResult);
3405   }
3406
3407   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3408     // For a SCEVUnknown, ask ValueTracking.
3409     APInt Mask = APInt::getAllOnesValue(BitWidth);
3410     APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
3411     ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD);
3412     if (Ones == ~Zeros + 1)
3413       return setUnsignedRange(U, ConservativeResult);
3414     return setUnsignedRange(U,
3415       ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)));
3416   }
3417
3418   return setUnsignedRange(S, ConservativeResult);
3419 }
3420
3421 /// getSignedRange - Determine the signed range for a particular SCEV.
3422 ///
3423 ConstantRange
3424 ScalarEvolution::getSignedRange(const SCEV *S) {
3425   // See if we've computed this range already.
3426   DenseMap<const SCEV *, ConstantRange>::iterator I = SignedRanges.find(S);
3427   if (I != SignedRanges.end())
3428     return I->second;
3429
3430   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3431     return setSignedRange(C, ConstantRange(C->getValue()->getValue()));
3432
3433   unsigned BitWidth = getTypeSizeInBits(S->getType());
3434   ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
3435
3436   // If the value has known zeros, the maximum signed value will have those
3437   // known zeros as well.
3438   uint32_t TZ = GetMinTrailingZeros(S);
3439   if (TZ != 0)
3440     ConservativeResult =
3441       ConstantRange(APInt::getSignedMinValue(BitWidth),
3442                     APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
3443
3444   if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
3445     ConstantRange X = getSignedRange(Add->getOperand(0));
3446     for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
3447       X = X.add(getSignedRange(Add->getOperand(i)));
3448     return setSignedRange(Add, ConservativeResult.intersectWith(X));
3449   }
3450
3451   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
3452     ConstantRange X = getSignedRange(Mul->getOperand(0));
3453     for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
3454       X = X.multiply(getSignedRange(Mul->getOperand(i)));
3455     return setSignedRange(Mul, ConservativeResult.intersectWith(X));
3456   }
3457
3458   if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
3459     ConstantRange X = getSignedRange(SMax->getOperand(0));
3460     for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
3461       X = X.smax(getSignedRange(SMax->getOperand(i)));
3462     return setSignedRange(SMax, ConservativeResult.intersectWith(X));
3463   }
3464
3465   if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
3466     ConstantRange X = getSignedRange(UMax->getOperand(0));
3467     for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
3468       X = X.umax(getSignedRange(UMax->getOperand(i)));
3469     return setSignedRange(UMax, ConservativeResult.intersectWith(X));
3470   }
3471
3472   if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
3473     ConstantRange X = getSignedRange(UDiv->getLHS());
3474     ConstantRange Y = getSignedRange(UDiv->getRHS());
3475     return setSignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y)));
3476   }
3477
3478   if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
3479     ConstantRange X = getSignedRange(ZExt->getOperand());
3480     return setSignedRange(ZExt,
3481       ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
3482   }
3483
3484   if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
3485     ConstantRange X = getSignedRange(SExt->getOperand());
3486     return setSignedRange(SExt,
3487       ConservativeResult.intersectWith(X.signExtend(BitWidth)));
3488   }
3489
3490   if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
3491     ConstantRange X = getSignedRange(Trunc->getOperand());
3492     return setSignedRange(Trunc,
3493       ConservativeResult.intersectWith(X.truncate(BitWidth)));
3494   }
3495
3496   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
3497     // If there's no signed wrap, and all the operands have the same sign or
3498     // zero, the value won't ever change sign.
3499     if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) {
3500       bool AllNonNeg = true;
3501       bool AllNonPos = true;
3502       for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3503         if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false;
3504         if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false;
3505       }
3506       if (AllNonNeg)
3507         ConservativeResult = ConservativeResult.intersectWith(
3508           ConstantRange(APInt(BitWidth, 0),
3509                         APInt::getSignedMinValue(BitWidth)));
3510       else if (AllNonPos)
3511         ConservativeResult = ConservativeResult.intersectWith(
3512           ConstantRange(APInt::getSignedMinValue(BitWidth),
3513                         APInt(BitWidth, 1)));
3514     }
3515
3516     // TODO: non-affine addrec
3517     if (AddRec->isAffine()) {
3518       Type *Ty = AddRec->getType();
3519       const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
3520       if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
3521           getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
3522         MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
3523
3524         const SCEV *Start = AddRec->getStart();
3525         const SCEV *Step = AddRec->getStepRecurrence(*this);
3526
3527         ConstantRange StartRange = getSignedRange(Start);
3528         ConstantRange StepRange = getSignedRange(Step);
3529         ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
3530         ConstantRange EndRange =
3531           StartRange.add(MaxBECountRange.multiply(StepRange));
3532
3533         // Check for overflow. This must be done with ConstantRange arithmetic
3534         // because we could be called from within the ScalarEvolution overflow
3535         // checking code.
3536         ConstantRange ExtStartRange = StartRange.sextOrTrunc(BitWidth*2+1);
3537         ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1);
3538         ConstantRange ExtMaxBECountRange =
3539           MaxBECountRange.zextOrTrunc(BitWidth*2+1);
3540         ConstantRange ExtEndRange = EndRange.sextOrTrunc(BitWidth*2+1);
3541         if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) !=
3542             ExtEndRange)
3543           return setSignedRange(AddRec, ConservativeResult);
3544
3545         APInt Min = APIntOps::smin(StartRange.getSignedMin(),
3546                                    EndRange.getSignedMin());
3547         APInt Max = APIntOps::smax(StartRange.getSignedMax(),
3548                                    EndRange.getSignedMax());
3549         if (Min.isMinSignedValue() && Max.isMaxSignedValue())
3550           return setSignedRange(AddRec, ConservativeResult);
3551         return setSignedRange(AddRec,
3552           ConservativeResult.intersectWith(ConstantRange(Min, Max+1)));
3553       }
3554     }
3555
3556     return setSignedRange(AddRec, ConservativeResult);
3557   }
3558
3559   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3560     // For a SCEVUnknown, ask ValueTracking.
3561     if (!U->getValue()->getType()->isIntegerTy() && !TD)
3562       return setSignedRange(U, ConservativeResult);
3563     unsigned NS = ComputeNumSignBits(U->getValue(), TD);
3564     if (NS == 1)
3565       return setSignedRange(U, ConservativeResult);
3566     return setSignedRange(U, ConservativeResult.intersectWith(
3567       ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
3568                     APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1)));
3569   }
3570
3571   return setSignedRange(S, ConservativeResult);
3572 }
3573
3574 /// createSCEV - We know that there is no SCEV for the specified value.
3575 /// Analyze the expression.
3576 ///
3577 const SCEV *ScalarEvolution::createSCEV(Value *V) {
3578   if (!isSCEVable(V->getType()))
3579     return getUnknown(V);
3580
3581   unsigned Opcode = Instruction::UserOp1;
3582   if (Instruction *I = dyn_cast<Instruction>(V)) {
3583     Opcode = I->getOpcode();
3584
3585     // Don't attempt to analyze instructions in blocks that aren't
3586     // reachable. Such instructions don't matter, and they aren't required
3587     // to obey basic rules for definitions dominating uses which this
3588     // analysis depends on.
3589     if (!DT->isReachableFromEntry(I->getParent()))
3590       return getUnknown(V);
3591   } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
3592     Opcode = CE->getOpcode();
3593   else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
3594     return getConstant(CI);
3595   else if (isa<ConstantPointerNull>(V))
3596     return getConstant(V->getType(), 0);
3597   else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
3598     return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee());
3599   else
3600     return getUnknown(V);
3601
3602   Operator *U = cast<Operator>(V);
3603   switch (Opcode) {
3604   case Instruction::Add: {
3605     // The simple thing to do would be to just call getSCEV on both operands
3606     // and call getAddExpr with the result. However if we're looking at a
3607     // bunch of things all added together, this can be quite inefficient,
3608     // because it leads to N-1 getAddExpr calls for N ultimate operands.
3609     // Instead, gather up all the operands and make a single getAddExpr call.
3610     // LLVM IR canonical form means we need only traverse the left operands.
3611     //
3612     // Don't apply this instruction's NSW or NUW flags to the new
3613     // expression. The instruction may be guarded by control flow that the
3614     // no-wrap behavior depends on. Non-control-equivalent instructions can be
3615     // mapped to the same SCEV expression, and it would be incorrect to transfer
3616     // NSW/NUW semantics to those operations.
3617     SmallVector<const SCEV *, 4> AddOps;
3618     AddOps.push_back(getSCEV(U->getOperand(1)));
3619     for (Value *Op = U->getOperand(0); ; Op = U->getOperand(0)) {
3620       unsigned Opcode = Op->getValueID() - Value::InstructionVal;
3621       if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
3622         break;
3623       U = cast<Operator>(Op);
3624       const SCEV *Op1 = getSCEV(U->getOperand(1));
3625       if (Opcode == Instruction::Sub)
3626         AddOps.push_back(getNegativeSCEV(Op1));
3627       else
3628         AddOps.push_back(Op1);
3629     }
3630     AddOps.push_back(getSCEV(U->getOperand(0)));
3631     return getAddExpr(AddOps);
3632   }
3633   case Instruction::Mul: {
3634     // Don't transfer NSW/NUW for the same reason as AddExpr.
3635     SmallVector<const SCEV *, 4> MulOps;
3636     MulOps.push_back(getSCEV(U->getOperand(1)));
3637     for (Value *Op = U->getOperand(0);
3638          Op->getValueID() == Instruction::Mul + Value::InstructionVal;
3639          Op = U->getOperand(0)) {
3640       U = cast<Operator>(Op);
3641       MulOps.push_back(getSCEV(U->getOperand(1)));
3642     }
3643     MulOps.push_back(getSCEV(U->getOperand(0)));
3644     return getMulExpr(MulOps);
3645   }
3646   case Instruction::UDiv:
3647     return getUDivExpr(getSCEV(U->getOperand(0)),
3648                        getSCEV(U->getOperand(1)));
3649   case Instruction::Sub:
3650     return getMinusSCEV(getSCEV(U->getOperand(0)),
3651                         getSCEV(U->getOperand(1)));
3652   case Instruction::And:
3653     // For an expression like x&255 that merely masks off the high bits,
3654     // use zext(trunc(x)) as the SCEV expression.
3655     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3656       if (CI->isNullValue())
3657         return getSCEV(U->getOperand(1));
3658       if (CI->isAllOnesValue())
3659         return getSCEV(U->getOperand(0));
3660       const APInt &A = CI->getValue();
3661
3662       // Instcombine's ShrinkDemandedConstant may strip bits out of
3663       // constants, obscuring what would otherwise be a low-bits mask.
3664       // Use ComputeMaskedBits to compute what ShrinkDemandedConstant
3665       // knew about to reconstruct a low-bits mask value.
3666       unsigned LZ = A.countLeadingZeros();
3667       unsigned BitWidth = A.getBitWidth();
3668       APInt AllOnes = APInt::getAllOnesValue(BitWidth);
3669       APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
3670       ComputeMaskedBits(U->getOperand(0), AllOnes, KnownZero, KnownOne, TD);
3671
3672       APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ);
3673
3674       if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask))
3675         return
3676           getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)),
3677                                 IntegerType::get(getContext(), BitWidth - LZ)),
3678                             U->getType());
3679     }
3680     break;
3681
3682   case Instruction::Or:
3683     // If the RHS of the Or is a constant, we may have something like:
3684     // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
3685     // optimizations will transparently handle this case.
3686     //
3687     // In order for this transformation to be safe, the LHS must be of the
3688     // form X*(2^n) and the Or constant must be less than 2^n.
3689     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3690       const SCEV *LHS = getSCEV(U->getOperand(0));
3691       const APInt &CIVal = CI->getValue();
3692       if (GetMinTrailingZeros(LHS) >=
3693           (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
3694         // Build a plain add SCEV.
3695         const SCEV *S = getAddExpr(LHS, getSCEV(CI));
3696         // If the LHS of the add was an addrec and it has no-wrap flags,
3697         // transfer the no-wrap flags, since an or won't introduce a wrap.
3698         if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
3699           const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
3700           const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags(
3701             OldAR->getNoWrapFlags());
3702         }
3703         return S;
3704       }
3705     }
3706     break;
3707   case Instruction::Xor:
3708     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3709       // If the RHS of the xor is a signbit, then this is just an add.
3710       // Instcombine turns add of signbit into xor as a strength reduction step.
3711       if (CI->getValue().isSignBit())
3712         return getAddExpr(getSCEV(U->getOperand(0)),
3713                           getSCEV(U->getOperand(1)));
3714
3715       // If the RHS of xor is -1, then this is a not operation.
3716       if (CI->isAllOnesValue())
3717         return getNotSCEV(getSCEV(U->getOperand(0)));
3718
3719       // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
3720       // This is a variant of the check for xor with -1, and it handles
3721       // the case where instcombine has trimmed non-demanded bits out
3722       // of an xor with -1.
3723       if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0)))
3724         if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1)))
3725           if (BO->getOpcode() == Instruction::And &&
3726               LCI->getValue() == CI->getValue())
3727             if (const SCEVZeroExtendExpr *Z =
3728                   dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) {
3729               Type *UTy = U->getType();
3730               const SCEV *Z0 = Z->getOperand();
3731               Type *Z0Ty = Z0->getType();
3732               unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
3733
3734               // If C is a low-bits mask, the zero extend is serving to
3735               // mask off the high bits. Complement the operand and
3736               // re-apply the zext.
3737               if (APIntOps::isMask(Z0TySize, CI->getValue()))
3738                 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
3739
3740               // If C is a single bit, it may be in the sign-bit position
3741               // before the zero-extend. In this case, represent the xor
3742               // using an add, which is equivalent, and re-apply the zext.
3743               APInt Trunc = CI->getValue().trunc(Z0TySize);
3744               if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
3745                   Trunc.isSignBit())
3746                 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
3747                                          UTy);
3748             }
3749     }
3750     break;
3751
3752   case Instruction::Shl:
3753     // Turn shift left of a constant amount into a multiply.
3754     if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
3755       uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
3756
3757       // If the shift count is not less than the bitwidth, the result of
3758       // the shift is undefined. Don't try to analyze it, because the
3759       // resolution chosen here may differ from the resolution chosen in
3760       // other parts of the compiler.
3761       if (SA->getValue().uge(BitWidth))
3762         break;
3763
3764       Constant *X = ConstantInt::get(getContext(),
3765         APInt(BitWidth, 1).shl(SA->getZExtValue()));
3766       return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
3767     }
3768     break;
3769
3770   case Instruction::LShr:
3771     // Turn logical shift right of a constant into a unsigned divide.
3772     if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
3773       uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
3774
3775       // If the shift count is not less than the bitwidth, the result of
3776       // the shift is undefined. Don't try to analyze it, because the
3777       // resolution chosen here may differ from the resolution chosen in
3778       // other parts of the compiler.
3779       if (SA->getValue().uge(BitWidth))
3780         break;
3781
3782       Constant *X = ConstantInt::get(getContext(),
3783         APInt(BitWidth, 1).shl(SA->getZExtValue()));
3784       return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
3785     }
3786     break;
3787
3788   case Instruction::AShr:
3789     // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
3790     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
3791       if (Operator *L = dyn_cast<Operator>(U->getOperand(0)))
3792         if (L->getOpcode() == Instruction::Shl &&
3793             L->getOperand(1) == U->getOperand(1)) {
3794           uint64_t BitWidth = getTypeSizeInBits(U->getType());
3795
3796           // If the shift count is not less than the bitwidth, the result of
3797           // the shift is undefined. Don't try to analyze it, because the
3798           // resolution chosen here may differ from the resolution chosen in
3799           // other parts of the compiler.
3800           if (CI->getValue().uge(BitWidth))
3801             break;
3802
3803           uint64_t Amt = BitWidth - CI->getZExtValue();
3804           if (Amt == BitWidth)
3805             return getSCEV(L->getOperand(0));       // shift by zero --> noop
3806           return
3807             getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
3808                                               IntegerType::get(getContext(),
3809                                                                Amt)),
3810                               U->getType());
3811         }
3812     break;
3813
3814   case Instruction::Trunc:
3815     return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
3816
3817   case Instruction::ZExt:
3818     return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
3819
3820   case Instruction::SExt:
3821     return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
3822
3823   case Instruction::BitCast:
3824     // BitCasts are no-op casts so we just eliminate the cast.
3825     if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
3826       return getSCEV(U->getOperand(0));
3827     break;
3828
3829   // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can
3830   // lead to pointer expressions which cannot safely be expanded to GEPs,
3831   // because ScalarEvolution doesn't respect the GEP aliasing rules when
3832   // simplifying integer expressions.
3833
3834   case Instruction::GetElementPtr:
3835     return createNodeForGEP(cast<GEPOperator>(U));
3836
3837   case Instruction::PHI:
3838     return createNodeForPHI(cast<PHINode>(U));
3839
3840   case Instruction::Select:
3841     // This could be a smax or umax that was lowered earlier.
3842     // Try to recover it.
3843     if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
3844       Value *LHS = ICI->getOperand(0);
3845       Value *RHS = ICI->getOperand(1);
3846       switch (ICI->getPredicate()) {
3847       case ICmpInst::ICMP_SLT:
3848       case ICmpInst::ICMP_SLE:
3849         std::swap(LHS, RHS);
3850         // fall through
3851       case ICmpInst::ICMP_SGT:
3852       case ICmpInst::ICMP_SGE:
3853         // a >s b ? a+x : b+x  ->  smax(a, b)+x
3854         // a >s b ? b+x : a+x  ->  smin(a, b)+x
3855         if (LHS->getType() == U->getType()) {
3856           const SCEV *LS = getSCEV(LHS);
3857           const SCEV *RS = getSCEV(RHS);
3858           const SCEV *LA = getSCEV(U->getOperand(1));
3859           const SCEV *RA = getSCEV(U->getOperand(2));
3860           const SCEV *LDiff = getMinusSCEV(LA, LS);
3861           const SCEV *RDiff = getMinusSCEV(RA, RS);
3862           if (LDiff == RDiff)
3863             return getAddExpr(getSMaxExpr(LS, RS), LDiff);
3864           LDiff = getMinusSCEV(LA, RS);
3865           RDiff = getMinusSCEV(RA, LS);
3866           if (LDiff == RDiff)
3867             return getAddExpr(getSMinExpr(LS, RS), LDiff);
3868         }
3869         break;
3870       case ICmpInst::ICMP_ULT:
3871       case ICmpInst::ICMP_ULE:
3872         std::swap(LHS, RHS);
3873         // fall through
3874       case ICmpInst::ICMP_UGT:
3875       case ICmpInst::ICMP_UGE:
3876         // a >u b ? a+x : b+x  ->  umax(a, b)+x
3877         // a >u b ? b+x : a+x  ->  umin(a, b)+x
3878         if (LHS->getType() == U->getType()) {
3879           const SCEV *LS = getSCEV(LHS);
3880           const SCEV *RS = getSCEV(RHS);
3881           const SCEV *LA = getSCEV(U->getOperand(1));
3882           const SCEV *RA = getSCEV(U->getOperand(2));
3883           const SCEV *LDiff = getMinusSCEV(LA, LS);
3884           const SCEV *RDiff = getMinusSCEV(RA, RS);
3885           if (LDiff == RDiff)
3886             return getAddExpr(getUMaxExpr(LS, RS), LDiff);
3887           LDiff = getMinusSCEV(LA, RS);
3888           RDiff = getMinusSCEV(RA, LS);
3889           if (LDiff == RDiff)
3890             return getAddExpr(getUMinExpr(LS, RS), LDiff);
3891         }
3892         break;
3893       case ICmpInst::ICMP_NE:
3894         // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x
3895         if (LHS->getType() == U->getType() &&
3896             isa<ConstantInt>(RHS) &&
3897             cast<ConstantInt>(RHS)->isZero()) {
3898           const SCEV *One = getConstant(LHS->getType(), 1);
3899           const SCEV *LS = getSCEV(LHS);
3900           const SCEV *LA = getSCEV(U->getOperand(1));
3901           const SCEV *RA = getSCEV(U->getOperand(2));
3902           const SCEV *LDiff = getMinusSCEV(LA, LS);
3903           const SCEV *RDiff = getMinusSCEV(RA, One);
3904           if (LDiff == RDiff)
3905             return getAddExpr(getUMaxExpr(One, LS), LDiff);
3906         }
3907         break;
3908       case ICmpInst::ICMP_EQ:
3909         // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x
3910         if (LHS->getType() == U->getType() &&
3911             isa<ConstantInt>(RHS) &&
3912             cast<ConstantInt>(RHS)->isZero()) {
3913           const SCEV *One = getConstant(LHS->getType(), 1);
3914           const SCEV *LS = getSCEV(LHS);
3915           const SCEV *LA = getSCEV(U->getOperand(1));
3916           const SCEV *RA = getSCEV(U->getOperand(2));
3917           const SCEV *LDiff = getMinusSCEV(LA, One);
3918           const SCEV *RDiff = getMinusSCEV(RA, LS);
3919           if (LDiff == RDiff)
3920             return getAddExpr(getUMaxExpr(One, LS), LDiff);
3921         }
3922         break;
3923       default:
3924         break;
3925       }
3926     }
3927
3928   default: // We cannot analyze this expression.
3929     break;
3930   }
3931
3932   return getUnknown(V);
3933 }
3934
3935
3936
3937 //===----------------------------------------------------------------------===//
3938 //                   Iteration Count Computation Code
3939 //
3940
3941 /// getSmallConstantTripCount - Returns the maximum trip count of this loop as a
3942 /// normal unsigned value. Returns 0 if the trip count is unknown or not
3943 /// constant. Will also return 0 if the maximum trip count is very large (>=
3944 /// 2^32).
3945 ///
3946 /// This "trip count" assumes that control exits via ExitingBlock. More
3947 /// precisely, it is the number of times that control may reach ExitingBlock
3948 /// before taking the branch. For loops with multiple exits, it may not be the
3949 /// number times that the loop header executes because the loop may exit
3950 /// prematurely via another branch.
3951 unsigned ScalarEvolution::
3952 getSmallConstantTripCount(Loop *L, BasicBlock *ExitingBlock) {
3953   const SCEVConstant *ExitCount =
3954     dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
3955   if (!ExitCount)
3956     return 0;
3957
3958   ConstantInt *ExitConst = ExitCount->getValue();
3959
3960   // Guard against huge trip counts.
3961   if (ExitConst->getValue().getActiveBits() > 32)
3962     return 0;
3963
3964   // In case of integer overflow, this returns 0, which is correct.
3965   return ((unsigned)ExitConst->getZExtValue()) + 1;
3966 }
3967
3968 /// getSmallConstantTripMultiple - Returns the largest constant divisor of the
3969 /// trip count of this loop as a normal unsigned value, if possible. This
3970 /// means that the actual trip count is always a multiple of the returned
3971 /// value (don't forget the trip count could very well be zero as well!).
3972 ///
3973 /// Returns 1 if the trip count is unknown or not guaranteed to be the
3974 /// multiple of a constant (which is also the case if the trip count is simply
3975 /// constant, use getSmallConstantTripCount for that case), Will also return 1
3976 /// if the trip count is very large (>= 2^32).
3977 ///
3978 /// As explained in the comments for getSmallConstantTripCount, this assumes
3979 /// that control exits the loop via ExitingBlock.
3980 unsigned ScalarEvolution::
3981 getSmallConstantTripMultiple(Loop *L, BasicBlock *ExitingBlock) {
3982   const SCEV *ExitCount = getExitCount(L, ExitingBlock);
3983   if (ExitCount == getCouldNotCompute())
3984     return 1;
3985
3986   // Get the trip count from the BE count by adding 1.
3987   const SCEV *TCMul = getAddExpr(ExitCount,
3988                                  getConstant(ExitCount->getType(), 1));
3989   // FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt
3990   // to factor simple cases.
3991   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(TCMul))
3992     TCMul = Mul->getOperand(0);
3993
3994   const SCEVConstant *MulC = dyn_cast<SCEVConstant>(TCMul);
3995   if (!MulC)
3996     return 1;
3997
3998   ConstantInt *Result = MulC->getValue();
3999
4000   // Guard against huge trip counts.
4001   if (!Result || Result->getValue().getActiveBits() > 32)
4002     return 1;
4003
4004   return (unsigned)Result->getZExtValue();
4005 }
4006
4007 // getExitCount - Get the expression for the number of loop iterations for which
4008 // this loop is guaranteed not to exit via ExitintBlock. Otherwise return
4009 // SCEVCouldNotCompute.
4010 const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) {
4011   return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
4012 }
4013
4014 /// getBackedgeTakenCount - If the specified loop has a predictable
4015 /// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
4016 /// object. The backedge-taken count is the number of times the loop header
4017 /// will be branched to from within the loop. This is one less than the
4018 /// trip count of the loop, since it doesn't count the first iteration,
4019 /// when the header is branched to from outside the loop.
4020 ///
4021 /// Note that it is not valid to call this method on a loop without a
4022 /// loop-invariant backedge-taken count (see
4023 /// hasLoopInvariantBackedgeTakenCount).
4024 ///
4025 const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
4026   return getBackedgeTakenInfo(L).getExact(this);
4027 }
4028
4029 /// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except
4030 /// return the least SCEV value that is known never to be less than the
4031 /// actual backedge taken count.
4032 const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
4033   return getBackedgeTakenInfo(L).getMax(this);
4034 }
4035
4036 /// PushLoopPHIs - Push PHI nodes in the header of the given loop
4037 /// onto the given Worklist.
4038 static void
4039 PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
4040   BasicBlock *Header = L->getHeader();
4041
4042   // Push all Loop-header PHIs onto the Worklist stack.
4043   for (BasicBlock::iterator I = Header->begin();
4044        PHINode *PN = dyn_cast<PHINode>(I); ++I)
4045     Worklist.push_back(PN);
4046 }
4047
4048 const ScalarEvolution::BackedgeTakenInfo &
4049 ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
4050   // Initially insert an invalid entry for this loop. If the insertion
4051   // succeeds, proceed to actually compute a backedge-taken count and
4052   // update the value. The temporary CouldNotCompute value tells SCEV
4053   // code elsewhere that it shouldn't attempt to request a new
4054   // backedge-taken count, which could result in infinite recursion.
4055   std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
4056     BackedgeTakenCounts.insert(std::make_pair(L, BackedgeTakenInfo()));
4057   if (!Pair.second)
4058     return Pair.first->second;
4059
4060   // ComputeBackedgeTakenCount may allocate memory for its result. Inserting it
4061   // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
4062   // must be cleared in this scope.
4063   BackedgeTakenInfo Result = ComputeBackedgeTakenCount(L);
4064
4065   if (Result.getExact(this) != getCouldNotCompute()) {
4066     assert(isLoopInvariant(Result.getExact(this), L) &&
4067            isLoopInvariant(Result.getMax(this), L) &&
4068            "Computed backedge-taken count isn't loop invariant for loop!");
4069     ++NumTripCountsComputed;
4070   }
4071   else if (Result.getMax(this) == getCouldNotCompute() &&
4072            isa<PHINode>(L->getHeader()->begin())) {
4073     // Only count loops that have phi nodes as not being computable.
4074     ++NumTripCountsNotComputed;
4075   }
4076
4077   // Now that we know more about the trip count for this loop, forget any
4078   // existing SCEV values for PHI nodes in this loop since they are only
4079   // conservative estimates made without the benefit of trip count
4080   // information. This is similar to the code in forgetLoop, except that
4081   // it handles SCEVUnknown PHI nodes specially.
4082   if (Result.hasAnyInfo()) {
4083     SmallVector<Instruction *, 16> Worklist;
4084     PushLoopPHIs(L, Worklist);
4085
4086     SmallPtrSet<Instruction *, 8> Visited;
4087     while (!Worklist.empty()) {
4088       Instruction *I = Worklist.pop_back_val();
4089       if (!Visited.insert(I)) continue;
4090
4091       ValueExprMapType::iterator It =
4092         ValueExprMap.find(static_cast<Value *>(I));
4093       if (It != ValueExprMap.end()) {
4094         const SCEV *Old = It->second;
4095
4096         // SCEVUnknown for a PHI either means that it has an unrecognized
4097         // structure, or it's a PHI that's in the progress of being computed
4098         // by createNodeForPHI.  In the former case, additional loop trip
4099         // count information isn't going to change anything. In the later
4100         // case, createNodeForPHI will perform the necessary updates on its
4101         // own when it gets to that point.
4102         if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) {
4103           forgetMemoizedResults(Old);
4104           ValueExprMap.erase(It);
4105         }
4106         if (PHINode *PN = dyn_cast<PHINode>(I))
4107           ConstantEvolutionLoopExitValue.erase(PN);
4108       }
4109
4110       PushDefUseChildren(I, Worklist);
4111     }
4112   }
4113
4114   // Re-lookup the insert position, since the call to
4115   // ComputeBackedgeTakenCount above could result in a
4116   // recusive call to getBackedgeTakenInfo (on a different
4117   // loop), which would invalidate the iterator computed
4118   // earlier.
4119   return BackedgeTakenCounts.find(L)->second = Result;
4120 }
4121
4122 /// forgetLoop - This method should be called by the client when it has
4123 /// changed a loop in a way that may effect ScalarEvolution's ability to
4124 /// compute a trip count, or if the loop is deleted.
4125 void ScalarEvolution::forgetLoop(const Loop *L) {
4126   // Drop any stored trip count value.
4127   DenseMap<const Loop*, BackedgeTakenInfo>::iterator BTCPos =
4128     BackedgeTakenCounts.find(L);
4129   if (BTCPos != BackedgeTakenCounts.end()) {
4130     BTCPos->second.clear();
4131     BackedgeTakenCounts.erase(BTCPos);
4132   }
4133
4134   // Drop information about expressions based on loop-header PHIs.
4135   SmallVector<Instruction *, 16> Worklist;
4136   PushLoopPHIs(L, Worklist);
4137
4138   SmallPtrSet<Instruction *, 8> Visited;
4139   while (!Worklist.empty()) {
4140     Instruction *I = Worklist.pop_back_val();
4141     if (!Visited.insert(I)) continue;
4142
4143     ValueExprMapType::iterator It = ValueExprMap.find(static_cast<Value *>(I));
4144     if (It != ValueExprMap.end()) {
4145       forgetMemoizedResults(It->second);
4146       ValueExprMap.erase(It);
4147       if (PHINode *PN = dyn_cast<PHINode>(I))
4148         ConstantEvolutionLoopExitValue.erase(PN);
4149     }
4150
4151     PushDefUseChildren(I, Worklist);
4152   }
4153
4154   // Forget all contained loops too, to avoid dangling entries in the
4155   // ValuesAtScopes map.
4156   for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
4157     forgetLoop(*I);
4158 }
4159
4160 /// forgetValue - This method should be called by the client when it has
4161 /// changed a value in a way that may effect its value, or which may
4162 /// disconnect it from a def-use chain linking it to a loop.
4163 void ScalarEvolution::forgetValue(Value *V) {
4164   Instruction *I = dyn_cast<Instruction>(V);
4165   if (!I) return;
4166
4167   // Drop information about expressions based on loop-header PHIs.
4168   SmallVector<Instruction *, 16> Worklist;
4169   Worklist.push_back(I);
4170
4171   SmallPtrSet<Instruction *, 8> Visited;
4172   while (!Worklist.empty()) {
4173     I = Worklist.pop_back_val();
4174     if (!Visited.insert(I)) continue;
4175
4176     ValueExprMapType::iterator It = ValueExprMap.find(static_cast<Value *>(I));
4177     if (It != ValueExprMap.end()) {
4178       forgetMemoizedResults(It->second);
4179       ValueExprMap.erase(It);
4180       if (PHINode *PN = dyn_cast<PHINode>(I))
4181         ConstantEvolutionLoopExitValue.erase(PN);
4182     }
4183
4184     PushDefUseChildren(I, Worklist);
4185   }
4186 }
4187
4188 /// getExact - Get the exact loop backedge taken count considering all loop
4189 /// exits. A computable result can only be return for loops with a single exit.
4190 /// Returning the minimum taken count among all exits is incorrect because one
4191 /// of the loop's exit limit's may have been skipped. HowFarToZero assumes that
4192 /// the limit of each loop test is never skipped. This is a valid assumption as
4193 /// long as the loop exits via that test. For precise results, it is the
4194 /// caller's responsibility to specify the relevant loop exit using
4195 /// getExact(ExitingBlock, SE).
4196 const SCEV *
4197 ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const {
4198   // If any exits were not computable, the loop is not computable.
4199   if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute();
4200
4201   // We need exactly one computable exit.
4202   if (!ExitNotTaken.ExitingBlock) return SE->getCouldNotCompute();
4203   assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info");
4204
4205   const SCEV *BECount = 0;
4206   for (const ExitNotTakenInfo *ENT = &ExitNotTaken;
4207        ENT != 0; ENT = ENT->getNextExit()) {
4208
4209     assert(ENT->ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV");
4210
4211     if (!BECount)
4212       BECount = ENT->ExactNotTaken;
4213     else if (BECount != ENT->ExactNotTaken)
4214       return SE->getCouldNotCompute();
4215   }
4216   assert(BECount && "Invalid not taken count for loop exit");
4217   return BECount;
4218 }
4219
4220 /// getExact - Get the exact not taken count for this loop exit.
4221 const SCEV *
4222 ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock,
4223                                              ScalarEvolution *SE) const {
4224   for (const ExitNotTakenInfo *ENT = &ExitNotTaken;
4225        ENT != 0; ENT = ENT->getNextExit()) {
4226
4227     if (ENT->ExitingBlock == ExitingBlock)
4228       return ENT->ExactNotTaken;
4229   }
4230   return SE->getCouldNotCompute();
4231 }
4232
4233 /// getMax - Get the max backedge taken count for the loop.
4234 const SCEV *
4235 ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const {
4236   return Max ? Max : SE->getCouldNotCompute();
4237 }
4238
4239 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
4240 /// computable exit into a persistent ExitNotTakenInfo array.
4241 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
4242   SmallVectorImpl< std::pair<BasicBlock *, const SCEV *> > &ExitCounts,
4243   bool Complete, const SCEV *MaxCount) : Max(MaxCount) {
4244
4245   if (!Complete)
4246     ExitNotTaken.setIncomplete();
4247
4248   unsigned NumExits = ExitCounts.size();
4249   if (NumExits == 0) return;
4250
4251   ExitNotTaken.ExitingBlock = ExitCounts[0].first;
4252   ExitNotTaken.ExactNotTaken = ExitCounts[0].second;
4253   if (NumExits == 1) return;
4254
4255   // Handle the rare case of multiple computable exits.
4256   ExitNotTakenInfo *ENT = new ExitNotTakenInfo[NumExits-1];
4257
4258   ExitNotTakenInfo *PrevENT = &ExitNotTaken;
4259   for (unsigned i = 1; i < NumExits; ++i, PrevENT = ENT, ++ENT) {
4260     PrevENT->setNextExit(ENT);
4261     ENT->ExitingBlock = ExitCounts[i].first;
4262     ENT->ExactNotTaken = ExitCounts[i].second;
4263   }
4264 }
4265
4266 /// clear - Invalidate this result and free the ExitNotTakenInfo array.
4267 void ScalarEvolution::BackedgeTakenInfo::clear() {
4268   ExitNotTaken.ExitingBlock = 0;
4269   ExitNotTaken.ExactNotTaken = 0;
4270   delete[] ExitNotTaken.getNextExit();
4271 }
4272
4273 /// ComputeBackedgeTakenCount - Compute the number of times the backedge
4274 /// of the specified loop will execute.
4275 ScalarEvolution::BackedgeTakenInfo
4276 ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
4277   SmallVector<BasicBlock *, 8> ExitingBlocks;
4278   L->getExitingBlocks(ExitingBlocks);
4279
4280   // Examine all exits and pick the most conservative values.
4281   const SCEV *MaxBECount = getCouldNotCompute();
4282   bool CouldComputeBECount = true;
4283   SmallVector<std::pair<BasicBlock *, const SCEV *>, 4> ExitCounts;
4284   for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
4285     ExitLimit EL = ComputeExitLimit(L, ExitingBlocks[i]);
4286     if (EL.Exact == getCouldNotCompute())
4287       // We couldn't compute an exact value for this exit, so
4288       // we won't be able to compute an exact value for the loop.
4289       CouldComputeBECount = false;
4290     else
4291       ExitCounts.push_back(std::make_pair(ExitingBlocks[i], EL.Exact));
4292
4293     if (MaxBECount == getCouldNotCompute())
4294       MaxBECount = EL.Max;
4295     else if (EL.Max != getCouldNotCompute()) {
4296       // We cannot take the "min" MaxBECount, because non-unit stride loops may
4297       // skip some loop tests. Taking the max over the exits is sufficiently
4298       // conservative.  TODO: We could do better taking into consideration
4299       // that (1) the loop has unit stride (2) the last loop test is
4300       // less-than/greater-than (3) any loop test is less-than/greater-than AND
4301       // falls-through some constant times less then the other tests.
4302       MaxBECount = getUMaxFromMismatchedTypes(MaxBECount, EL.Max);
4303     }
4304   }
4305
4306   return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount);
4307 }
4308
4309 /// ComputeExitLimit - Compute the number of times the backedge of the specified
4310 /// loop will execute if it exits via the specified block.
4311 ScalarEvolution::ExitLimit
4312 ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) {
4313
4314   // Okay, we've chosen an exiting block.  See what condition causes us to
4315   // exit at this block.
4316   //
4317   // FIXME: we should be able to handle switch instructions (with a single exit)
4318   BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
4319   if (ExitBr == 0) return getCouldNotCompute();
4320   assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
4321
4322   // At this point, we know we have a conditional branch that determines whether
4323   // the loop is exited.  However, we don't know if the branch is executed each
4324   // time through the loop.  If not, then the execution count of the branch will
4325   // not be equal to the trip count of the loop.
4326   //
4327   // Currently we check for this by checking to see if the Exit branch goes to
4328   // the loop header.  If so, we know it will always execute the same number of
4329   // times as the loop.  We also handle the case where the exit block *is* the
4330   // loop header.  This is common for un-rotated loops.
4331   //
4332   // If both of those tests fail, walk up the unique predecessor chain to the
4333   // header, stopping if there is an edge that doesn't exit the loop. If the
4334   // header is reached, the execution count of the branch will be equal to the
4335   // trip count of the loop.
4336   //
4337   //  More extensive analysis could be done to handle more cases here.
4338   //
4339   if (ExitBr->getSuccessor(0) != L->getHeader() &&
4340       ExitBr->getSuccessor(1) != L->getHeader() &&
4341       ExitBr->getParent() != L->getHeader()) {
4342     // The simple checks failed, try climbing the unique predecessor chain
4343     // up to the header.
4344     bool Ok = false;
4345     for (BasicBlock *BB = ExitBr->getParent(); BB; ) {
4346       BasicBlock *Pred = BB->getUniquePredecessor();
4347       if (!Pred)
4348         return getCouldNotCompute();
4349       TerminatorInst *PredTerm = Pred->getTerminator();
4350       for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) {
4351         BasicBlock *PredSucc = PredTerm->getSuccessor(i);
4352         if (PredSucc == BB)
4353           continue;
4354         // If the predecessor has a successor that isn't BB and isn't
4355         // outside the loop, assume the worst.
4356         if (L->contains(PredSucc))
4357           return getCouldNotCompute();
4358       }
4359       if (Pred == L->getHeader()) {
4360         Ok = true;
4361         break;
4362       }
4363       BB = Pred;
4364     }
4365     if (!Ok)
4366       return getCouldNotCompute();
4367   }
4368
4369   // Proceed to the next level to examine the exit condition expression.
4370   return ComputeExitLimitFromCond(L, ExitBr->getCondition(),
4371                                   ExitBr->getSuccessor(0),
4372                                   ExitBr->getSuccessor(1));
4373 }
4374
4375 /// ComputeExitLimitFromCond - Compute the number of times the
4376 /// backedge of the specified loop will execute if its exit condition
4377 /// were a conditional branch of ExitCond, TBB, and FBB.
4378 ScalarEvolution::ExitLimit
4379 ScalarEvolution::ComputeExitLimitFromCond(const Loop *L,
4380                                           Value *ExitCond,
4381                                           BasicBlock *TBB,
4382                                           BasicBlock *FBB) {
4383   // Check if the controlling expression for this loop is an And or Or.
4384   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
4385     if (BO->getOpcode() == Instruction::And) {
4386       // Recurse on the operands of the and.
4387       ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB);
4388       ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB);
4389       const SCEV *BECount = getCouldNotCompute();
4390       const SCEV *MaxBECount = getCouldNotCompute();
4391       if (L->contains(TBB)) {
4392         // Both conditions must be true for the loop to continue executing.
4393         // Choose the less conservative count.
4394         if (EL0.Exact == getCouldNotCompute() ||
4395             EL1.Exact == getCouldNotCompute())
4396           BECount = getCouldNotCompute();
4397         else
4398           BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact);
4399         if (EL0.Max == getCouldNotCompute())
4400           MaxBECount = EL1.Max;
4401         else if (EL1.Max == getCouldNotCompute())
4402           MaxBECount = EL0.Max;
4403         else
4404           MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max);
4405       } else {
4406         // Both conditions must be true at the same time for the loop to exit.
4407         // For now, be conservative.
4408         assert(L->contains(FBB) && "Loop block has no successor in loop!");
4409         if (EL0.Max == EL1.Max)
4410           MaxBECount = EL0.Max;
4411         if (EL0.Exact == EL1.Exact)
4412           BECount = EL0.Exact;
4413       }
4414
4415       return ExitLimit(BECount, MaxBECount);
4416     }
4417     if (BO->getOpcode() == Instruction::Or) {
4418       // Recurse on the operands of the or.
4419       ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB);
4420       ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB);
4421       const SCEV *BECount = getCouldNotCompute();
4422       const SCEV *MaxBECount = getCouldNotCompute();
4423       if (L->contains(FBB)) {
4424         // Both conditions must be false for the loop to continue executing.
4425         // Choose the less conservative count.
4426         if (EL0.Exact == getCouldNotCompute() ||
4427             EL1.Exact == getCouldNotCompute())
4428           BECount = getCouldNotCompute();
4429         else
4430           BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact);
4431         if (EL0.Max == getCouldNotCompute())
4432           MaxBECount = EL1.Max;
4433         else if (EL1.Max == getCouldNotCompute())
4434           MaxBECount = EL0.Max;
4435         else
4436           MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max);
4437       } else {
4438         // Both conditions must be false at the same time for the loop to exit.
4439         // For now, be conservative.
4440         assert(L->contains(TBB) && "Loop block has no successor in loop!");
4441         if (EL0.Max == EL1.Max)
4442           MaxBECount = EL0.Max;
4443         if (EL0.Exact == EL1.Exact)
4444           BECount = EL0.Exact;
4445       }
4446
4447       return ExitLimit(BECount, MaxBECount);
4448     }
4449   }
4450
4451   // With an icmp, it may be feasible to compute an exact backedge-taken count.
4452   // Proceed to the next level to examine the icmp.
4453   if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
4454     return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB);
4455
4456   // Check for a constant condition. These are normally stripped out by
4457   // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
4458   // preserve the CFG and is temporarily leaving constant conditions
4459   // in place.
4460   if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
4461     if (L->contains(FBB) == !CI->getZExtValue())
4462       // The backedge is always taken.
4463       return getCouldNotCompute();
4464     else
4465       // The backedge is never taken.
4466       return getConstant(CI->getType(), 0);
4467   }
4468
4469   // If it's not an integer or pointer comparison then compute it the hard way.
4470   return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
4471 }
4472
4473 /// ComputeExitLimitFromICmp - Compute the number of times the
4474 /// backedge of the specified loop will execute if its exit condition
4475 /// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB.
4476 ScalarEvolution::ExitLimit
4477 ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L,
4478                                           ICmpInst *ExitCond,
4479                                           BasicBlock *TBB,
4480                                           BasicBlock *FBB) {
4481
4482   // If the condition was exit on true, convert the condition to exit on false
4483   ICmpInst::Predicate Cond;
4484   if (!L->contains(FBB))
4485     Cond = ExitCond->getPredicate();
4486   else
4487     Cond = ExitCond->getInversePredicate();
4488
4489   // Handle common loops like: for (X = "string"; *X; ++X)
4490   if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
4491     if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
4492       ExitLimit ItCnt =
4493         ComputeLoadConstantCompareExitLimit(LI, RHS, L, Cond);
4494       if (ItCnt.hasAnyInfo())
4495         return ItCnt;
4496     }
4497
4498   const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
4499   const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
4500
4501   // Try to evaluate any dependencies out of the loop.
4502   LHS = getSCEVAtScope(LHS, L);
4503   RHS = getSCEVAtScope(RHS, L);
4504
4505   // At this point, we would like to compute how many iterations of the
4506   // loop the predicate will return true for these inputs.
4507   if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
4508     // If there is a loop-invariant, force it into the RHS.
4509     std::swap(LHS, RHS);
4510     Cond = ICmpInst::getSwappedPredicate(Cond);
4511   }
4512
4513   // Simplify the operands before analyzing them.
4514   (void)SimplifyICmpOperands(Cond, LHS, RHS);
4515
4516   // If we have a comparison of a chrec against a constant, try to use value
4517   // ranges to answer this query.
4518   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
4519     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
4520       if (AddRec->getLoop() == L) {
4521         // Form the constant range.
4522         ConstantRange CompRange(
4523             ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue()));
4524
4525         const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
4526         if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
4527       }
4528
4529   switch (Cond) {
4530   case ICmpInst::ICMP_NE: {                     // while (X != Y)
4531     // Convert to: while (X-Y != 0)
4532     ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L);
4533     if (EL.hasAnyInfo()) return EL;
4534     break;
4535   }
4536   case ICmpInst::ICMP_EQ: {                     // while (X == Y)
4537     // Convert to: while (X-Y == 0)
4538     ExitLimit EL = HowFarToNonZero(getMinusSCEV(LHS, RHS), L);
4539     if (EL.hasAnyInfo()) return EL;
4540     break;
4541   }
4542   case ICmpInst::ICMP_SLT: {
4543     ExitLimit EL = HowManyLessThans(LHS, RHS, L, true);
4544     if (EL.hasAnyInfo()) return EL;
4545     break;
4546   }
4547   case ICmpInst::ICMP_SGT: {
4548     ExitLimit EL = HowManyLessThans(getNotSCEV(LHS),
4549                                              getNotSCEV(RHS), L, true);
4550     if (EL.hasAnyInfo()) return EL;
4551     break;
4552   }
4553   case ICmpInst::ICMP_ULT: {
4554     ExitLimit EL = HowManyLessThans(LHS, RHS, L, false);
4555     if (EL.hasAnyInfo()) return EL;
4556     break;
4557   }
4558   case ICmpInst::ICMP_UGT: {
4559     ExitLimit EL = HowManyLessThans(getNotSCEV(LHS),
4560                                              getNotSCEV(RHS), L, false);
4561     if (EL.hasAnyInfo()) return EL;
4562     break;
4563   }
4564   default:
4565 #if 0
4566     dbgs() << "ComputeBackedgeTakenCount ";
4567     if (ExitCond->getOperand(0)->getType()->isUnsigned())
4568       dbgs() << "[unsigned] ";
4569     dbgs() << *LHS << "   "
4570          << Instruction::getOpcodeName(Instruction::ICmp)
4571          << "   " << *RHS << "\n";
4572 #endif
4573     break;
4574   }
4575   return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
4576 }
4577
4578 static ConstantInt *
4579 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
4580                                 ScalarEvolution &SE) {
4581   const SCEV *InVal = SE.getConstant(C);
4582   const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
4583   assert(isa<SCEVConstant>(Val) &&
4584          "Evaluation of SCEV at constant didn't fold correctly?");
4585   return cast<SCEVConstant>(Val)->getValue();
4586 }
4587
4588 /// GetAddressedElementFromGlobal - Given a global variable with an initializer
4589 /// and a GEP expression (missing the pointer index) indexing into it, return
4590 /// the addressed element of the initializer or null if the index expression is
4591 /// invalid.
4592 static Constant *
4593 GetAddressedElementFromGlobal(GlobalVariable *GV,
4594                               const std::vector<ConstantInt*> &Indices) {
4595   Constant *Init = GV->getInitializer();
4596   for (unsigned i = 0, e = Indices.size(); i != e; ++i) {
4597     uint64_t Idx = Indices[i]->getZExtValue();
4598     if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Init)) {
4599       assert(Idx < CS->getNumOperands() && "Bad struct index!");
4600       Init = cast<Constant>(CS->getOperand(Idx));
4601     } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Init)) {
4602       if (Idx >= CA->getNumOperands()) return 0;  // Bogus program
4603       Init = cast<Constant>(CA->getOperand(Idx));
4604     } else if (isa<ConstantAggregateZero>(Init)) {
4605       if (StructType *STy = dyn_cast<StructType>(Init->getType())) {
4606         assert(Idx < STy->getNumElements() && "Bad struct index!");
4607         Init = Constant::getNullValue(STy->getElementType(Idx));
4608       } else if (ArrayType *ATy = dyn_cast<ArrayType>(Init->getType())) {
4609         if (Idx >= ATy->getNumElements()) return 0;  // Bogus program
4610         Init = Constant::getNullValue(ATy->getElementType());
4611       } else {
4612         llvm_unreachable("Unknown constant aggregate type!");
4613       }
4614       return 0;
4615     } else {
4616       return 0; // Unknown initializer type
4617     }
4618   }
4619   return Init;
4620 }
4621
4622 /// ComputeLoadConstantCompareExitLimit - Given an exit condition of
4623 /// 'icmp op load X, cst', try to see if we can compute the backedge
4624 /// execution count.
4625 ScalarEvolution::ExitLimit
4626 ScalarEvolution::ComputeLoadConstantCompareExitLimit(
4627   LoadInst *LI,
4628   Constant *RHS,
4629   const Loop *L,
4630   ICmpInst::Predicate predicate) {
4631
4632   if (LI->isVolatile()) return getCouldNotCompute();
4633
4634   // Check to see if the loaded pointer is a getelementptr of a global.
4635   // TODO: Use SCEV instead of manually grubbing with GEPs.
4636   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
4637   if (!GEP) return getCouldNotCompute();
4638
4639   // Make sure that it is really a constant global we are gepping, with an
4640   // initializer, and make sure the first IDX is really 0.
4641   GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
4642   if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() ||
4643       GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
4644       !cast<Constant>(GEP->getOperand(1))->isNullValue())
4645     return getCouldNotCompute();
4646
4647   // Okay, we allow one non-constant index into the GEP instruction.
4648   Value *VarIdx = 0;
4649   std::vector<ConstantInt*> Indexes;
4650   unsigned VarIdxNum = 0;
4651   for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
4652     if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
4653       Indexes.push_back(CI);
4654     } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
4655       if (VarIdx) return getCouldNotCompute();  // Multiple non-constant idx's.
4656       VarIdx = GEP->getOperand(i);
4657       VarIdxNum = i-2;
4658       Indexes.push_back(0);
4659     }
4660
4661   // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
4662   // Check to see if X is a loop variant variable value now.
4663   const SCEV *Idx = getSCEV(VarIdx);
4664   Idx = getSCEVAtScope(Idx, L);
4665
4666   // We can only recognize very limited forms of loop index expressions, in
4667   // particular, only affine AddRec's like {C1,+,C2}.
4668   const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
4669   if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) ||
4670       !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
4671       !isa<SCEVConstant>(IdxExpr->getOperand(1)))
4672     return getCouldNotCompute();
4673
4674   unsigned MaxSteps = MaxBruteForceIterations;
4675   for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
4676     ConstantInt *ItCst = ConstantInt::get(
4677                            cast<IntegerType>(IdxExpr->getType()), IterationNum);
4678     ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
4679
4680     // Form the GEP offset.
4681     Indexes[VarIdxNum] = Val;
4682
4683     Constant *Result = GetAddressedElementFromGlobal(GV, Indexes);
4684     if (Result == 0) break;  // Cannot compute!
4685
4686     // Evaluate the condition for this iteration.
4687     Result = ConstantExpr::getICmp(predicate, Result, RHS);
4688     if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
4689     if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
4690 #if 0
4691       dbgs() << "\n***\n*** Computed loop count " << *ItCst
4692              << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
4693              << "***\n";
4694 #endif
4695       ++NumArrayLenItCounts;
4696       return getConstant(ItCst);   // Found terminating iteration!
4697     }
4698   }
4699   return getCouldNotCompute();
4700 }
4701
4702
4703 /// CanConstantFold - Return true if we can constant fold an instruction of the
4704 /// specified type, assuming that all operands were constants.
4705 static bool CanConstantFold(const Instruction *I) {
4706   if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
4707       isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
4708       isa<LoadInst>(I))
4709     return true;
4710
4711   if (const CallInst *CI = dyn_cast<CallInst>(I))
4712     if (const Function *F = CI->getCalledFunction())
4713       return canConstantFoldCallTo(F);
4714   return false;
4715 }
4716
4717 /// Determine whether this instruction can constant evolve within this loop
4718 /// assuming its operands can all constant evolve.
4719 static bool canConstantEvolve(Instruction *I, const Loop *L) {
4720   // An instruction outside of the loop can't be derived from a loop PHI.
4721   if (!L->contains(I)) return false;
4722
4723   if (isa<PHINode>(I)) {
4724     if (L->getHeader() == I->getParent())
4725       return true;
4726     else
4727       // We don't currently keep track of the control flow needed to evaluate
4728       // PHIs, so we cannot handle PHIs inside of loops.
4729       return false;
4730   }
4731
4732   // If we won't be able to constant fold this expression even if the operands
4733   // are constants, bail early.
4734   return CanConstantFold(I);
4735 }
4736
4737 /// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
4738 /// recursing through each instruction operand until reaching a loop header phi.
4739 static PHINode *
4740 getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L,
4741                                DenseMap<Instruction *, PHINode *> &PHIMap) {
4742
4743   // Otherwise, we can evaluate this instruction if all of its operands are
4744   // constant or derived from a PHI node themselves.
4745   PHINode *PHI = 0;
4746   for (Instruction::op_iterator OpI = UseInst->op_begin(),
4747          OpE = UseInst->op_end(); OpI != OpE; ++OpI) {
4748
4749     if (isa<Constant>(*OpI)) continue;
4750
4751     Instruction *OpInst = dyn_cast<Instruction>(*OpI);
4752     if (!OpInst || !canConstantEvolve(OpInst, L)) return 0;
4753
4754     PHINode *P = dyn_cast<PHINode>(OpInst);
4755     if (!P)
4756       // If this operand is already visited, reuse the prior result.
4757       // We may have P != PHI if this is the deepest point at which the
4758       // inconsistent paths meet.
4759       P = PHIMap.lookup(OpInst);
4760     if (!P) {
4761       // Recurse and memoize the results, whether a phi is found or not.
4762       // This recursive call invalidates pointers into PHIMap.
4763       P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap);
4764       PHIMap[OpInst] = P;
4765     }
4766     if (P == 0) return 0;        // Not evolving from PHI
4767     if (PHI && PHI != P) return 0;  // Evolving from multiple different PHIs.
4768     PHI = P;
4769   }
4770   // This is a expression evolving from a constant PHI!
4771   return PHI;
4772 }
4773
4774 /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
4775 /// in the loop that V is derived from.  We allow arbitrary operations along the
4776 /// way, but the operands of an operation must either be constants or a value
4777 /// derived from a constant PHI.  If this expression does not fit with these
4778 /// constraints, return null.
4779 static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
4780   Instruction *I = dyn_cast<Instruction>(V);
4781   if (I == 0 || !canConstantEvolve(I, L)) return 0;
4782
4783   if (PHINode *PN = dyn_cast<PHINode>(I)) {
4784     return PN;
4785   }
4786
4787   // Record non-constant instructions contained by the loop.
4788   DenseMap<Instruction *, PHINode *> PHIMap;
4789   return getConstantEvolvingPHIOperands(I, L, PHIMap);
4790 }
4791
4792 /// EvaluateExpression - Given an expression that passes the
4793 /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
4794 /// in the loop has the value PHIVal.  If we can't fold this expression for some
4795 /// reason, return null.
4796 static Constant *EvaluateExpression(Value *V, const Loop *L,
4797                                     DenseMap<Instruction *, Constant *> &Vals,
4798                                     const TargetData *TD,
4799                                     const TargetLibraryInfo *TLI) {
4800   // Convenient constant check, but redundant for recursive calls.
4801   if (Constant *C = dyn_cast<Constant>(V)) return C;
4802   Instruction *I = dyn_cast<Instruction>(V);
4803   if (!I) return 0;
4804
4805   if (Constant *C = Vals.lookup(I)) return C;
4806
4807   // An instruction inside the loop depends on a value outside the loop that we
4808   // weren't given a mapping for, or a value such as a call inside the loop.
4809   if (!canConstantEvolve(I, L)) return 0;
4810
4811   // An unmapped PHI can be due to a branch or another loop inside this loop,
4812   // or due to this not being the initial iteration through a loop where we
4813   // couldn't compute the evolution of this particular PHI last time.
4814   if (isa<PHINode>(I)) return 0;
4815
4816   std::vector<Constant*> Operands(I->getNumOperands());
4817
4818   for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
4819     Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
4820     if (!Operand) {
4821       Operands[i] = dyn_cast<Constant>(I->getOperand(i));
4822       if (!Operands[i]) return 0;
4823       continue;
4824     }
4825     Constant *C = EvaluateExpression(Operand, L, Vals, TD, TLI);
4826     Vals[Operand] = C;
4827     if (!C) return 0;
4828     Operands[i] = C;
4829   }
4830
4831   if (CmpInst *CI = dyn_cast<CmpInst>(I))
4832     return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
4833                                            Operands[1], TD, TLI);
4834   if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
4835     if (!LI->isVolatile())
4836       return ConstantFoldLoadFromConstPtr(Operands[0], TD);
4837   }
4838   return ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, TD,
4839                                   TLI);
4840 }
4841
4842 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
4843 /// in the header of its containing loop, we know the loop executes a
4844 /// constant number of times, and the PHI node is just a recurrence
4845 /// involving constants, fold it.
4846 Constant *
4847 ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
4848                                                    const APInt &BEs,
4849                                                    const Loop *L) {
4850   DenseMap<PHINode*, Constant*>::const_iterator I =
4851     ConstantEvolutionLoopExitValue.find(PN);
4852   if (I != ConstantEvolutionLoopExitValue.end())
4853     return I->second;
4854
4855   if (BEs.ugt(MaxBruteForceIterations))
4856     return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
4857
4858   Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
4859
4860   DenseMap<Instruction *, Constant *> CurrentIterVals;
4861   BasicBlock *Header = L->getHeader();
4862   assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
4863
4864   // Since the loop is canonicalized, the PHI node must have two entries.  One
4865   // entry must be a constant (coming in from outside of the loop), and the
4866   // second must be derived from the same PHI.
4867   bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
4868   PHINode *PHI = 0;
4869   for (BasicBlock::iterator I = Header->begin();
4870        (PHI = dyn_cast<PHINode>(I)); ++I) {
4871     Constant *StartCST =
4872       dyn_cast<Constant>(PHI->getIncomingValue(!SecondIsBackedge));
4873     if (StartCST == 0) continue;
4874     CurrentIterVals[PHI] = StartCST;
4875   }
4876   if (!CurrentIterVals.count(PN))
4877     return RetVal = 0;
4878
4879   Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
4880
4881   // Execute the loop symbolically to determine the exit value.
4882   if (BEs.getActiveBits() >= 32)
4883     return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
4884
4885   unsigned NumIterations = BEs.getZExtValue(); // must be in range
4886   unsigned IterationNum = 0;
4887   for (; ; ++IterationNum) {
4888     if (IterationNum == NumIterations)
4889       return RetVal = CurrentIterVals[PN];  // Got exit value!
4890
4891     // Compute the value of the PHIs for the next iteration.
4892     // EvaluateExpression adds non-phi values to the CurrentIterVals map.
4893     DenseMap<Instruction *, Constant *> NextIterVals;
4894     Constant *NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, TD,
4895                                            TLI);
4896     if (NextPHI == 0)
4897       return 0;        // Couldn't evaluate!
4898     NextIterVals[PN] = NextPHI;
4899
4900     bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
4901
4902     // Also evaluate the other PHI nodes.  However, we don't get to stop if we
4903     // cease to be able to evaluate one of them or if they stop evolving,
4904     // because that doesn't necessarily prevent us from computing PN.
4905     SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute;
4906     for (DenseMap<Instruction *, Constant *>::const_iterator
4907            I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){
4908       PHINode *PHI = dyn_cast<PHINode>(I->first);
4909       if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
4910       PHIsToCompute.push_back(std::make_pair(PHI, I->second));
4911     }
4912     // We use two distinct loops because EvaluateExpression may invalidate any
4913     // iterators into CurrentIterVals.
4914     for (SmallVectorImpl<std::pair<PHINode *, Constant*> >::const_iterator
4915              I = PHIsToCompute.begin(), E = PHIsToCompute.end(); I != E; ++I) {
4916       PHINode *PHI = I->first;
4917       Constant *&NextPHI = NextIterVals[PHI];
4918       if (!NextPHI) {   // Not already computed.
4919         Value *BEValue = PHI->getIncomingValue(SecondIsBackedge);
4920         NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, TD, TLI);
4921       }
4922       if (NextPHI != I->second)
4923         StoppedEvolving = false;
4924     }
4925
4926     // If all entries in CurrentIterVals == NextIterVals then we can stop
4927     // iterating, the loop can't continue to change.
4928     if (StoppedEvolving)
4929       return RetVal = CurrentIterVals[PN];
4930
4931     CurrentIterVals.swap(NextIterVals);
4932   }
4933 }
4934
4935 /// ComputeExitCountExhaustively - If the loop is known to execute a
4936 /// constant number of times (the condition evolves only from constants),
4937 /// try to evaluate a few iterations of the loop until we get the exit
4938 /// condition gets a value of ExitWhen (true or false).  If we cannot
4939 /// evaluate the trip count of the loop, return getCouldNotCompute().
4940 const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L,
4941                                                           Value *Cond,
4942                                                           bool ExitWhen) {
4943   PHINode *PN = getConstantEvolvingPHI(Cond, L);
4944   if (PN == 0) return getCouldNotCompute();
4945
4946   // If the loop is canonicalized, the PHI will have exactly two entries.
4947   // That's the only form we support here.
4948   if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
4949
4950   DenseMap<Instruction *, Constant *> CurrentIterVals;
4951   BasicBlock *Header = L->getHeader();
4952   assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
4953
4954   // One entry must be a constant (coming in from outside of the loop), and the
4955   // second must be derived from the same PHI.
4956   bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
4957   PHINode *PHI = 0;
4958   for (BasicBlock::iterator I = Header->begin();
4959        (PHI = dyn_cast<PHINode>(I)); ++I) {
4960     Constant *StartCST =
4961       dyn_cast<Constant>(PHI->getIncomingValue(!SecondIsBackedge));
4962     if (StartCST == 0) continue;
4963     CurrentIterVals[PHI] = StartCST;
4964   }
4965   if (!CurrentIterVals.count(PN))
4966     return getCouldNotCompute();
4967
4968   // Okay, we find a PHI node that defines the trip count of this loop.  Execute
4969   // the loop symbolically to determine when the condition gets a value of
4970   // "ExitWhen".
4971
4972   unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
4973   for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
4974     ConstantInt *CondVal =
4975       dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, L, CurrentIterVals,
4976                                                        TD, TLI));
4977
4978     // Couldn't symbolically evaluate.
4979     if (!CondVal) return getCouldNotCompute();
4980
4981     if (CondVal->getValue() == uint64_t(ExitWhen)) {
4982       ++NumBruteForceTripCountsComputed;
4983       return getConstant(Type::getInt32Ty(getContext()), IterationNum);
4984     }
4985
4986     // Update all the PHI nodes for the next iteration.
4987     DenseMap<Instruction *, Constant *> NextIterVals;
4988
4989     // Create a list of which PHIs we need to compute. We want to do this before
4990     // calling EvaluateExpression on them because that may invalidate iterators
4991     // into CurrentIterVals.
4992     SmallVector<PHINode *, 8> PHIsToCompute;
4993     for (DenseMap<Instruction *, Constant *>::const_iterator
4994            I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){
4995       PHINode *PHI = dyn_cast<PHINode>(I->first);
4996       if (!PHI || PHI->getParent() != Header) continue;
4997       PHIsToCompute.push_back(PHI);
4998     }
4999     for (SmallVectorImpl<PHINode *>::const_iterator I = PHIsToCompute.begin(),
5000              E = PHIsToCompute.end(); I != E; ++I) {
5001       PHINode *PHI = *I;
5002       Constant *&NextPHI = NextIterVals[PHI];
5003       if (NextPHI) continue;    // Already computed!
5004
5005       Value *BEValue = PHI->getIncomingValue(SecondIsBackedge);
5006       NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, TD, TLI);
5007     }
5008     CurrentIterVals.swap(NextIterVals);
5009   }
5010
5011   // Too many iterations were needed to evaluate.
5012   return getCouldNotCompute();
5013 }
5014
5015 /// getSCEVAtScope - Return a SCEV expression for the specified value
5016 /// at the specified scope in the program.  The L value specifies a loop
5017 /// nest to evaluate the expression at, where null is the top-level or a
5018 /// specified loop is immediately inside of the loop.
5019 ///
5020 /// This method can be used to compute the exit value for a variable defined
5021 /// in a loop by querying what the value will hold in the parent loop.
5022 ///
5023 /// In the case that a relevant loop exit value cannot be computed, the
5024 /// original value V is returned.
5025 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
5026   // Check to see if we've folded this expression at this loop before.
5027   std::map<const Loop *, const SCEV *> &Values = ValuesAtScopes[V];
5028   std::pair<std::map<const Loop *, const SCEV *>::iterator, bool> Pair =
5029     Values.insert(std::make_pair(L, static_cast<const SCEV *>(0)));
5030   if (!Pair.second)
5031     return Pair.first->second ? Pair.first->second : V;
5032
5033   // Otherwise compute it.
5034   const SCEV *C = computeSCEVAtScope(V, L);
5035   ValuesAtScopes[V][L] = C;
5036   return C;
5037 }
5038
5039 /// This builds up a Constant using the ConstantExpr interface.  That way, we
5040 /// will return Constants for objects which aren't represented by a
5041 /// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
5042 /// Returns NULL if the SCEV isn't representable as a Constant.
5043 static Constant *BuildConstantFromSCEV(const SCEV *V) {
5044   switch (V->getSCEVType()) {
5045     default:  // TODO: smax, umax.
5046     case scCouldNotCompute:
5047     case scAddRecExpr:
5048       break;
5049     case scConstant:
5050       return cast<SCEVConstant>(V)->getValue();
5051     case scUnknown:
5052       return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
5053     case scSignExtend: {
5054       const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
5055       if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
5056         return ConstantExpr::getSExt(CastOp, SS->getType());
5057       break;
5058     }
5059     case scZeroExtend: {
5060       const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
5061       if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
5062         return ConstantExpr::getZExt(CastOp, SZ->getType());
5063       break;
5064     }
5065     case scTruncate: {
5066       const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
5067       if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
5068         return ConstantExpr::getTrunc(CastOp, ST->getType());
5069       break;
5070     }
5071     case scAddExpr: {
5072       const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
5073       if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) {
5074         if (C->getType()->isPointerTy())
5075           C = ConstantExpr::getBitCast(C, Type::getInt8PtrTy(C->getContext()));
5076         for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) {
5077           Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i));
5078           if (!C2) return 0;
5079
5080           // First pointer!
5081           if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) {
5082             std::swap(C, C2);
5083             // The offsets have been converted to bytes.  We can add bytes to an
5084             // i8* by GEP with the byte count in the first index.
5085             C = ConstantExpr::getBitCast(C,Type::getInt8PtrTy(C->getContext()));
5086           }
5087
5088           // Don't bother trying to sum two pointers. We probably can't
5089           // statically compute a load that results from it anyway.
5090           if (C2->getType()->isPointerTy())
5091             return 0;
5092
5093           if (C->getType()->isPointerTy()) {
5094             if (cast<PointerType>(C->getType())->getElementType()->isStructTy())
5095               C2 = ConstantExpr::getIntegerCast(
5096                   C2, Type::getInt32Ty(C->getContext()), true);
5097             C = ConstantExpr::getGetElementPtr(C, C2);
5098           } else
5099             C = ConstantExpr::getAdd(C, C2);
5100         }
5101         return C;
5102       }
5103       break;
5104     }
5105     case scMulExpr: {
5106       const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
5107       if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) {
5108         // Don't bother with pointers at all.
5109         if (C->getType()->isPointerTy()) return 0;
5110         for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) {
5111           Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i));
5112           if (!C2 || C2->getType()->isPointerTy()) return 0;
5113           C = ConstantExpr::getMul(C, C2);
5114         }
5115         return C;
5116       }
5117       break;
5118     }
5119     case scUDivExpr: {
5120       const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
5121       if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS()))
5122         if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
5123           if (LHS->getType() == RHS->getType())
5124             return ConstantExpr::getUDiv(LHS, RHS);
5125       break;
5126     }
5127   }
5128   return 0;
5129 }
5130
5131 const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
5132   if (isa<SCEVConstant>(V)) return V;
5133
5134   // If this instruction is evolved from a constant-evolving PHI, compute the
5135   // exit value from the loop without using SCEVs.
5136   if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
5137     if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
5138       const Loop *LI = (*this->LI)[I->getParent()];
5139       if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
5140         if (PHINode *PN = dyn_cast<PHINode>(I))
5141           if (PN->getParent() == LI->getHeader()) {
5142             // Okay, there is no closed form solution for the PHI node.  Check
5143             // to see if the loop that contains it has a known backedge-taken
5144             // count.  If so, we may be able to force computation of the exit
5145             // value.
5146             const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI);
5147             if (const SCEVConstant *BTCC =
5148                   dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
5149               // Okay, we know how many times the containing loop executes.  If
5150               // this is a constant evolving PHI node, get the final value at
5151               // the specified iteration number.
5152               Constant *RV = getConstantEvolutionLoopExitValue(PN,
5153                                                    BTCC->getValue()->getValue(),
5154                                                                LI);
5155               if (RV) return getSCEV(RV);
5156             }
5157           }
5158
5159       // Okay, this is an expression that we cannot symbolically evaluate
5160       // into a SCEV.  Check to see if it's possible to symbolically evaluate
5161       // the arguments into constants, and if so, try to constant propagate the
5162       // result.  This is particularly useful for computing loop exit values.
5163       if (CanConstantFold(I)) {
5164         SmallVector<Constant *, 4> Operands;
5165         bool MadeImprovement = false;
5166         for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
5167           Value *Op = I->getOperand(i);
5168           if (Constant *C = dyn_cast<Constant>(Op)) {
5169             Operands.push_back(C);
5170             continue;
5171           }
5172
5173           // If any of the operands is non-constant and if they are
5174           // non-integer and non-pointer, don't even try to analyze them
5175           // with scev techniques.
5176           if (!isSCEVable(Op->getType()))
5177             return V;
5178
5179           const SCEV *OrigV = getSCEV(Op);
5180           const SCEV *OpV = getSCEVAtScope(OrigV, L);
5181           MadeImprovement |= OrigV != OpV;
5182
5183           Constant *C = BuildConstantFromSCEV(OpV);
5184           if (!C) return V;
5185           if (C->getType() != Op->getType())
5186             C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
5187                                                               Op->getType(),
5188                                                               false),
5189                                       C, Op->getType());
5190           Operands.push_back(C);
5191         }
5192
5193         // Check to see if getSCEVAtScope actually made an improvement.
5194         if (MadeImprovement) {
5195           Constant *C = 0;
5196           if (const CmpInst *CI = dyn_cast<CmpInst>(I))
5197             C = ConstantFoldCompareInstOperands(CI->getPredicate(),
5198                                                 Operands[0], Operands[1], TD,
5199                                                 TLI);
5200           else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) {
5201             if (!LI->isVolatile())
5202               C = ConstantFoldLoadFromConstPtr(Operands[0], TD);
5203           } else
5204             C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
5205                                          Operands, TD, TLI);
5206           if (!C) return V;
5207           return getSCEV(C);
5208         }
5209       }
5210     }
5211
5212     // This is some other type of SCEVUnknown, just return it.
5213     return V;
5214   }
5215
5216   if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
5217     // Avoid performing the look-up in the common case where the specified
5218     // expression has no loop-variant portions.
5219     for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
5220       const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
5221       if (OpAtScope != Comm->getOperand(i)) {
5222         // Okay, at least one of these operands is loop variant but might be
5223         // foldable.  Build a new instance of the folded commutative expression.
5224         SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
5225                                             Comm->op_begin()+i);
5226         NewOps.push_back(OpAtScope);
5227
5228         for (++i; i != e; ++i) {
5229           OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
5230           NewOps.push_back(OpAtScope);
5231         }
5232         if (isa<SCEVAddExpr>(Comm))
5233           return getAddExpr(NewOps);
5234         if (isa<SCEVMulExpr>(Comm))
5235           return getMulExpr(NewOps);
5236         if (isa<SCEVSMaxExpr>(Comm))
5237           return getSMaxExpr(NewOps);
5238         if (isa<SCEVUMaxExpr>(Comm))
5239           return getUMaxExpr(NewOps);
5240         llvm_unreachable("Unknown commutative SCEV type!");
5241       }
5242     }
5243     // If we got here, all operands are loop invariant.
5244     return Comm;
5245   }
5246
5247   if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
5248     const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
5249     const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
5250     if (LHS == Div->getLHS() && RHS == Div->getRHS())
5251       return Div;   // must be loop invariant
5252     return getUDivExpr(LHS, RHS);
5253   }
5254
5255   // If this is a loop recurrence for a loop that does not contain L, then we
5256   // are dealing with the final value computed by the loop.
5257   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
5258     // First, attempt to evaluate each operand.
5259     // Avoid performing the look-up in the common case where the specified
5260     // expression has no loop-variant portions.
5261     for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
5262       const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
5263       if (OpAtScope == AddRec->getOperand(i))
5264         continue;
5265
5266       // Okay, at least one of these operands is loop variant but might be
5267       // foldable.  Build a new instance of the folded commutative expression.
5268       SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(),
5269                                           AddRec->op_begin()+i);
5270       NewOps.push_back(OpAtScope);
5271       for (++i; i != e; ++i)
5272         NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
5273
5274       const SCEV *FoldedRec =
5275         getAddRecExpr(NewOps, AddRec->getLoop(),
5276                       AddRec->getNoWrapFlags(SCEV::FlagNW));
5277       AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
5278       // The addrec may be folded to a nonrecurrence, for example, if the
5279       // induction variable is multiplied by zero after constant folding. Go
5280       // ahead and return the folded value.
5281       if (!AddRec)
5282         return FoldedRec;
5283       break;
5284     }
5285
5286     // If the scope is outside the addrec's loop, evaluate it by using the
5287     // loop exit value of the addrec.
5288     if (!AddRec->getLoop()->contains(L)) {
5289       // To evaluate this recurrence, we need to know how many times the AddRec
5290       // loop iterates.  Compute this now.
5291       const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
5292       if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
5293
5294       // Then, evaluate the AddRec.
5295       return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
5296     }
5297
5298     return AddRec;
5299   }
5300
5301   if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
5302     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
5303     if (Op == Cast->getOperand())
5304       return Cast;  // must be loop invariant
5305     return getZeroExtendExpr(Op, Cast->getType());
5306   }
5307
5308   if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
5309     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
5310     if (Op == Cast->getOperand())
5311       return Cast;  // must be loop invariant
5312     return getSignExtendExpr(Op, Cast->getType());
5313   }
5314
5315   if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
5316     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
5317     if (Op == Cast->getOperand())
5318       return Cast;  // must be loop invariant
5319     return getTruncateExpr(Op, Cast->getType());
5320   }
5321
5322   llvm_unreachable("Unknown SCEV type!");
5323   return 0;
5324 }
5325
5326 /// getSCEVAtScope - This is a convenience function which does
5327 /// getSCEVAtScope(getSCEV(V), L).
5328 const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
5329   return getSCEVAtScope(getSCEV(V), L);
5330 }
5331
5332 /// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
5333 /// following equation:
5334 ///
5335 ///     A * X = B (mod N)
5336 ///
5337 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of
5338 /// A and B isn't important.
5339 ///
5340 /// If the equation does not have a solution, SCEVCouldNotCompute is returned.
5341 static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
5342                                                ScalarEvolution &SE) {
5343   uint32_t BW = A.getBitWidth();
5344   assert(BW == B.getBitWidth() && "Bit widths must be the same.");
5345   assert(A != 0 && "A must be non-zero.");
5346
5347   // 1. D = gcd(A, N)
5348   //
5349   // The gcd of A and N may have only one prime factor: 2. The number of
5350   // trailing zeros in A is its multiplicity
5351   uint32_t Mult2 = A.countTrailingZeros();
5352   // D = 2^Mult2
5353
5354   // 2. Check if B is divisible by D.
5355   //
5356   // B is divisible by D if and only if the multiplicity of prime factor 2 for B
5357   // is not less than multiplicity of this prime factor for D.
5358   if (B.countTrailingZeros() < Mult2)
5359     return SE.getCouldNotCompute();
5360
5361   // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
5362   // modulo (N / D).
5363   //
5364   // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
5365   // bit width during computations.
5366   APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
5367   APInt Mod(BW + 1, 0);
5368   Mod.setBit(BW - Mult2);  // Mod = N / D
5369   APInt I = AD.multiplicativeInverse(Mod);
5370
5371   // 4. Compute the minimum unsigned root of the equation:
5372   // I * (B / D) mod (N / D)
5373   APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
5374
5375   // The result is guaranteed to be less than 2^BW so we may truncate it to BW
5376   // bits.
5377   return SE.getConstant(Result.trunc(BW));
5378 }
5379
5380 /// SolveQuadraticEquation - Find the roots of the quadratic equation for the
5381 /// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
5382 /// might be the same) or two SCEVCouldNotCompute objects.
5383 ///
5384 static std::pair<const SCEV *,const SCEV *>
5385 SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
5386   assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
5387   const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
5388   const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
5389   const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
5390
5391   // We currently can only solve this if the coefficients are constants.
5392   if (!LC || !MC || !NC) {
5393     const SCEV *CNC = SE.getCouldNotCompute();
5394     return std::make_pair(CNC, CNC);
5395   }
5396
5397   uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
5398   const APInt &L = LC->getValue()->getValue();
5399   const APInt &M = MC->getValue()->getValue();
5400   const APInt &N = NC->getValue()->getValue();
5401   APInt Two(BitWidth, 2);
5402   APInt Four(BitWidth, 4);
5403
5404   {
5405     using namespace APIntOps;
5406     const APInt& C = L;
5407     // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
5408     // The B coefficient is M-N/2
5409     APInt B(M);
5410     B -= sdiv(N,Two);
5411
5412     // The A coefficient is N/2
5413     APInt A(N.sdiv(Two));
5414
5415     // Compute the B^2-4ac term.
5416     APInt SqrtTerm(B);
5417     SqrtTerm *= B;
5418     SqrtTerm -= Four * (A * C);
5419
5420     // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
5421     // integer value or else APInt::sqrt() will assert.
5422     APInt SqrtVal(SqrtTerm.sqrt());
5423
5424     // Compute the two solutions for the quadratic formula.
5425     // The divisions must be performed as signed divisions.
5426     APInt NegB(-B);
5427     APInt TwoA(A << 1);
5428     if (TwoA.isMinValue()) {
5429       const SCEV *CNC = SE.getCouldNotCompute();
5430       return std::make_pair(CNC, CNC);
5431     }
5432
5433     LLVMContext &Context = SE.getContext();
5434
5435     ConstantInt *Solution1 =
5436       ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA));
5437     ConstantInt *Solution2 =
5438       ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA));
5439
5440     return std::make_pair(SE.getConstant(Solution1),
5441                           SE.getConstant(Solution2));
5442   } // end APIntOps namespace
5443 }
5444
5445 /// HowFarToZero - Return the number of times a backedge comparing the specified
5446 /// value to zero will execute.  If not computable, return CouldNotCompute.
5447 ///
5448 /// This is only used for loops with a "x != y" exit test. The exit condition is
5449 /// now expressed as a single expression, V = x-y. So the exit test is
5450 /// effectively V != 0.  We know and take advantage of the fact that this
5451 /// expression only being used in a comparison by zero context.
5452 ScalarEvolution::ExitLimit
5453 ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) {
5454   // If the value is a constant
5455   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
5456     // If the value is already zero, the branch will execute zero times.
5457     if (C->getValue()->isZero()) return C;
5458     return getCouldNotCompute();  // Otherwise it will loop infinitely.
5459   }
5460
5461   const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
5462   if (!AddRec || AddRec->getLoop() != L)
5463     return getCouldNotCompute();
5464
5465   // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
5466   // the quadratic equation to solve it.
5467   if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
5468     std::pair<const SCEV *,const SCEV *> Roots =
5469       SolveQuadraticEquation(AddRec, *this);
5470     const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
5471     const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
5472     if (R1 && R2) {
5473 #if 0
5474       dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1
5475              << "  sol#2: " << *R2 << "\n";
5476 #endif
5477       // Pick the smallest positive root value.
5478       if (ConstantInt *CB =
5479           dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT,
5480                                                       R1->getValue(),
5481                                                       R2->getValue()))) {
5482         if (CB->getZExtValue() == false)
5483           std::swap(R1, R2);   // R1 is the minimum root now.
5484
5485         // We can only use this value if the chrec ends up with an exact zero
5486         // value at this index.  When solving for "X*X != 5", for example, we
5487         // should not accept a root of 2.
5488         const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
5489         if (Val->isZero())
5490           return R1;  // We found a quadratic root!
5491       }
5492     }
5493     return getCouldNotCompute();
5494   }
5495
5496   // Otherwise we can only handle this if it is affine.
5497   if (!AddRec->isAffine())
5498     return getCouldNotCompute();
5499
5500   // If this is an affine expression, the execution count of this branch is
5501   // the minimum unsigned root of the following equation:
5502   //
5503   //     Start + Step*N = 0 (mod 2^BW)
5504   //
5505   // equivalent to:
5506   //
5507   //             Step*N = -Start (mod 2^BW)
5508   //
5509   // where BW is the common bit width of Start and Step.
5510
5511   // Get the initial value for the loop.
5512   const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
5513   const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
5514
5515   // For now we handle only constant steps.
5516   //
5517   // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
5518   // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
5519   // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
5520   // We have not yet seen any such cases.
5521   const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
5522   if (StepC == 0)
5523     return getCouldNotCompute();
5524
5525   // For positive steps (counting up until unsigned overflow):
5526   //   N = -Start/Step (as unsigned)
5527   // For negative steps (counting down to zero):
5528   //   N = Start/-Step
5529   // First compute the unsigned distance from zero in the direction of Step.
5530   bool CountDown = StepC->getValue()->getValue().isNegative();
5531   const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
5532
5533   // Handle unitary steps, which cannot wraparound.
5534   // 1*N = -Start; -1*N = Start (mod 2^BW), so:
5535   //   N = Distance (as unsigned)
5536   if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) {
5537     ConstantRange CR = getUnsignedRange(Start);
5538     const SCEV *MaxBECount;
5539     if (!CountDown && CR.getUnsignedMin().isMinValue())
5540       // When counting up, the worst starting value is 1, not 0.
5541       MaxBECount = CR.getUnsignedMax().isMinValue()
5542         ? getConstant(APInt::getMinValue(CR.getBitWidth()))
5543         : getConstant(APInt::getMaxValue(CR.getBitWidth()));
5544     else
5545       MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
5546                                          : -CR.getUnsignedMin());
5547     return ExitLimit(Distance, MaxBECount);
5548   }
5549
5550   // If the recurrence is known not to wraparound, unsigned divide computes the
5551   // back edge count. We know that the value will either become zero (and thus
5552   // the loop terminates), that the loop will terminate through some other exit
5553   // condition first, or that the loop has undefined behavior.  This means
5554   // we can't "miss" the exit value, even with nonunit stride.
5555   //
5556   // FIXME: Prove that loops always exhibits *acceptable* undefined
5557   // behavior. Loops must exhibit defined behavior until a wrapped value is
5558   // actually used. So the trip count computed by udiv could be smaller than the
5559   // number of well-defined iterations.
5560   if (AddRec->getNoWrapFlags(SCEV::FlagNW)) {
5561     // FIXME: We really want an "isexact" bit for udiv.
5562     return getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
5563   }
5564   // Then, try to solve the above equation provided that Start is constant.
5565   if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
5566     return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
5567                                         -StartC->getValue()->getValue(),
5568                                         *this);
5569   return getCouldNotCompute();
5570 }
5571
5572 /// HowFarToNonZero - Return the number of times a backedge checking the
5573 /// specified value for nonzero will execute.  If not computable, return
5574 /// CouldNotCompute
5575 ScalarEvolution::ExitLimit
5576 ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
5577   // Loops that look like: while (X == 0) are very strange indeed.  We don't
5578   // handle them yet except for the trivial case.  This could be expanded in the
5579   // future as needed.
5580
5581   // If the value is a constant, check to see if it is known to be non-zero
5582   // already.  If so, the backedge will execute zero times.
5583   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
5584     if (!C->getValue()->isNullValue())
5585       return getConstant(C->getType(), 0);
5586     return getCouldNotCompute();  // Otherwise it will loop infinitely.
5587   }
5588
5589   // We could implement others, but I really doubt anyone writes loops like
5590   // this, and if they did, they would already be constant folded.
5591   return getCouldNotCompute();
5592 }
5593
5594 /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
5595 /// (which may not be an immediate predecessor) which has exactly one
5596 /// successor from which BB is reachable, or null if no such block is
5597 /// found.
5598 ///
5599 std::pair<BasicBlock *, BasicBlock *>
5600 ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
5601   // If the block has a unique predecessor, then there is no path from the
5602   // predecessor to the block that does not go through the direct edge
5603   // from the predecessor to the block.
5604   if (BasicBlock *Pred = BB->getSinglePredecessor())
5605     return std::make_pair(Pred, BB);
5606
5607   // A loop's header is defined to be a block that dominates the loop.
5608   // If the header has a unique predecessor outside the loop, it must be
5609   // a block that has exactly one successor that can reach the loop.
5610   if (Loop *L = LI->getLoopFor(BB))
5611     return std::make_pair(L->getLoopPredecessor(), L->getHeader());
5612
5613   return std::pair<BasicBlock *, BasicBlock *>();
5614 }
5615
5616 /// HasSameValue - SCEV structural equivalence is usually sufficient for
5617 /// testing whether two expressions are equal, however for the purposes of
5618 /// looking for a condition guarding a loop, it can be useful to be a little
5619 /// more general, since a front-end may have replicated the controlling
5620 /// expression.
5621 ///
5622 static bool HasSameValue(const SCEV *A, const SCEV *B) {
5623   // Quick check to see if they are the same SCEV.
5624   if (A == B) return true;
5625
5626   // Otherwise, if they're both SCEVUnknown, it's possible that they hold
5627   // two different instructions with the same value. Check for this case.
5628   if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
5629     if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
5630       if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
5631         if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
5632           if (AI->isIdenticalTo(BI) && !AI->mayReadFromMemory())
5633             return true;
5634
5635   // Otherwise assume they may have a different value.
5636   return false;
5637 }
5638
5639 /// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with
5640 /// predicate Pred. Return true iff any changes were made.
5641 ///
5642 bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
5643                                            const SCEV *&LHS, const SCEV *&RHS) {
5644   bool Changed = false;
5645
5646   // Canonicalize a constant to the right side.
5647   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
5648     // Check for both operands constant.
5649     if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
5650       if (ConstantExpr::getICmp(Pred,
5651                                 LHSC->getValue(),
5652                                 RHSC->getValue())->isNullValue())
5653         goto trivially_false;
5654       else
5655         goto trivially_true;
5656     }
5657     // Otherwise swap the operands to put the constant on the right.
5658     std::swap(LHS, RHS);
5659     Pred = ICmpInst::getSwappedPredicate(Pred);
5660     Changed = true;
5661   }
5662
5663   // If we're comparing an addrec with a value which is loop-invariant in the
5664   // addrec's loop, put the addrec on the left. Also make a dominance check,
5665   // as both operands could be addrecs loop-invariant in each other's loop.
5666   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
5667     const Loop *L = AR->getLoop();
5668     if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
5669       std::swap(LHS, RHS);
5670       Pred = ICmpInst::getSwappedPredicate(Pred);
5671       Changed = true;
5672     }
5673   }
5674
5675   // If there's a constant operand, canonicalize comparisons with boundary
5676   // cases, and canonicalize *-or-equal comparisons to regular comparisons.
5677   if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
5678     const APInt &RA = RC->getValue()->getValue();
5679     switch (Pred) {
5680     default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
5681     case ICmpInst::ICMP_EQ:
5682     case ICmpInst::ICMP_NE:
5683       break;
5684     case ICmpInst::ICMP_UGE:
5685       if ((RA - 1).isMinValue()) {
5686         Pred = ICmpInst::ICMP_NE;
5687         RHS = getConstant(RA - 1);
5688         Changed = true;
5689         break;
5690       }
5691       if (RA.isMaxValue()) {
5692         Pred = ICmpInst::ICMP_EQ;
5693         Changed = true;
5694         break;
5695       }
5696       if (RA.isMinValue()) goto trivially_true;
5697
5698       Pred = ICmpInst::ICMP_UGT;
5699       RHS = getConstant(RA - 1);
5700       Changed = true;
5701       break;
5702     case ICmpInst::ICMP_ULE:
5703       if ((RA + 1).isMaxValue()) {
5704         Pred = ICmpInst::ICMP_NE;
5705         RHS = getConstant(RA + 1);
5706         Changed = true;
5707         break;
5708       }
5709       if (RA.isMinValue()) {
5710         Pred = ICmpInst::ICMP_EQ;
5711         Changed = true;
5712         break;
5713       }
5714       if (RA.isMaxValue()) goto trivially_true;
5715
5716       Pred = ICmpInst::ICMP_ULT;
5717       RHS = getConstant(RA + 1);
5718       Changed = true;
5719       break;
5720     case ICmpInst::ICMP_SGE:
5721       if ((RA - 1).isMinSignedValue()) {
5722         Pred = ICmpInst::ICMP_NE;
5723         RHS = getConstant(RA - 1);
5724         Changed = true;
5725         break;
5726       }
5727       if (RA.isMaxSignedValue()) {
5728         Pred = ICmpInst::ICMP_EQ;
5729         Changed = true;
5730         break;
5731       }
5732       if (RA.isMinSignedValue()) goto trivially_true;
5733
5734       Pred = ICmpInst::ICMP_SGT;
5735       RHS = getConstant(RA - 1);
5736       Changed = true;
5737       break;
5738     case ICmpInst::ICMP_SLE:
5739       if ((RA + 1).isMaxSignedValue()) {
5740         Pred = ICmpInst::ICMP_NE;
5741         RHS = getConstant(RA + 1);
5742         Changed = true;
5743         break;
5744       }
5745       if (RA.isMinSignedValue()) {
5746         Pred = ICmpInst::ICMP_EQ;
5747         Changed = true;
5748         break;
5749       }
5750       if (RA.isMaxSignedValue()) goto trivially_true;
5751
5752       Pred = ICmpInst::ICMP_SLT;
5753       RHS = getConstant(RA + 1);
5754       Changed = true;
5755       break;
5756     case ICmpInst::ICMP_UGT:
5757       if (RA.isMinValue()) {
5758         Pred = ICmpInst::ICMP_NE;
5759         Changed = true;
5760         break;
5761       }
5762       if ((RA + 1).isMaxValue()) {
5763         Pred = ICmpInst::ICMP_EQ;
5764         RHS = getConstant(RA + 1);
5765         Changed = true;
5766         break;
5767       }
5768       if (RA.isMaxValue()) goto trivially_false;
5769       break;
5770     case ICmpInst::ICMP_ULT:
5771       if (RA.isMaxValue()) {
5772         Pred = ICmpInst::ICMP_NE;
5773         Changed = true;
5774         break;
5775       }
5776       if ((RA - 1).isMinValue()) {
5777         Pred = ICmpInst::ICMP_EQ;
5778         RHS = getConstant(RA - 1);
5779         Changed = true;
5780         break;
5781       }
5782       if (RA.isMinValue()) goto trivially_false;
5783       break;
5784     case ICmpInst::ICMP_SGT:
5785       if (RA.isMinSignedValue()) {
5786         Pred = ICmpInst::ICMP_NE;
5787         Changed = true;
5788         break;
5789       }
5790       if ((RA + 1).isMaxSignedValue()) {
5791         Pred = ICmpInst::ICMP_EQ;
5792         RHS = getConstant(RA + 1);
5793         Changed = true;
5794         break;
5795       }
5796       if (RA.isMaxSignedValue()) goto trivially_false;
5797       break;
5798     case ICmpInst::ICMP_SLT:
5799       if (RA.isMaxSignedValue()) {
5800         Pred = ICmpInst::ICMP_NE;
5801         Changed = true;
5802         break;
5803       }
5804       if ((RA - 1).isMinSignedValue()) {
5805        Pred = ICmpInst::ICMP_EQ;
5806        RHS = getConstant(RA - 1);
5807         Changed = true;
5808        break;
5809       }
5810       if (RA.isMinSignedValue()) goto trivially_false;
5811       break;
5812     }
5813   }
5814
5815   // Check for obvious equality.
5816   if (HasSameValue(LHS, RHS)) {
5817     if (ICmpInst::isTrueWhenEqual(Pred))
5818       goto trivially_true;
5819     if (ICmpInst::isFalseWhenEqual(Pred))
5820       goto trivially_false;
5821   }
5822
5823   // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
5824   // adding or subtracting 1 from one of the operands.
5825   switch (Pred) {
5826   case ICmpInst::ICMP_SLE:
5827     if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) {
5828       RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
5829                        SCEV::FlagNSW);
5830       Pred = ICmpInst::ICMP_SLT;
5831       Changed = true;
5832     } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) {
5833       LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
5834                        SCEV::FlagNSW);
5835       Pred = ICmpInst::ICMP_SLT;
5836       Changed = true;
5837     }
5838     break;
5839   case ICmpInst::ICMP_SGE:
5840     if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) {
5841       RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
5842                        SCEV::FlagNSW);
5843       Pred = ICmpInst::ICMP_SGT;
5844       Changed = true;
5845     } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) {
5846       LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
5847                        SCEV::FlagNSW);
5848       Pred = ICmpInst::ICMP_SGT;
5849       Changed = true;
5850     }
5851     break;
5852   case ICmpInst::ICMP_ULE:
5853     if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) {
5854       RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
5855                        SCEV::FlagNUW);
5856       Pred = ICmpInst::ICMP_ULT;
5857       Changed = true;
5858     } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
5859       LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
5860                        SCEV::FlagNUW);
5861       Pred = ICmpInst::ICMP_ULT;
5862       Changed = true;
5863     }
5864     break;
5865   case ICmpInst::ICMP_UGE:
5866     if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
5867       RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
5868                        SCEV::FlagNUW);
5869       Pred = ICmpInst::ICMP_UGT;
5870       Changed = true;
5871     } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
5872       LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
5873                        SCEV::FlagNUW);
5874       Pred = ICmpInst::ICMP_UGT;
5875       Changed = true;
5876     }
5877     break;
5878   default:
5879     break;
5880   }
5881
5882   // TODO: More simplifications are possible here.
5883
5884   return Changed;
5885
5886 trivially_true:
5887   // Return 0 == 0.
5888   LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
5889   Pred = ICmpInst::ICMP_EQ;
5890   return true;
5891
5892 trivially_false:
5893   // Return 0 != 0.
5894   LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
5895   Pred = ICmpInst::ICMP_NE;
5896   return true;
5897 }
5898
5899 bool ScalarEvolution::isKnownNegative(const SCEV *S) {
5900   return getSignedRange(S).getSignedMax().isNegative();
5901 }
5902
5903 bool ScalarEvolution::isKnownPositive(const SCEV *S) {
5904   return getSignedRange(S).getSignedMin().isStrictlyPositive();
5905 }
5906
5907 bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
5908   return !getSignedRange(S).getSignedMin().isNegative();
5909 }
5910
5911 bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
5912   return !getSignedRange(S).getSignedMax().isStrictlyPositive();
5913 }
5914
5915 bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
5916   return isKnownNegative(S) || isKnownPositive(S);
5917 }
5918
5919 bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
5920                                        const SCEV *LHS, const SCEV *RHS) {
5921   // Canonicalize the inputs first.
5922   (void)SimplifyICmpOperands(Pred, LHS, RHS);
5923
5924   // If LHS or RHS is an addrec, check to see if the condition is true in
5925   // every iteration of the loop.
5926   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
5927     if (isLoopEntryGuardedByCond(
5928           AR->getLoop(), Pred, AR->getStart(), RHS) &&
5929         isLoopBackedgeGuardedByCond(
5930           AR->getLoop(), Pred, AR->getPostIncExpr(*this), RHS))
5931       return true;
5932   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS))
5933     if (isLoopEntryGuardedByCond(
5934           AR->getLoop(), Pred, LHS, AR->getStart()) &&
5935         isLoopBackedgeGuardedByCond(
5936           AR->getLoop(), Pred, LHS, AR->getPostIncExpr(*this)))
5937       return true;
5938
5939   // Otherwise see what can be done with known constant ranges.
5940   return isKnownPredicateWithRanges(Pred, LHS, RHS);
5941 }
5942
5943 bool
5944 ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
5945                                             const SCEV *LHS, const SCEV *RHS) {
5946   if (HasSameValue(LHS, RHS))
5947     return ICmpInst::isTrueWhenEqual(Pred);
5948
5949   // This code is split out from isKnownPredicate because it is called from
5950   // within isLoopEntryGuardedByCond.
5951   switch (Pred) {
5952   default:
5953     llvm_unreachable("Unexpected ICmpInst::Predicate value!");
5954     break;
5955   case ICmpInst::ICMP_SGT:
5956     Pred = ICmpInst::ICMP_SLT;
5957     std::swap(LHS, RHS);
5958   case ICmpInst::ICMP_SLT: {
5959     ConstantRange LHSRange = getSignedRange(LHS);
5960     ConstantRange RHSRange = getSignedRange(RHS);
5961     if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin()))
5962       return true;
5963     if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax()))
5964       return false;
5965     break;
5966   }
5967   case ICmpInst::ICMP_SGE:
5968     Pred = ICmpInst::ICMP_SLE;
5969     std::swap(LHS, RHS);
5970   case ICmpInst::ICMP_SLE: {
5971     ConstantRange LHSRange = getSignedRange(LHS);
5972     ConstantRange RHSRange = getSignedRange(RHS);
5973     if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin()))
5974       return true;
5975     if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax()))
5976       return false;
5977     break;
5978   }
5979   case ICmpInst::ICMP_UGT:
5980     Pred = ICmpInst::ICMP_ULT;
5981     std::swap(LHS, RHS);
5982   case ICmpInst::ICMP_ULT: {
5983     ConstantRange LHSRange = getUnsignedRange(LHS);
5984     ConstantRange RHSRange = getUnsignedRange(RHS);
5985     if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin()))
5986       return true;
5987     if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax()))
5988       return false;
5989     break;
5990   }
5991   case ICmpInst::ICMP_UGE:
5992     Pred = ICmpInst::ICMP_ULE;
5993     std::swap(LHS, RHS);
5994   case ICmpInst::ICMP_ULE: {
5995     ConstantRange LHSRange = getUnsignedRange(LHS);
5996     ConstantRange RHSRange = getUnsignedRange(RHS);
5997     if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin()))
5998       return true;
5999     if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax()))
6000       return false;
6001     break;
6002   }
6003   case ICmpInst::ICMP_NE: {
6004     if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet())
6005       return true;
6006     if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet())
6007       return true;
6008
6009     const SCEV *Diff = getMinusSCEV(LHS, RHS);
6010     if (isKnownNonZero(Diff))
6011       return true;
6012     break;
6013   }
6014   case ICmpInst::ICMP_EQ:
6015     // The check at the top of the function catches the case where
6016     // the values are known to be equal.
6017     break;
6018   }
6019   return false;
6020 }
6021
6022 /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
6023 /// protected by a conditional between LHS and RHS.  This is used to
6024 /// to eliminate casts.
6025 bool
6026 ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
6027                                              ICmpInst::Predicate Pred,
6028                                              const SCEV *LHS, const SCEV *RHS) {
6029   // Interpret a null as meaning no loop, where there is obviously no guard
6030   // (interprocedural conditions notwithstanding).
6031   if (!L) return true;
6032
6033   BasicBlock *Latch = L->getLoopLatch();
6034   if (!Latch)
6035     return false;
6036
6037   BranchInst *LoopContinuePredicate =
6038     dyn_cast<BranchInst>(Latch->getTerminator());
6039   if (!LoopContinuePredicate ||
6040       LoopContinuePredicate->isUnconditional())
6041     return false;
6042
6043   return isImpliedCond(Pred, LHS, RHS,
6044                        LoopContinuePredicate->getCondition(),
6045                        LoopContinuePredicate->getSuccessor(0) != L->getHeader());
6046 }
6047
6048 /// isLoopEntryGuardedByCond - Test whether entry to the loop is protected
6049 /// by a conditional between LHS and RHS.  This is used to help avoid max
6050 /// expressions in loop trip counts, and to eliminate casts.
6051 bool
6052 ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
6053                                           ICmpInst::Predicate Pred,
6054                                           const SCEV *LHS, const SCEV *RHS) {
6055   // Interpret a null as meaning no loop, where there is obviously no guard
6056   // (interprocedural conditions notwithstanding).
6057   if (!L) return false;
6058
6059   // Starting at the loop predecessor, climb up the predecessor chain, as long
6060   // as there are predecessors that can be found that have unique successors
6061   // leading to the original header.
6062   for (std::pair<BasicBlock *, BasicBlock *>
6063          Pair(L->getLoopPredecessor(), L->getHeader());
6064        Pair.first;
6065        Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
6066
6067     BranchInst *LoopEntryPredicate =
6068       dyn_cast<BranchInst>(Pair.first->getTerminator());
6069     if (!LoopEntryPredicate ||
6070         LoopEntryPredicate->isUnconditional())
6071       continue;
6072
6073     if (isImpliedCond(Pred, LHS, RHS,
6074                       LoopEntryPredicate->getCondition(),
6075                       LoopEntryPredicate->getSuccessor(0) != Pair.second))
6076       return true;
6077   }
6078
6079   return false;
6080 }
6081
6082 /// isImpliedCond - Test whether the condition described by Pred, LHS,
6083 /// and RHS is true whenever the given Cond value evaluates to true.
6084 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
6085                                     const SCEV *LHS, const SCEV *RHS,
6086                                     Value *FoundCondValue,
6087                                     bool Inverse) {
6088   // Recursively handle And and Or conditions.
6089   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
6090     if (BO->getOpcode() == Instruction::And) {
6091       if (!Inverse)
6092         return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
6093                isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
6094     } else if (BO->getOpcode() == Instruction::Or) {
6095       if (Inverse)
6096         return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
6097                isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
6098     }
6099   }
6100
6101   ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
6102   if (!ICI) return false;
6103
6104   // Bail if the ICmp's operands' types are wider than the needed type
6105   // before attempting to call getSCEV on them. This avoids infinite
6106   // recursion, since the analysis of widening casts can require loop
6107   // exit condition information for overflow checking, which would
6108   // lead back here.
6109   if (getTypeSizeInBits(LHS->getType()) <
6110       getTypeSizeInBits(ICI->getOperand(0)->getType()))
6111     return false;
6112
6113   // Now that we found a conditional branch that dominates the loop, check to
6114   // see if it is the comparison we are looking for.
6115   ICmpInst::Predicate FoundPred;
6116   if (Inverse)
6117     FoundPred = ICI->getInversePredicate();
6118   else
6119     FoundPred = ICI->getPredicate();
6120
6121   const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
6122   const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
6123
6124   // Balance the types. The case where FoundLHS' type is wider than
6125   // LHS' type is checked for above.
6126   if (getTypeSizeInBits(LHS->getType()) >
6127       getTypeSizeInBits(FoundLHS->getType())) {
6128     if (CmpInst::isSigned(Pred)) {
6129       FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
6130       FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
6131     } else {
6132       FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
6133       FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
6134     }
6135   }
6136
6137   // Canonicalize the query to match the way instcombine will have
6138   // canonicalized the comparison.
6139   if (SimplifyICmpOperands(Pred, LHS, RHS))
6140     if (LHS == RHS)
6141       return CmpInst::isTrueWhenEqual(Pred);
6142   if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
6143     if (FoundLHS == FoundRHS)
6144       return CmpInst::isFalseWhenEqual(Pred);
6145
6146   // Check to see if we can make the LHS or RHS match.
6147   if (LHS == FoundRHS || RHS == FoundLHS) {
6148     if (isa<SCEVConstant>(RHS)) {
6149       std::swap(FoundLHS, FoundRHS);
6150       FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
6151     } else {
6152       std::swap(LHS, RHS);
6153       Pred = ICmpInst::getSwappedPredicate(Pred);
6154     }
6155   }
6156
6157   // Check whether the found predicate is the same as the desired predicate.
6158   if (FoundPred == Pred)
6159     return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
6160
6161   // Check whether swapping the found predicate makes it the same as the
6162   // desired predicate.
6163   if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
6164     if (isa<SCEVConstant>(RHS))
6165       return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
6166     else
6167       return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
6168                                    RHS, LHS, FoundLHS, FoundRHS);
6169   }
6170
6171   // Check whether the actual condition is beyond sufficient.
6172   if (FoundPred == ICmpInst::ICMP_EQ)
6173     if (ICmpInst::isTrueWhenEqual(Pred))
6174       if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
6175         return true;
6176   if (Pred == ICmpInst::ICMP_NE)
6177     if (!ICmpInst::isTrueWhenEqual(FoundPred))
6178       if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
6179         return true;
6180
6181   // Otherwise assume the worst.
6182   return false;
6183 }
6184
6185 /// isImpliedCondOperands - Test whether the condition described by Pred,
6186 /// LHS, and RHS is true whenever the condition described by Pred, FoundLHS,
6187 /// and FoundRHS is true.
6188 bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
6189                                             const SCEV *LHS, const SCEV *RHS,
6190                                             const SCEV *FoundLHS,
6191                                             const SCEV *FoundRHS) {
6192   return isImpliedCondOperandsHelper(Pred, LHS, RHS,
6193                                      FoundLHS, FoundRHS) ||
6194          // ~x < ~y --> x > y
6195          isImpliedCondOperandsHelper(Pred, LHS, RHS,
6196                                      getNotSCEV(FoundRHS),
6197                                      getNotSCEV(FoundLHS));
6198 }
6199
6200 /// isImpliedCondOperandsHelper - Test whether the condition described by
6201 /// Pred, LHS, and RHS is true whenever the condition described by Pred,
6202 /// FoundLHS, and FoundRHS is true.
6203 bool
6204 ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
6205                                              const SCEV *LHS, const SCEV *RHS,
6206                                              const SCEV *FoundLHS,
6207                                              const SCEV *FoundRHS) {
6208   switch (Pred) {
6209   default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
6210   case ICmpInst::ICMP_EQ:
6211   case ICmpInst::ICMP_NE:
6212     if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
6213       return true;
6214     break;
6215   case ICmpInst::ICMP_SLT:
6216   case ICmpInst::ICMP_SLE:
6217     if (isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
6218         isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, RHS, FoundRHS))
6219       return true;
6220     break;
6221   case ICmpInst::ICMP_SGT:
6222   case ICmpInst::ICMP_SGE:
6223     if (isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
6224         isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, RHS, FoundRHS))
6225       return true;
6226     break;
6227   case ICmpInst::ICMP_ULT:
6228   case ICmpInst::ICMP_ULE:
6229     if (isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
6230         isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, RHS, FoundRHS))
6231       return true;
6232     break;
6233   case ICmpInst::ICMP_UGT:
6234   case ICmpInst::ICMP_UGE:
6235     if (isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
6236         isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, RHS, FoundRHS))
6237       return true;
6238     break;
6239   }
6240
6241   return false;
6242 }
6243
6244 /// getBECount - Subtract the end and start values and divide by the step,
6245 /// rounding up, to get the number of times the backedge is executed. Return
6246 /// CouldNotCompute if an intermediate computation overflows.
6247 const SCEV *ScalarEvolution::getBECount(const SCEV *Start,
6248                                         const SCEV *End,
6249                                         const SCEV *Step,
6250                                         bool NoWrap) {
6251   assert(!isKnownNegative(Step) &&
6252          "This code doesn't handle negative strides yet!");
6253
6254   Type *Ty = Start->getType();
6255
6256   // When Start == End, we have an exact BECount == 0. Short-circuit this case
6257   // here because SCEV may not be able to determine that the unsigned division
6258   // after rounding is zero.
6259   if (Start == End)
6260     return getConstant(Ty, 0);
6261
6262   const SCEV *NegOne = getConstant(Ty, (uint64_t)-1);
6263   const SCEV *Diff = getMinusSCEV(End, Start);
6264   const SCEV *RoundUp = getAddExpr(Step, NegOne);
6265
6266   // Add an adjustment to the difference between End and Start so that
6267   // the division will effectively round up.
6268   const SCEV *Add = getAddExpr(Diff, RoundUp);
6269
6270   if (!NoWrap) {
6271     // Check Add for unsigned overflow.
6272     // TODO: More sophisticated things could be done here.
6273     Type *WideTy = IntegerType::get(getContext(),
6274                                           getTypeSizeInBits(Ty) + 1);
6275     const SCEV *EDiff = getZeroExtendExpr(Diff, WideTy);
6276     const SCEV *ERoundUp = getZeroExtendExpr(RoundUp, WideTy);
6277     const SCEV *OperandExtendedAdd = getAddExpr(EDiff, ERoundUp);
6278     if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd)
6279       return getCouldNotCompute();
6280   }
6281
6282   return getUDivExpr(Add, Step);
6283 }
6284
6285 /// HowManyLessThans - Return the number of times a backedge containing the
6286 /// specified less-than comparison will execute.  If not computable, return
6287 /// CouldNotCompute.
6288 ScalarEvolution::ExitLimit
6289 ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
6290                                   const Loop *L, bool isSigned) {
6291   // Only handle:  "ADDREC < LoopInvariant".
6292   if (!isLoopInvariant(RHS, L)) return getCouldNotCompute();
6293
6294   const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS);
6295   if (!AddRec || AddRec->getLoop() != L)
6296     return getCouldNotCompute();
6297
6298   // Check to see if we have a flag which makes analysis easy.
6299   bool NoWrap = isSigned ?
6300     AddRec->getNoWrapFlags((SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNW)) :
6301     AddRec->getNoWrapFlags((SCEV::NoWrapFlags)(SCEV::FlagNUW | SCEV::FlagNW));
6302
6303   if (AddRec->isAffine()) {
6304     unsigned BitWidth = getTypeSizeInBits(AddRec->getType());
6305     const SCEV *Step = AddRec->getStepRecurrence(*this);
6306
6307     if (Step->isZero())
6308       return getCouldNotCompute();
6309     if (Step->isOne()) {
6310       // With unit stride, the iteration never steps past the limit value.
6311     } else if (isKnownPositive(Step)) {
6312       // Test whether a positive iteration can step past the limit
6313       // value and past the maximum value for its type in a single step.
6314       // Note that it's not sufficient to check NoWrap here, because even
6315       // though the value after a wrap is undefined, it's not undefined
6316       // behavior, so if wrap does occur, the loop could either terminate or
6317       // loop infinitely, but in either case, the loop is guaranteed to
6318       // iterate at least until the iteration where the wrapping occurs.
6319       const SCEV *One = getConstant(Step->getType(), 1);
6320       if (isSigned) {
6321         APInt Max = APInt::getSignedMaxValue(BitWidth);
6322         if ((Max - getSignedRange(getMinusSCEV(Step, One)).getSignedMax())
6323               .slt(getSignedRange(RHS).getSignedMax()))
6324           return getCouldNotCompute();
6325       } else {
6326         APInt Max = APInt::getMaxValue(BitWidth);
6327         if ((Max - getUnsignedRange(getMinusSCEV(Step, One)).getUnsignedMax())
6328               .ult(getUnsignedRange(RHS).getUnsignedMax()))
6329           return getCouldNotCompute();
6330       }
6331     } else
6332       // TODO: Handle negative strides here and below.
6333       return getCouldNotCompute();
6334
6335     // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant
6336     // m.  So, we count the number of iterations in which {n,+,s} < m is true.
6337     // Note that we cannot simply return max(m-n,0)/s because it's not safe to
6338     // treat m-n as signed nor unsigned due to overflow possibility.
6339
6340     // First, we get the value of the LHS in the first iteration: n
6341     const SCEV *Start = AddRec->getOperand(0);
6342
6343     // Determine the minimum constant start value.
6344     const SCEV *MinStart = getConstant(isSigned ?
6345       getSignedRange(Start).getSignedMin() :
6346       getUnsignedRange(Start).getUnsignedMin());
6347
6348     // If we know that the condition is true in order to enter the loop,
6349     // then we know that it will run exactly (m-n)/s times. Otherwise, we
6350     // only know that it will execute (max(m,n)-n)/s times. In both cases,
6351     // the division must round up.
6352     const SCEV *End = RHS;
6353     if (!isLoopEntryGuardedByCond(L,
6354                                   isSigned ? ICmpInst::ICMP_SLT :
6355                                              ICmpInst::ICMP_ULT,
6356                                   getMinusSCEV(Start, Step), RHS))
6357       End = isSigned ? getSMaxExpr(RHS, Start)
6358                      : getUMaxExpr(RHS, Start);
6359
6360     // Determine the maximum constant end value.
6361     const SCEV *MaxEnd = getConstant(isSigned ?
6362       getSignedRange(End).getSignedMax() :
6363       getUnsignedRange(End).getUnsignedMax());
6364
6365     // If MaxEnd is within a step of the maximum integer value in its type,
6366     // adjust it down to the minimum value which would produce the same effect.
6367     // This allows the subsequent ceiling division of (N+(step-1))/step to
6368     // compute the correct value.
6369     const SCEV *StepMinusOne = getMinusSCEV(Step,
6370                                             getConstant(Step->getType(), 1));
6371     MaxEnd = isSigned ?
6372       getSMinExpr(MaxEnd,
6373                   getMinusSCEV(getConstant(APInt::getSignedMaxValue(BitWidth)),
6374                                StepMinusOne)) :
6375       getUMinExpr(MaxEnd,
6376                   getMinusSCEV(getConstant(APInt::getMaxValue(BitWidth)),
6377                                StepMinusOne));
6378
6379     // Finally, we subtract these two values and divide, rounding up, to get
6380     // the number of times the backedge is executed.
6381     const SCEV *BECount = getBECount(Start, End, Step, NoWrap);
6382
6383     // The maximum backedge count is similar, except using the minimum start
6384     // value and the maximum end value.
6385     // If we already have an exact constant BECount, use it instead.
6386     const SCEV *MaxBECount = isa<SCEVConstant>(BECount) ? BECount
6387       : getBECount(MinStart, MaxEnd, Step, NoWrap);
6388
6389     // If the stride is nonconstant, and NoWrap == true, then
6390     // getBECount(MinStart, MaxEnd) may not compute. This would result in an
6391     // exact BECount and invalid MaxBECount, which should be avoided to catch
6392     // more optimization opportunities.
6393     if (isa<SCEVCouldNotCompute>(MaxBECount))
6394       MaxBECount = BECount;
6395
6396     return ExitLimit(BECount, MaxBECount);
6397   }
6398
6399   return getCouldNotCompute();
6400 }
6401
6402 /// getNumIterationsInRange - Return the number of iterations of this loop that
6403 /// produce values in the specified constant range.  Another way of looking at
6404 /// this is that it returns the first iteration number where the value is not in
6405 /// the condition, thus computing the exit count. If the iteration count can't
6406 /// be computed, an instance of SCEVCouldNotCompute is returned.
6407 const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
6408                                                     ScalarEvolution &SE) const {
6409   if (Range.isFullSet())  // Infinite loop.
6410     return SE.getCouldNotCompute();
6411
6412   // If the start is a non-zero constant, shift the range to simplify things.
6413   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
6414     if (!SC->getValue()->isZero()) {
6415       SmallVector<const SCEV *, 4> Operands(op_begin(), op_end());
6416       Operands[0] = SE.getConstant(SC->getType(), 0);
6417       const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
6418                                              getNoWrapFlags(FlagNW));
6419       if (const SCEVAddRecExpr *ShiftedAddRec =
6420             dyn_cast<SCEVAddRecExpr>(Shifted))
6421         return ShiftedAddRec->getNumIterationsInRange(
6422                            Range.subtract(SC->getValue()->getValue()), SE);
6423       // This is strange and shouldn't happen.
6424       return SE.getCouldNotCompute();
6425     }
6426
6427   // The only time we can solve this is when we have all constant indices.
6428   // Otherwise, we cannot determine the overflow conditions.
6429   for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
6430     if (!isa<SCEVConstant>(getOperand(i)))
6431       return SE.getCouldNotCompute();
6432
6433
6434   // Okay at this point we know that all elements of the chrec are constants and
6435   // that the start element is zero.
6436
6437   // First check to see if the range contains zero.  If not, the first
6438   // iteration exits.
6439   unsigned BitWidth = SE.getTypeSizeInBits(getType());
6440   if (!Range.contains(APInt(BitWidth, 0)))
6441     return SE.getConstant(getType(), 0);
6442
6443   if (isAffine()) {
6444     // If this is an affine expression then we have this situation:
6445     //   Solve {0,+,A} in Range  ===  Ax in Range
6446
6447     // We know that zero is in the range.  If A is positive then we know that
6448     // the upper value of the range must be the first possible exit value.
6449     // If A is negative then the lower of the range is the last possible loop
6450     // value.  Also note that we already checked for a full range.
6451     APInt One(BitWidth,1);
6452     APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
6453     APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
6454
6455     // The exit value should be (End+A)/A.
6456     APInt ExitVal = (End + A).udiv(A);
6457     ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
6458
6459     // Evaluate at the exit value.  If we really did fall out of the valid
6460     // range, then we computed our trip count, otherwise wrap around or other
6461     // things must have happened.
6462     ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
6463     if (Range.contains(Val->getValue()))
6464       return SE.getCouldNotCompute();  // Something strange happened
6465
6466     // Ensure that the previous value is in the range.  This is a sanity check.
6467     assert(Range.contains(
6468            EvaluateConstantChrecAtConstant(this,
6469            ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) &&
6470            "Linear scev computation is off in a bad way!");
6471     return SE.getConstant(ExitValue);
6472   } else if (isQuadratic()) {
6473     // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
6474     // quadratic equation to solve it.  To do this, we must frame our problem in
6475     // terms of figuring out when zero is crossed, instead of when
6476     // Range.getUpper() is crossed.
6477     SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end());
6478     NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
6479     const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(),
6480                                              // getNoWrapFlags(FlagNW)
6481                                              FlagAnyWrap);
6482
6483     // Next, solve the constructed addrec
6484     std::pair<const SCEV *,const SCEV *> Roots =
6485       SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
6486     const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
6487     const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
6488     if (R1) {
6489       // Pick the smallest positive root value.
6490       if (ConstantInt *CB =
6491           dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
6492                          R1->getValue(), R2->getValue()))) {
6493         if (CB->getZExtValue() == false)
6494           std::swap(R1, R2);   // R1 is the minimum root now.
6495
6496         // Make sure the root is not off by one.  The returned iteration should
6497         // not be in the range, but the previous one should be.  When solving
6498         // for "X*X < 5", for example, we should not return a root of 2.
6499         ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
6500                                                              R1->getValue(),
6501                                                              SE);
6502         if (Range.contains(R1Val->getValue())) {
6503           // The next iteration must be out of the range...
6504           ConstantInt *NextVal =
6505                 ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1);
6506
6507           R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
6508           if (!Range.contains(R1Val->getValue()))
6509             return SE.getConstant(NextVal);
6510           return SE.getCouldNotCompute();  // Something strange happened
6511         }
6512
6513         // If R1 was not in the range, then it is a good return value.  Make
6514         // sure that R1-1 WAS in the range though, just in case.
6515         ConstantInt *NextVal =
6516                ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1);
6517         R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
6518         if (Range.contains(R1Val->getValue()))
6519           return R1;
6520         return SE.getCouldNotCompute();  // Something strange happened
6521       }
6522     }
6523   }
6524
6525   return SE.getCouldNotCompute();
6526 }
6527
6528
6529
6530 //===----------------------------------------------------------------------===//
6531 //                   SCEVCallbackVH Class Implementation
6532 //===----------------------------------------------------------------------===//
6533
6534 void ScalarEvolution::SCEVCallbackVH::deleted() {
6535   assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
6536   if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
6537     SE->ConstantEvolutionLoopExitValue.erase(PN);
6538   SE->ValueExprMap.erase(getValPtr());
6539   // this now dangles!
6540 }
6541
6542 void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
6543   assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
6544
6545   // Forget all the expressions associated with users of the old value,
6546   // so that future queries will recompute the expressions using the new
6547   // value.
6548   Value *Old = getValPtr();
6549   SmallVector<User *, 16> Worklist;
6550   SmallPtrSet<User *, 8> Visited;
6551   for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end();
6552        UI != UE; ++UI)
6553     Worklist.push_back(*UI);
6554   while (!Worklist.empty()) {
6555     User *U = Worklist.pop_back_val();
6556     // Deleting the Old value will cause this to dangle. Postpone
6557     // that until everything else is done.
6558     if (U == Old)
6559       continue;
6560     if (!Visited.insert(U))
6561       continue;
6562     if (PHINode *PN = dyn_cast<PHINode>(U))
6563       SE->ConstantEvolutionLoopExitValue.erase(PN);
6564     SE->ValueExprMap.erase(U);
6565     for (Value::use_iterator UI = U->use_begin(), UE = U->use_end();
6566          UI != UE; ++UI)
6567       Worklist.push_back(*UI);
6568   }
6569   // Delete the Old value.
6570   if (PHINode *PN = dyn_cast<PHINode>(Old))
6571     SE->ConstantEvolutionLoopExitValue.erase(PN);
6572   SE->ValueExprMap.erase(Old);
6573   // this now dangles!
6574 }
6575
6576 ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
6577   : CallbackVH(V), SE(se) {}
6578
6579 //===----------------------------------------------------------------------===//
6580 //                   ScalarEvolution Class Implementation
6581 //===----------------------------------------------------------------------===//
6582
6583 ScalarEvolution::ScalarEvolution()
6584   : FunctionPass(ID), FirstUnknown(0) {
6585   initializeScalarEvolutionPass(*PassRegistry::getPassRegistry());
6586 }
6587
6588 bool ScalarEvolution::runOnFunction(Function &F) {
6589   this->F = &F;
6590   LI = &getAnalysis<LoopInfo>();
6591   TD = getAnalysisIfAvailable<TargetData>();
6592   TLI = &getAnalysis<TargetLibraryInfo>();
6593   DT = &getAnalysis<DominatorTree>();
6594   return false;
6595 }
6596
6597 void ScalarEvolution::releaseMemory() {
6598   // Iterate through all the SCEVUnknown instances and call their
6599   // destructors, so that they release their references to their values.
6600   for (SCEVUnknown *U = FirstUnknown; U; U = U->Next)
6601     U->~SCEVUnknown();
6602   FirstUnknown = 0;
6603
6604   ValueExprMap.clear();
6605
6606   // Free any extra memory created for ExitNotTakenInfo in the unlikely event
6607   // that a loop had multiple computable exits.
6608   for (DenseMap<const Loop*, BackedgeTakenInfo>::iterator I =
6609          BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end();
6610        I != E; ++I) {
6611     I->second.clear();
6612   }
6613
6614   BackedgeTakenCounts.clear();
6615   ConstantEvolutionLoopExitValue.clear();
6616   ValuesAtScopes.clear();
6617   LoopDispositions.clear();
6618   BlockDispositions.clear();
6619   UnsignedRanges.clear();
6620   SignedRanges.clear();
6621   UniqueSCEVs.clear();
6622   SCEVAllocator.Reset();
6623 }
6624
6625 void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
6626   AU.setPreservesAll();
6627   AU.addRequiredTransitive<LoopInfo>();
6628   AU.addRequiredTransitive<DominatorTree>();
6629   AU.addRequired<TargetLibraryInfo>();
6630 }
6631
6632 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
6633   return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
6634 }
6635
6636 static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
6637                           const Loop *L) {
6638   // Print all inner loops first
6639   for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
6640     PrintLoopInfo(OS, SE, *I);
6641
6642   OS << "Loop ";
6643   WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
6644   OS << ": ";
6645
6646   SmallVector<BasicBlock *, 8> ExitBlocks;
6647   L->getExitBlocks(ExitBlocks);
6648   if (ExitBlocks.size() != 1)
6649     OS << "<multiple exits> ";
6650
6651   if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
6652     OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L);
6653   } else {
6654     OS << "Unpredictable backedge-taken count. ";
6655   }
6656
6657   OS << "\n"
6658         "Loop ";
6659   WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
6660   OS << ": ";
6661
6662   if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
6663     OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
6664   } else {
6665     OS << "Unpredictable max backedge-taken count. ";
6666   }
6667
6668   OS << "\n";
6669 }
6670
6671 void ScalarEvolution::print(raw_ostream &OS, const Module *) const {
6672   // ScalarEvolution's implementation of the print method is to print
6673   // out SCEV values of all instructions that are interesting. Doing
6674   // this potentially causes it to create new SCEV objects though,
6675   // which technically conflicts with the const qualifier. This isn't
6676   // observable from outside the class though, so casting away the
6677   // const isn't dangerous.
6678   ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
6679
6680   OS << "Classifying expressions for: ";
6681   WriteAsOperand(OS, F, /*PrintType=*/false);
6682   OS << "\n";
6683   for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
6684     if (isSCEVable(I->getType()) && !isa<CmpInst>(*I)) {
6685       OS << *I << '\n';
6686       OS << "  -->  ";
6687       const SCEV *SV = SE.getSCEV(&*I);
6688       SV->print(OS);
6689
6690       const Loop *L = LI->getLoopFor((*I).getParent());
6691
6692       const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
6693       if (AtUse != SV) {
6694         OS << "  -->  ";
6695         AtUse->print(OS);
6696       }
6697
6698       if (L) {
6699         OS << "\t\t" "Exits: ";
6700         const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
6701         if (!SE.isLoopInvariant(ExitValue, L)) {
6702           OS << "<<Unknown>>";
6703         } else {
6704           OS << *ExitValue;
6705         }
6706       }
6707
6708       OS << "\n";
6709     }
6710
6711   OS << "Determining loop execution counts for: ";
6712   WriteAsOperand(OS, F, /*PrintType=*/false);
6713   OS << "\n";
6714   for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
6715     PrintLoopInfo(OS, &SE, *I);
6716 }
6717
6718 ScalarEvolution::LoopDisposition
6719 ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
6720   std::map<const Loop *, LoopDisposition> &Values = LoopDispositions[S];
6721   std::pair<std::map<const Loop *, LoopDisposition>::iterator, bool> Pair =
6722     Values.insert(std::make_pair(L, LoopVariant));
6723   if (!Pair.second)
6724     return Pair.first->second;
6725
6726   LoopDisposition D = computeLoopDisposition(S, L);
6727   return LoopDispositions[S][L] = D;
6728 }
6729
6730 ScalarEvolution::LoopDisposition
6731 ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
6732   switch (S->getSCEVType()) {
6733   case scConstant:
6734     return LoopInvariant;
6735   case scTruncate:
6736   case scZeroExtend:
6737   case scSignExtend:
6738     return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L);
6739   case scAddRecExpr: {
6740     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
6741
6742     // If L is the addrec's loop, it's computable.
6743     if (AR->getLoop() == L)
6744       return LoopComputable;
6745
6746     // Add recurrences are never invariant in the function-body (null loop).
6747     if (!L)
6748       return LoopVariant;
6749
6750     // This recurrence is variant w.r.t. L if L contains AR's loop.
6751     if (L->contains(AR->getLoop()))
6752       return LoopVariant;
6753
6754     // This recurrence is invariant w.r.t. L if AR's loop contains L.
6755     if (AR->getLoop()->contains(L))
6756       return LoopInvariant;
6757
6758     // This recurrence is variant w.r.t. L if any of its operands
6759     // are variant.
6760     for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
6761          I != E; ++I)
6762       if (!isLoopInvariant(*I, L))
6763         return LoopVariant;
6764
6765     // Otherwise it's loop-invariant.
6766     return LoopInvariant;
6767   }
6768   case scAddExpr:
6769   case scMulExpr:
6770   case scUMaxExpr:
6771   case scSMaxExpr: {
6772     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
6773     bool HasVarying = false;
6774     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
6775          I != E; ++I) {
6776       LoopDisposition D = getLoopDisposition(*I, L);
6777       if (D == LoopVariant)
6778         return LoopVariant;
6779       if (D == LoopComputable)
6780         HasVarying = true;
6781     }
6782     return HasVarying ? LoopComputable : LoopInvariant;
6783   }
6784   case scUDivExpr: {
6785     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6786     LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L);
6787     if (LD == LoopVariant)
6788       return LoopVariant;
6789     LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L);
6790     if (RD == LoopVariant)
6791       return LoopVariant;
6792     return (LD == LoopInvariant && RD == LoopInvariant) ?
6793            LoopInvariant : LoopComputable;
6794   }
6795   case scUnknown:
6796     // All non-instruction values are loop invariant.  All instructions are loop
6797     // invariant if they are not contained in the specified loop.
6798     // Instructions are never considered invariant in the function body
6799     // (null loop) because they are defined within the "loop".
6800     if (Instruction *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
6801       return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
6802     return LoopInvariant;
6803   case scCouldNotCompute:
6804     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6805     return LoopVariant;
6806   default: break;
6807   }
6808   llvm_unreachable("Unknown SCEV kind!");
6809   return LoopVariant;
6810 }
6811
6812 bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) {
6813   return getLoopDisposition(S, L) == LoopInvariant;
6814 }
6815
6816 bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
6817   return getLoopDisposition(S, L) == LoopComputable;
6818 }
6819
6820 ScalarEvolution::BlockDisposition
6821 ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
6822   std::map<const BasicBlock *, BlockDisposition> &Values = BlockDispositions[S];
6823   std::pair<std::map<const BasicBlock *, BlockDisposition>::iterator, bool>
6824     Pair = Values.insert(std::make_pair(BB, DoesNotDominateBlock));
6825   if (!Pair.second)
6826     return Pair.first->second;
6827
6828   BlockDisposition D = computeBlockDisposition(S, BB);
6829   return BlockDispositions[S][BB] = D;
6830 }
6831
6832 ScalarEvolution::BlockDisposition
6833 ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
6834   switch (S->getSCEVType()) {
6835   case scConstant:
6836     return ProperlyDominatesBlock;
6837   case scTruncate:
6838   case scZeroExtend:
6839   case scSignExtend:
6840     return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB);
6841   case scAddRecExpr: {
6842     // This uses a "dominates" query instead of "properly dominates" query
6843     // to test for proper dominance too, because the instruction which
6844     // produces the addrec's value is a PHI, and a PHI effectively properly
6845     // dominates its entire containing block.
6846     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
6847     if (!DT->dominates(AR->getLoop()->getHeader(), BB))
6848       return DoesNotDominateBlock;
6849   }
6850   // FALL THROUGH into SCEVNAryExpr handling.
6851   case scAddExpr:
6852   case scMulExpr:
6853   case scUMaxExpr:
6854   case scSMaxExpr: {
6855     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
6856     bool Proper = true;
6857     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
6858          I != E; ++I) {
6859       BlockDisposition D = getBlockDisposition(*I, BB);
6860       if (D == DoesNotDominateBlock)
6861         return DoesNotDominateBlock;
6862       if (D == DominatesBlock)
6863         Proper = false;
6864     }
6865     return Proper ? ProperlyDominatesBlock : DominatesBlock;
6866   }
6867   case scUDivExpr: {
6868     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6869     const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
6870     BlockDisposition LD = getBlockDisposition(LHS, BB);
6871     if (LD == DoesNotDominateBlock)
6872       return DoesNotDominateBlock;
6873     BlockDisposition RD = getBlockDisposition(RHS, BB);
6874     if (RD == DoesNotDominateBlock)
6875       return DoesNotDominateBlock;
6876     return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ?
6877       ProperlyDominatesBlock : DominatesBlock;
6878   }
6879   case scUnknown:
6880     if (Instruction *I =
6881           dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
6882       if (I->getParent() == BB)
6883         return DominatesBlock;
6884       if (DT->properlyDominates(I->getParent(), BB))
6885         return ProperlyDominatesBlock;
6886       return DoesNotDominateBlock;
6887     }
6888     return ProperlyDominatesBlock;
6889   case scCouldNotCompute:
6890     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6891     return DoesNotDominateBlock;
6892   default: break;
6893   }
6894   llvm_unreachable("Unknown SCEV kind!");
6895   return DoesNotDominateBlock;
6896 }
6897
6898 bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
6899   return getBlockDisposition(S, BB) >= DominatesBlock;
6900 }
6901
6902 bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
6903   return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
6904 }
6905
6906 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
6907   switch (S->getSCEVType()) {
6908   case scConstant:
6909     return false;
6910   case scTruncate:
6911   case scZeroExtend:
6912   case scSignExtend: {
6913     const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6914     const SCEV *CastOp = Cast->getOperand();
6915     return Op == CastOp || hasOperand(CastOp, Op);
6916   }
6917   case scAddRecExpr:
6918   case scAddExpr:
6919   case scMulExpr:
6920   case scUMaxExpr:
6921   case scSMaxExpr: {
6922     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
6923     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
6924          I != E; ++I) {
6925       const SCEV *NAryOp = *I;
6926       if (NAryOp == Op || hasOperand(NAryOp, Op))
6927         return true;
6928     }
6929     return false;
6930   }
6931   case scUDivExpr: {
6932     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6933     const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
6934     return LHS == Op || hasOperand(LHS, Op) ||
6935            RHS == Op || hasOperand(RHS, Op);
6936   }
6937   case scUnknown:
6938     return false;
6939   case scCouldNotCompute:
6940     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6941     return false;
6942   default: break;
6943   }
6944   llvm_unreachable("Unknown SCEV kind!");
6945   return false;
6946 }
6947
6948 void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
6949   ValuesAtScopes.erase(S);
6950   LoopDispositions.erase(S);
6951   BlockDispositions.erase(S);
6952   UnsignedRanges.erase(S);
6953   SignedRanges.erase(S);
6954 }