Expose isNonConstantNegative to users of ScalarEvolution.
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains the implementation of the scalar evolution analysis
11 // engine, which is used primarily to analyze expressions involving induction
12 // variables in loops.
13 //
14 // There are several aspects to this library.  First is the representation of
15 // scalar expressions, which are represented as subclasses of the SCEV class.
16 // These classes are used to represent certain types of subexpressions that we
17 // can handle. We only create one SCEV of a particular shape, so
18 // pointer-comparisons for equality are legal.
19 //
20 // One important aspect of the SCEV objects is that they are never cyclic, even
21 // if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
22 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
23 // recurrence) then we represent it directly as a recurrence node, otherwise we
24 // represent it as a SCEVUnknown node.
25 //
26 // In addition to being able to represent expressions of various types, we also
27 // have folders that are used to build the *canonical* representation for a
28 // particular expression.  These folders are capable of using a variety of
29 // rewrite rules to simplify the expressions.
30 //
31 // Once the folders are defined, we can implement the more interesting
32 // higher-level code, such as the code that recognizes PHI nodes of various
33 // types, computes the execution count of a loop, etc.
34 //
35 // TODO: We should use these routines and value representations to implement
36 // dependence analysis!
37 //
38 //===----------------------------------------------------------------------===//
39 //
40 // There are several good references for the techniques used in this analysis.
41 //
42 //  Chains of recurrences -- a method to expedite the evaluation
43 //  of closed-form functions
44 //  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
45 //
46 //  On computational properties of chains of recurrences
47 //  Eugene V. Zima
48 //
49 //  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
50 //  Robert A. van Engelen
51 //
52 //  Efficient Symbolic Analysis for Optimizing Compilers
53 //  Robert A. van Engelen
54 //
55 //  Using the chains of recurrences algebra for data dependence testing and
56 //  induction variable substitution
57 //  MS Thesis, Johnie Birch
58 //
59 //===----------------------------------------------------------------------===//
60
61 #define DEBUG_TYPE "scalar-evolution"
62 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
63 #include "llvm/Constants.h"
64 #include "llvm/DerivedTypes.h"
65 #include "llvm/GlobalVariable.h"
66 #include "llvm/GlobalAlias.h"
67 #include "llvm/Instructions.h"
68 #include "llvm/LLVMContext.h"
69 #include "llvm/Operator.h"
70 #include "llvm/Analysis/ConstantFolding.h"
71 #include "llvm/Analysis/Dominators.h"
72 #include "llvm/Analysis/InstructionSimplify.h"
73 #include "llvm/Analysis/LoopInfo.h"
74 #include "llvm/Analysis/ValueTracking.h"
75 #include "llvm/Assembly/Writer.h"
76 #include "llvm/Target/TargetData.h"
77 #include "llvm/Target/TargetLibraryInfo.h"
78 #include "llvm/Support/CommandLine.h"
79 #include "llvm/Support/ConstantRange.h"
80 #include "llvm/Support/Debug.h"
81 #include "llvm/Support/ErrorHandling.h"
82 #include "llvm/Support/GetElementPtrTypeIterator.h"
83 #include "llvm/Support/InstIterator.h"
84 #include "llvm/Support/MathExtras.h"
85 #include "llvm/Support/raw_ostream.h"
86 #include "llvm/ADT/Statistic.h"
87 #include "llvm/ADT/STLExtras.h"
88 #include "llvm/ADT/SmallPtrSet.h"
89 #include <algorithm>
90 using namespace llvm;
91
92 STATISTIC(NumArrayLenItCounts,
93           "Number of trip counts computed with array length");
94 STATISTIC(NumTripCountsComputed,
95           "Number of loops with predictable loop counts");
96 STATISTIC(NumTripCountsNotComputed,
97           "Number of loops without predictable loop counts");
98 STATISTIC(NumBruteForceTripCountsComputed,
99           "Number of loops with trip counts computed by force");
100
101 static cl::opt<unsigned>
102 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
103                         cl::desc("Maximum number of iterations SCEV will "
104                                  "symbolically execute a constant "
105                                  "derived loop"),
106                         cl::init(100));
107
108 INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution",
109                 "Scalar Evolution Analysis", false, true)
110 INITIALIZE_PASS_DEPENDENCY(LoopInfo)
111 INITIALIZE_PASS_DEPENDENCY(DominatorTree)
112 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo)
113 INITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution",
114                 "Scalar Evolution Analysis", false, true)
115 char ScalarEvolution::ID = 0;
116
117 //===----------------------------------------------------------------------===//
118 //                           SCEV class definitions
119 //===----------------------------------------------------------------------===//
120
121 //===----------------------------------------------------------------------===//
122 // Implementation of the SCEV class.
123 //
124
125 void SCEV::dump() const {
126   print(dbgs());
127   dbgs() << '\n';
128 }
129
130 void SCEV::print(raw_ostream &OS) const {
131   switch (getSCEVType()) {
132   case scConstant:
133     WriteAsOperand(OS, cast<SCEVConstant>(this)->getValue(), false);
134     return;
135   case scTruncate: {
136     const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
137     const SCEV *Op = Trunc->getOperand();
138     OS << "(trunc " << *Op->getType() << " " << *Op << " to "
139        << *Trunc->getType() << ")";
140     return;
141   }
142   case scZeroExtend: {
143     const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
144     const SCEV *Op = ZExt->getOperand();
145     OS << "(zext " << *Op->getType() << " " << *Op << " to "
146        << *ZExt->getType() << ")";
147     return;
148   }
149   case scSignExtend: {
150     const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
151     const SCEV *Op = SExt->getOperand();
152     OS << "(sext " << *Op->getType() << " " << *Op << " to "
153        << *SExt->getType() << ")";
154     return;
155   }
156   case scAddRecExpr: {
157     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
158     OS << "{" << *AR->getOperand(0);
159     for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
160       OS << ",+," << *AR->getOperand(i);
161     OS << "}<";
162     if (AR->getNoWrapFlags(FlagNUW))
163       OS << "nuw><";
164     if (AR->getNoWrapFlags(FlagNSW))
165       OS << "nsw><";
166     if (AR->getNoWrapFlags(FlagNW) &&
167         !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
168       OS << "nw><";
169     WriteAsOperand(OS, AR->getLoop()->getHeader(), /*PrintType=*/false);
170     OS << ">";
171     return;
172   }
173   case scAddExpr:
174   case scMulExpr:
175   case scUMaxExpr:
176   case scSMaxExpr: {
177     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
178     const char *OpStr = 0;
179     switch (NAry->getSCEVType()) {
180     case scAddExpr: OpStr = " + "; break;
181     case scMulExpr: OpStr = " * "; break;
182     case scUMaxExpr: OpStr = " umax "; break;
183     case scSMaxExpr: OpStr = " smax "; break;
184     }
185     OS << "(";
186     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
187          I != E; ++I) {
188       OS << **I;
189       if (llvm::next(I) != E)
190         OS << OpStr;
191     }
192     OS << ")";
193     switch (NAry->getSCEVType()) {
194     case scAddExpr:
195     case scMulExpr:
196       if (NAry->getNoWrapFlags(FlagNUW))
197         OS << "<nuw>";
198       if (NAry->getNoWrapFlags(FlagNSW))
199         OS << "<nsw>";
200     }
201     return;
202   }
203   case scUDivExpr: {
204     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
205     OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
206     return;
207   }
208   case scUnknown: {
209     const SCEVUnknown *U = cast<SCEVUnknown>(this);
210     Type *AllocTy;
211     if (U->isSizeOf(AllocTy)) {
212       OS << "sizeof(" << *AllocTy << ")";
213       return;
214     }
215     if (U->isAlignOf(AllocTy)) {
216       OS << "alignof(" << *AllocTy << ")";
217       return;
218     }
219
220     Type *CTy;
221     Constant *FieldNo;
222     if (U->isOffsetOf(CTy, FieldNo)) {
223       OS << "offsetof(" << *CTy << ", ";
224       WriteAsOperand(OS, FieldNo, false);
225       OS << ")";
226       return;
227     }
228
229     // Otherwise just print it normally.
230     WriteAsOperand(OS, U->getValue(), false);
231     return;
232   }
233   case scCouldNotCompute:
234     OS << "***COULDNOTCOMPUTE***";
235     return;
236   default: break;
237   }
238   llvm_unreachable("Unknown SCEV kind!");
239 }
240
241 Type *SCEV::getType() const {
242   switch (getSCEVType()) {
243   case scConstant:
244     return cast<SCEVConstant>(this)->getType();
245   case scTruncate:
246   case scZeroExtend:
247   case scSignExtend:
248     return cast<SCEVCastExpr>(this)->getType();
249   case scAddRecExpr:
250   case scMulExpr:
251   case scUMaxExpr:
252   case scSMaxExpr:
253     return cast<SCEVNAryExpr>(this)->getType();
254   case scAddExpr:
255     return cast<SCEVAddExpr>(this)->getType();
256   case scUDivExpr:
257     return cast<SCEVUDivExpr>(this)->getType();
258   case scUnknown:
259     return cast<SCEVUnknown>(this)->getType();
260   case scCouldNotCompute:
261     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
262     return 0;
263   default: break;
264   }
265   llvm_unreachable("Unknown SCEV kind!");
266   return 0;
267 }
268
269 bool SCEV::isZero() const {
270   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
271     return SC->getValue()->isZero();
272   return false;
273 }
274
275 bool SCEV::isOne() const {
276   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
277     return SC->getValue()->isOne();
278   return false;
279 }
280
281 bool SCEV::isAllOnesValue() const {
282   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
283     return SC->getValue()->isAllOnesValue();
284   return false;
285 }
286
287 /// isNonConstantNegative - Return true if the specified scev is negated, but
288 /// not a constant.
289 bool SCEV::isNonConstantNegative() const {
290   const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
291   if (!Mul) return false;
292
293   // If there is a constant factor, it will be first.
294   const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
295   if (!SC) return false;
296
297   // Return true if the value is negative, this matches things like (-42 * V).
298   return SC->getValue()->getValue().isNegative();
299 }
300
301 SCEVCouldNotCompute::SCEVCouldNotCompute() :
302   SCEV(FoldingSetNodeIDRef(), scCouldNotCompute) {}
303
304 bool SCEVCouldNotCompute::classof(const SCEV *S) {
305   return S->getSCEVType() == scCouldNotCompute;
306 }
307
308 const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
309   FoldingSetNodeID ID;
310   ID.AddInteger(scConstant);
311   ID.AddPointer(V);
312   void *IP = 0;
313   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
314   SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
315   UniqueSCEVs.InsertNode(S, IP);
316   return S;
317 }
318
319 const SCEV *ScalarEvolution::getConstant(const APInt& Val) {
320   return getConstant(ConstantInt::get(getContext(), Val));
321 }
322
323 const SCEV *
324 ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
325   IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
326   return getConstant(ConstantInt::get(ITy, V, isSigned));
327 }
328
329 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID,
330                            unsigned SCEVTy, const SCEV *op, Type *ty)
331   : SCEV(ID, SCEVTy), Op(op), Ty(ty) {}
332
333 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
334                                    const SCEV *op, Type *ty)
335   : SCEVCastExpr(ID, scTruncate, op, ty) {
336   assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
337          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
338          "Cannot truncate non-integer value!");
339 }
340
341 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
342                                        const SCEV *op, Type *ty)
343   : SCEVCastExpr(ID, scZeroExtend, op, ty) {
344   assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
345          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
346          "Cannot zero extend non-integer value!");
347 }
348
349 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
350                                        const SCEV *op, Type *ty)
351   : SCEVCastExpr(ID, scSignExtend, op, ty) {
352   assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
353          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
354          "Cannot sign extend non-integer value!");
355 }
356
357 void SCEVUnknown::deleted() {
358   // Clear this SCEVUnknown from various maps.
359   SE->forgetMemoizedResults(this);
360
361   // Remove this SCEVUnknown from the uniquing map.
362   SE->UniqueSCEVs.RemoveNode(this);
363
364   // Release the value.
365   setValPtr(0);
366 }
367
368 void SCEVUnknown::allUsesReplacedWith(Value *New) {
369   // Clear this SCEVUnknown from various maps.
370   SE->forgetMemoizedResults(this);
371
372   // Remove this SCEVUnknown from the uniquing map.
373   SE->UniqueSCEVs.RemoveNode(this);
374
375   // Update this SCEVUnknown to point to the new value. This is needed
376   // because there may still be outstanding SCEVs which still point to
377   // this SCEVUnknown.
378   setValPtr(New);
379 }
380
381 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
382   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
383     if (VCE->getOpcode() == Instruction::PtrToInt)
384       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
385         if (CE->getOpcode() == Instruction::GetElementPtr &&
386             CE->getOperand(0)->isNullValue() &&
387             CE->getNumOperands() == 2)
388           if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
389             if (CI->isOne()) {
390               AllocTy = cast<PointerType>(CE->getOperand(0)->getType())
391                                  ->getElementType();
392               return true;
393             }
394
395   return false;
396 }
397
398 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
399   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
400     if (VCE->getOpcode() == Instruction::PtrToInt)
401       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
402         if (CE->getOpcode() == Instruction::GetElementPtr &&
403             CE->getOperand(0)->isNullValue()) {
404           Type *Ty =
405             cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
406           if (StructType *STy = dyn_cast<StructType>(Ty))
407             if (!STy->isPacked() &&
408                 CE->getNumOperands() == 3 &&
409                 CE->getOperand(1)->isNullValue()) {
410               if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
411                 if (CI->isOne() &&
412                     STy->getNumElements() == 2 &&
413                     STy->getElementType(0)->isIntegerTy(1)) {
414                   AllocTy = STy->getElementType(1);
415                   return true;
416                 }
417             }
418         }
419
420   return false;
421 }
422
423 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
424   if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
425     if (VCE->getOpcode() == Instruction::PtrToInt)
426       if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
427         if (CE->getOpcode() == Instruction::GetElementPtr &&
428             CE->getNumOperands() == 3 &&
429             CE->getOperand(0)->isNullValue() &&
430             CE->getOperand(1)->isNullValue()) {
431           Type *Ty =
432             cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
433           // Ignore vector types here so that ScalarEvolutionExpander doesn't
434           // emit getelementptrs that index into vectors.
435           if (Ty->isStructTy() || Ty->isArrayTy()) {
436             CTy = Ty;
437             FieldNo = CE->getOperand(2);
438             return true;
439           }
440         }
441
442   return false;
443 }
444
445 //===----------------------------------------------------------------------===//
446 //                               SCEV Utilities
447 //===----------------------------------------------------------------------===//
448
449 namespace {
450   /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
451   /// than the complexity of the RHS.  This comparator is used to canonicalize
452   /// expressions.
453   class SCEVComplexityCompare {
454     const LoopInfo *const LI;
455   public:
456     explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {}
457
458     // Return true or false if LHS is less than, or at least RHS, respectively.
459     bool operator()(const SCEV *LHS, const SCEV *RHS) const {
460       return compare(LHS, RHS) < 0;
461     }
462
463     // Return negative, zero, or positive, if LHS is less than, equal to, or
464     // greater than RHS, respectively. A three-way result allows recursive
465     // comparisons to be more efficient.
466     int compare(const SCEV *LHS, const SCEV *RHS) const {
467       // Fast-path: SCEVs are uniqued so we can do a quick equality check.
468       if (LHS == RHS)
469         return 0;
470
471       // Primarily, sort the SCEVs by their getSCEVType().
472       unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
473       if (LType != RType)
474         return (int)LType - (int)RType;
475
476       // Aside from the getSCEVType() ordering, the particular ordering
477       // isn't very important except that it's beneficial to be consistent,
478       // so that (a + b) and (b + a) don't end up as different expressions.
479       switch (LType) {
480       case scUnknown: {
481         const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
482         const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
483
484         // Sort SCEVUnknown values with some loose heuristics. TODO: This is
485         // not as complete as it could be.
486         const Value *LV = LU->getValue(), *RV = RU->getValue();
487
488         // Order pointer values after integer values. This helps SCEVExpander
489         // form GEPs.
490         bool LIsPointer = LV->getType()->isPointerTy(),
491              RIsPointer = RV->getType()->isPointerTy();
492         if (LIsPointer != RIsPointer)
493           return (int)LIsPointer - (int)RIsPointer;
494
495         // Compare getValueID values.
496         unsigned LID = LV->getValueID(),
497                  RID = RV->getValueID();
498         if (LID != RID)
499           return (int)LID - (int)RID;
500
501         // Sort arguments by their position.
502         if (const Argument *LA = dyn_cast<Argument>(LV)) {
503           const Argument *RA = cast<Argument>(RV);
504           unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
505           return (int)LArgNo - (int)RArgNo;
506         }
507
508         // For instructions, compare their loop depth, and their operand
509         // count.  This is pretty loose.
510         if (const Instruction *LInst = dyn_cast<Instruction>(LV)) {
511           const Instruction *RInst = cast<Instruction>(RV);
512
513           // Compare loop depths.
514           const BasicBlock *LParent = LInst->getParent(),
515                            *RParent = RInst->getParent();
516           if (LParent != RParent) {
517             unsigned LDepth = LI->getLoopDepth(LParent),
518                      RDepth = LI->getLoopDepth(RParent);
519             if (LDepth != RDepth)
520               return (int)LDepth - (int)RDepth;
521           }
522
523           // Compare the number of operands.
524           unsigned LNumOps = LInst->getNumOperands(),
525                    RNumOps = RInst->getNumOperands();
526           return (int)LNumOps - (int)RNumOps;
527         }
528
529         return 0;
530       }
531
532       case scConstant: {
533         const SCEVConstant *LC = cast<SCEVConstant>(LHS);
534         const SCEVConstant *RC = cast<SCEVConstant>(RHS);
535
536         // Compare constant values.
537         const APInt &LA = LC->getValue()->getValue();
538         const APInt &RA = RC->getValue()->getValue();
539         unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
540         if (LBitWidth != RBitWidth)
541           return (int)LBitWidth - (int)RBitWidth;
542         return LA.ult(RA) ? -1 : 1;
543       }
544
545       case scAddRecExpr: {
546         const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
547         const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
548
549         // Compare addrec loop depths.
550         const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
551         if (LLoop != RLoop) {
552           unsigned LDepth = LLoop->getLoopDepth(),
553                    RDepth = RLoop->getLoopDepth();
554           if (LDepth != RDepth)
555             return (int)LDepth - (int)RDepth;
556         }
557
558         // Addrec complexity grows with operand count.
559         unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
560         if (LNumOps != RNumOps)
561           return (int)LNumOps - (int)RNumOps;
562
563         // Lexicographically compare.
564         for (unsigned i = 0; i != LNumOps; ++i) {
565           long X = compare(LA->getOperand(i), RA->getOperand(i));
566           if (X != 0)
567             return X;
568         }
569
570         return 0;
571       }
572
573       case scAddExpr:
574       case scMulExpr:
575       case scSMaxExpr:
576       case scUMaxExpr: {
577         const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
578         const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
579
580         // Lexicographically compare n-ary expressions.
581         unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
582         for (unsigned i = 0; i != LNumOps; ++i) {
583           if (i >= RNumOps)
584             return 1;
585           long X = compare(LC->getOperand(i), RC->getOperand(i));
586           if (X != 0)
587             return X;
588         }
589         return (int)LNumOps - (int)RNumOps;
590       }
591
592       case scUDivExpr: {
593         const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
594         const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
595
596         // Lexicographically compare udiv expressions.
597         long X = compare(LC->getLHS(), RC->getLHS());
598         if (X != 0)
599           return X;
600         return compare(LC->getRHS(), RC->getRHS());
601       }
602
603       case scTruncate:
604       case scZeroExtend:
605       case scSignExtend: {
606         const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
607         const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
608
609         // Compare cast expressions by operand.
610         return compare(LC->getOperand(), RC->getOperand());
611       }
612
613       default:
614         break;
615       }
616
617       llvm_unreachable("Unknown SCEV kind!");
618       return 0;
619     }
620   };
621 }
622
623 /// GroupByComplexity - Given a list of SCEV objects, order them by their
624 /// complexity, and group objects of the same complexity together by value.
625 /// When this routine is finished, we know that any duplicates in the vector are
626 /// consecutive and that complexity is monotonically increasing.
627 ///
628 /// Note that we go take special precautions to ensure that we get deterministic
629 /// results from this routine.  In other words, we don't want the results of
630 /// this to depend on where the addresses of various SCEV objects happened to
631 /// land in memory.
632 ///
633 static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
634                               LoopInfo *LI) {
635   if (Ops.size() < 2) return;  // Noop
636   if (Ops.size() == 2) {
637     // This is the common case, which also happens to be trivially simple.
638     // Special case it.
639     const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
640     if (SCEVComplexityCompare(LI)(RHS, LHS))
641       std::swap(LHS, RHS);
642     return;
643   }
644
645   // Do the rough sort by complexity.
646   std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI));
647
648   // Now that we are sorted by complexity, group elements of the same
649   // complexity.  Note that this is, at worst, N^2, but the vector is likely to
650   // be extremely short in practice.  Note that we take this approach because we
651   // do not want to depend on the addresses of the objects we are grouping.
652   for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
653     const SCEV *S = Ops[i];
654     unsigned Complexity = S->getSCEVType();
655
656     // If there are any objects of the same complexity and same value as this
657     // one, group them.
658     for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
659       if (Ops[j] == S) { // Found a duplicate.
660         // Move it to immediately after i'th element.
661         std::swap(Ops[i+1], Ops[j]);
662         ++i;   // no need to rescan it.
663         if (i == e-2) return;  // Done!
664       }
665     }
666   }
667 }
668
669
670
671 //===----------------------------------------------------------------------===//
672 //                      Simple SCEV method implementations
673 //===----------------------------------------------------------------------===//
674
675 /// BinomialCoefficient - Compute BC(It, K).  The result has width W.
676 /// Assume, K > 0.
677 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
678                                        ScalarEvolution &SE,
679                                        Type *ResultTy) {
680   // Handle the simplest case efficiently.
681   if (K == 1)
682     return SE.getTruncateOrZeroExtend(It, ResultTy);
683
684   // We are using the following formula for BC(It, K):
685   //
686   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
687   //
688   // Suppose, W is the bitwidth of the return value.  We must be prepared for
689   // overflow.  Hence, we must assure that the result of our computation is
690   // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
691   // safe in modular arithmetic.
692   //
693   // However, this code doesn't use exactly that formula; the formula it uses
694   // is something like the following, where T is the number of factors of 2 in
695   // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
696   // exponentiation:
697   //
698   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
699   //
700   // This formula is trivially equivalent to the previous formula.  However,
701   // this formula can be implemented much more efficiently.  The trick is that
702   // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
703   // arithmetic.  To do exact division in modular arithmetic, all we have
704   // to do is multiply by the inverse.  Therefore, this step can be done at
705   // width W.
706   //
707   // The next issue is how to safely do the division by 2^T.  The way this
708   // is done is by doing the multiplication step at a width of at least W + T
709   // bits.  This way, the bottom W+T bits of the product are accurate. Then,
710   // when we perform the division by 2^T (which is equivalent to a right shift
711   // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
712   // truncated out after the division by 2^T.
713   //
714   // In comparison to just directly using the first formula, this technique
715   // is much more efficient; using the first formula requires W * K bits,
716   // but this formula less than W + K bits. Also, the first formula requires
717   // a division step, whereas this formula only requires multiplies and shifts.
718   //
719   // It doesn't matter whether the subtraction step is done in the calculation
720   // width or the input iteration count's width; if the subtraction overflows,
721   // the result must be zero anyway.  We prefer here to do it in the width of
722   // the induction variable because it helps a lot for certain cases; CodeGen
723   // isn't smart enough to ignore the overflow, which leads to much less
724   // efficient code if the width of the subtraction is wider than the native
725   // register width.
726   //
727   // (It's possible to not widen at all by pulling out factors of 2 before
728   // the multiplication; for example, K=2 can be calculated as
729   // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
730   // extra arithmetic, so it's not an obvious win, and it gets
731   // much more complicated for K > 3.)
732
733   // Protection from insane SCEVs; this bound is conservative,
734   // but it probably doesn't matter.
735   if (K > 1000)
736     return SE.getCouldNotCompute();
737
738   unsigned W = SE.getTypeSizeInBits(ResultTy);
739
740   // Calculate K! / 2^T and T; we divide out the factors of two before
741   // multiplying for calculating K! / 2^T to avoid overflow.
742   // Other overflow doesn't matter because we only care about the bottom
743   // W bits of the result.
744   APInt OddFactorial(W, 1);
745   unsigned T = 1;
746   for (unsigned i = 3; i <= K; ++i) {
747     APInt Mult(W, i);
748     unsigned TwoFactors = Mult.countTrailingZeros();
749     T += TwoFactors;
750     Mult = Mult.lshr(TwoFactors);
751     OddFactorial *= Mult;
752   }
753
754   // We need at least W + T bits for the multiplication step
755   unsigned CalculationBits = W + T;
756
757   // Calculate 2^T, at width T+W.
758   APInt DivFactor = APInt(CalculationBits, 1).shl(T);
759
760   // Calculate the multiplicative inverse of K! / 2^T;
761   // this multiplication factor will perform the exact division by
762   // K! / 2^T.
763   APInt Mod = APInt::getSignedMinValue(W+1);
764   APInt MultiplyFactor = OddFactorial.zext(W+1);
765   MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
766   MultiplyFactor = MultiplyFactor.trunc(W);
767
768   // Calculate the product, at width T+W
769   IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
770                                                       CalculationBits);
771   const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
772   for (unsigned i = 1; i != K; ++i) {
773     const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
774     Dividend = SE.getMulExpr(Dividend,
775                              SE.getTruncateOrZeroExtend(S, CalculationTy));
776   }
777
778   // Divide by 2^T
779   const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
780
781   // Truncate the result, and divide by K! / 2^T.
782
783   return SE.getMulExpr(SE.getConstant(MultiplyFactor),
784                        SE.getTruncateOrZeroExtend(DivResult, ResultTy));
785 }
786
787 /// evaluateAtIteration - Return the value of this chain of recurrences at
788 /// the specified iteration number.  We can evaluate this recurrence by
789 /// multiplying each element in the chain by the binomial coefficient
790 /// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
791 ///
792 ///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
793 ///
794 /// where BC(It, k) stands for binomial coefficient.
795 ///
796 const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
797                                                 ScalarEvolution &SE) const {
798   const SCEV *Result = getStart();
799   for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
800     // The computation is correct in the face of overflow provided that the
801     // multiplication is performed _after_ the evaluation of the binomial
802     // coefficient.
803     const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType());
804     if (isa<SCEVCouldNotCompute>(Coeff))
805       return Coeff;
806
807     Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
808   }
809   return Result;
810 }
811
812 //===----------------------------------------------------------------------===//
813 //                    SCEV Expression folder implementations
814 //===----------------------------------------------------------------------===//
815
816 const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
817                                              Type *Ty) {
818   assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
819          "This is not a truncating conversion!");
820   assert(isSCEVable(Ty) &&
821          "This is not a conversion to a SCEVable type!");
822   Ty = getEffectiveSCEVType(Ty);
823
824   FoldingSetNodeID ID;
825   ID.AddInteger(scTruncate);
826   ID.AddPointer(Op);
827   ID.AddPointer(Ty);
828   void *IP = 0;
829   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
830
831   // Fold if the operand is constant.
832   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
833     return getConstant(
834       cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(),
835                                                getEffectiveSCEVType(Ty))));
836
837   // trunc(trunc(x)) --> trunc(x)
838   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
839     return getTruncateExpr(ST->getOperand(), Ty);
840
841   // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
842   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
843     return getTruncateOrSignExtend(SS->getOperand(), Ty);
844
845   // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
846   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
847     return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
848
849   // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can
850   // eliminate all the truncates.
851   if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) {
852     SmallVector<const SCEV *, 4> Operands;
853     bool hasTrunc = false;
854     for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) {
855       const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty);
856       hasTrunc = isa<SCEVTruncateExpr>(S);
857       Operands.push_back(S);
858     }
859     if (!hasTrunc)
860       return getAddExpr(Operands);
861     UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
862   }
863
864   // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can
865   // eliminate all the truncates.
866   if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) {
867     SmallVector<const SCEV *, 4> Operands;
868     bool hasTrunc = false;
869     for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) {
870       const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty);
871       hasTrunc = isa<SCEVTruncateExpr>(S);
872       Operands.push_back(S);
873     }
874     if (!hasTrunc)
875       return getMulExpr(Operands);
876     UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
877   }
878
879   // If the input value is a chrec scev, truncate the chrec's operands.
880   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
881     SmallVector<const SCEV *, 4> Operands;
882     for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
883       Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
884     return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
885   }
886
887   // As a special case, fold trunc(undef) to undef. We don't want to
888   // know too much about SCEVUnknowns, but this special case is handy
889   // and harmless.
890   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Op))
891     if (isa<UndefValue>(U->getValue()))
892       return getSCEV(UndefValue::get(Ty));
893
894   // The cast wasn't folded; create an explicit cast node. We can reuse
895   // the existing insert position since if we get here, we won't have
896   // made any changes which would invalidate it.
897   SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
898                                                  Op, Ty);
899   UniqueSCEVs.InsertNode(S, IP);
900   return S;
901 }
902
903 const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
904                                                Type *Ty) {
905   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
906          "This is not an extending conversion!");
907   assert(isSCEVable(Ty) &&
908          "This is not a conversion to a SCEVable type!");
909   Ty = getEffectiveSCEVType(Ty);
910
911   // Fold if the operand is constant.
912   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
913     return getConstant(
914       cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(),
915                                               getEffectiveSCEVType(Ty))));
916
917   // zext(zext(x)) --> zext(x)
918   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
919     return getZeroExtendExpr(SZ->getOperand(), Ty);
920
921   // Before doing any expensive analysis, check to see if we've already
922   // computed a SCEV for this Op and Ty.
923   FoldingSetNodeID ID;
924   ID.AddInteger(scZeroExtend);
925   ID.AddPointer(Op);
926   ID.AddPointer(Ty);
927   void *IP = 0;
928   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
929
930   // zext(trunc(x)) --> zext(x) or x or trunc(x)
931   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
932     // It's possible the bits taken off by the truncate were all zero bits. If
933     // so, we should be able to simplify this further.
934     const SCEV *X = ST->getOperand();
935     ConstantRange CR = getUnsignedRange(X);
936     unsigned TruncBits = getTypeSizeInBits(ST->getType());
937     unsigned NewBits = getTypeSizeInBits(Ty);
938     if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
939             CR.zextOrTrunc(NewBits)))
940       return getTruncateOrZeroExtend(X, Ty);
941   }
942
943   // If the input value is a chrec scev, and we can prove that the value
944   // did not overflow the old, smaller, value, we can zero extend all of the
945   // operands (often constants).  This allows analysis of something like
946   // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
947   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
948     if (AR->isAffine()) {
949       const SCEV *Start = AR->getStart();
950       const SCEV *Step = AR->getStepRecurrence(*this);
951       unsigned BitWidth = getTypeSizeInBits(AR->getType());
952       const Loop *L = AR->getLoop();
953
954       // If we have special knowledge that this addrec won't overflow,
955       // we don't need to do any further analysis.
956       if (AR->getNoWrapFlags(SCEV::FlagNUW))
957         return getAddRecExpr(getZeroExtendExpr(Start, Ty),
958                              getZeroExtendExpr(Step, Ty),
959                              L, AR->getNoWrapFlags());
960
961       // Check whether the backedge-taken count is SCEVCouldNotCompute.
962       // Note that this serves two purposes: It filters out loops that are
963       // simply not analyzable, and it covers the case where this code is
964       // being called from within backedge-taken count analysis, such that
965       // attempting to ask for the backedge-taken count would likely result
966       // in infinite recursion. In the later case, the analysis code will
967       // cope with a conservative value, and it will take care to purge
968       // that value once it has finished.
969       const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
970       if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
971         // Manually compute the final value for AR, checking for
972         // overflow.
973
974         // Check whether the backedge-taken count can be losslessly casted to
975         // the addrec's type. The count is always unsigned.
976         const SCEV *CastedMaxBECount =
977           getTruncateOrZeroExtend(MaxBECount, Start->getType());
978         const SCEV *RecastedMaxBECount =
979           getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
980         if (MaxBECount == RecastedMaxBECount) {
981           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
982           // Check whether Start+Step*MaxBECount has no unsigned overflow.
983           const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step);
984           const SCEV *Add = getAddExpr(Start, ZMul);
985           const SCEV *OperandExtendedAdd =
986             getAddExpr(getZeroExtendExpr(Start, WideTy),
987                        getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
988                                   getZeroExtendExpr(Step, WideTy)));
989           if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) {
990             // Cache knowledge of AR NUW, which is propagated to this AddRec.
991             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
992             // Return the expression with the addrec on the outside.
993             return getAddRecExpr(getZeroExtendExpr(Start, Ty),
994                                  getZeroExtendExpr(Step, Ty),
995                                  L, AR->getNoWrapFlags());
996           }
997           // Similar to above, only this time treat the step value as signed.
998           // This covers loops that count down.
999           const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
1000           Add = getAddExpr(Start, SMul);
1001           OperandExtendedAdd =
1002             getAddExpr(getZeroExtendExpr(Start, WideTy),
1003                        getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1004                                   getSignExtendExpr(Step, WideTy)));
1005           if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) {
1006             // Cache knowledge of AR NW, which is propagated to this AddRec.
1007             // Negative step causes unsigned wrap, but it still can't self-wrap.
1008             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1009             // Return the expression with the addrec on the outside.
1010             return getAddRecExpr(getZeroExtendExpr(Start, Ty),
1011                                  getSignExtendExpr(Step, Ty),
1012                                  L, AR->getNoWrapFlags());
1013           }
1014         }
1015
1016         // If the backedge is guarded by a comparison with the pre-inc value
1017         // the addrec is safe. Also, if the entry is guarded by a comparison
1018         // with the start value and the backedge is guarded by a comparison
1019         // with the post-inc value, the addrec is safe.
1020         if (isKnownPositive(Step)) {
1021           const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
1022                                       getUnsignedRange(Step).getUnsignedMax());
1023           if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
1024               (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) &&
1025                isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT,
1026                                            AR->getPostIncExpr(*this), N))) {
1027             // Cache knowledge of AR NUW, which is propagated to this AddRec.
1028             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1029             // Return the expression with the addrec on the outside.
1030             return getAddRecExpr(getZeroExtendExpr(Start, Ty),
1031                                  getZeroExtendExpr(Step, Ty),
1032                                  L, AR->getNoWrapFlags());
1033           }
1034         } else if (isKnownNegative(Step)) {
1035           const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
1036                                       getSignedRange(Step).getSignedMin());
1037           if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
1038               (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) &&
1039                isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT,
1040                                            AR->getPostIncExpr(*this), N))) {
1041             // Cache knowledge of AR NW, which is propagated to this AddRec.
1042             // Negative step causes unsigned wrap, but it still can't self-wrap.
1043             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1044             // Return the expression with the addrec on the outside.
1045             return getAddRecExpr(getZeroExtendExpr(Start, Ty),
1046                                  getSignExtendExpr(Step, Ty),
1047                                  L, AR->getNoWrapFlags());
1048           }
1049         }
1050       }
1051     }
1052
1053   // The cast wasn't folded; create an explicit cast node.
1054   // Recompute the insert position, as it may have been invalidated.
1055   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1056   SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1057                                                    Op, Ty);
1058   UniqueSCEVs.InsertNode(S, IP);
1059   return S;
1060 }
1061
1062 // Get the limit of a recurrence such that incrementing by Step cannot cause
1063 // signed overflow as long as the value of the recurrence within the loop does
1064 // not exceed this limit before incrementing.
1065 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1066                                            ICmpInst::Predicate *Pred,
1067                                            ScalarEvolution *SE) {
1068   unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1069   if (SE->isKnownPositive(Step)) {
1070     *Pred = ICmpInst::ICMP_SLT;
1071     return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
1072                            SE->getSignedRange(Step).getSignedMax());
1073   }
1074   if (SE->isKnownNegative(Step)) {
1075     *Pred = ICmpInst::ICMP_SGT;
1076     return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
1077                        SE->getSignedRange(Step).getSignedMin());
1078   }
1079   return 0;
1080 }
1081
1082 // The recurrence AR has been shown to have no signed wrap. Typically, if we can
1083 // prove NSW for AR, then we can just as easily prove NSW for its preincrement
1084 // or postincrement sibling. This allows normalizing a sign extended AddRec as
1085 // such: {sext(Step + Start),+,Step} => {(Step + sext(Start),+,Step} As a
1086 // result, the expression "Step + sext(PreIncAR)" is congruent with
1087 // "sext(PostIncAR)"
1088 static const SCEV *getPreStartForSignExtend(const SCEVAddRecExpr *AR,
1089                                             Type *Ty,
1090                                             ScalarEvolution *SE) {
1091   const Loop *L = AR->getLoop();
1092   const SCEV *Start = AR->getStart();
1093   const SCEV *Step = AR->getStepRecurrence(*SE);
1094
1095   // Check for a simple looking step prior to loop entry.
1096   const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1097   if (!SA)
1098     return 0;
1099
1100   // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1101   // subtraction is expensive. For this purpose, perform a quick and dirty
1102   // difference, by checking for Step in the operand list.
1103   SmallVector<const SCEV *, 4> DiffOps;
1104   for (SCEVAddExpr::op_iterator I = SA->op_begin(), E = SA->op_end();
1105        I != E; ++I) {
1106     if (*I != Step)
1107       DiffOps.push_back(*I);
1108   }
1109   if (DiffOps.size() == SA->getNumOperands())
1110     return 0;
1111
1112   // This is a postinc AR. Check for overflow on the preinc recurrence using the
1113   // same three conditions that getSignExtendedExpr checks.
1114
1115   // 1. NSW flags on the step increment.
1116   const SCEV *PreStart = SE->getAddExpr(DiffOps, SA->getNoWrapFlags());
1117   const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1118     SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1119
1120   if (PreAR && PreAR->getNoWrapFlags(SCEV::FlagNSW))
1121     return PreStart;
1122
1123   // 2. Direct overflow check on the step operation's expression.
1124   unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1125   Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1126   const SCEV *OperandExtendedStart =
1127     SE->getAddExpr(SE->getSignExtendExpr(PreStart, WideTy),
1128                    SE->getSignExtendExpr(Step, WideTy));
1129   if (SE->getSignExtendExpr(Start, WideTy) == OperandExtendedStart) {
1130     // Cache knowledge of PreAR NSW.
1131     if (PreAR)
1132       const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(SCEV::FlagNSW);
1133     // FIXME: this optimization needs a unit test
1134     DEBUG(dbgs() << "SCEV: untested prestart overflow check\n");
1135     return PreStart;
1136   }
1137
1138   // 3. Loop precondition.
1139   ICmpInst::Predicate Pred;
1140   const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, SE);
1141
1142   if (OverflowLimit &&
1143       SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) {
1144     return PreStart;
1145   }
1146   return 0;
1147 }
1148
1149 // Get the normalized sign-extended expression for this AddRec's Start.
1150 static const SCEV *getSignExtendAddRecStart(const SCEVAddRecExpr *AR,
1151                                             Type *Ty,
1152                                             ScalarEvolution *SE) {
1153   const SCEV *PreStart = getPreStartForSignExtend(AR, Ty, SE);
1154   if (!PreStart)
1155     return SE->getSignExtendExpr(AR->getStart(), Ty);
1156
1157   return SE->getAddExpr(SE->getSignExtendExpr(AR->getStepRecurrence(*SE), Ty),
1158                         SE->getSignExtendExpr(PreStart, Ty));
1159 }
1160
1161 const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
1162                                                Type *Ty) {
1163   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1164          "This is not an extending conversion!");
1165   assert(isSCEVable(Ty) &&
1166          "This is not a conversion to a SCEVable type!");
1167   Ty = getEffectiveSCEVType(Ty);
1168
1169   // Fold if the operand is constant.
1170   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1171     return getConstant(
1172       cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(),
1173                                               getEffectiveSCEVType(Ty))));
1174
1175   // sext(sext(x)) --> sext(x)
1176   if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1177     return getSignExtendExpr(SS->getOperand(), Ty);
1178
1179   // sext(zext(x)) --> zext(x)
1180   if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1181     return getZeroExtendExpr(SZ->getOperand(), Ty);
1182
1183   // Before doing any expensive analysis, check to see if we've already
1184   // computed a SCEV for this Op and Ty.
1185   FoldingSetNodeID ID;
1186   ID.AddInteger(scSignExtend);
1187   ID.AddPointer(Op);
1188   ID.AddPointer(Ty);
1189   void *IP = 0;
1190   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1191
1192   // If the input value is provably positive, build a zext instead.
1193   if (isKnownNonNegative(Op))
1194     return getZeroExtendExpr(Op, Ty);
1195
1196   // sext(trunc(x)) --> sext(x) or x or trunc(x)
1197   if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1198     // It's possible the bits taken off by the truncate were all sign bits. If
1199     // so, we should be able to simplify this further.
1200     const SCEV *X = ST->getOperand();
1201     ConstantRange CR = getSignedRange(X);
1202     unsigned TruncBits = getTypeSizeInBits(ST->getType());
1203     unsigned NewBits = getTypeSizeInBits(Ty);
1204     if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1205             CR.sextOrTrunc(NewBits)))
1206       return getTruncateOrSignExtend(X, Ty);
1207   }
1208
1209   // If the input value is a chrec scev, and we can prove that the value
1210   // did not overflow the old, smaller, value, we can sign extend all of the
1211   // operands (often constants).  This allows analysis of something like
1212   // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
1213   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1214     if (AR->isAffine()) {
1215       const SCEV *Start = AR->getStart();
1216       const SCEV *Step = AR->getStepRecurrence(*this);
1217       unsigned BitWidth = getTypeSizeInBits(AR->getType());
1218       const Loop *L = AR->getLoop();
1219
1220       // If we have special knowledge that this addrec won't overflow,
1221       // we don't need to do any further analysis.
1222       if (AR->getNoWrapFlags(SCEV::FlagNSW))
1223         return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1224                              getSignExtendExpr(Step, Ty),
1225                              L, SCEV::FlagNSW);
1226
1227       // Check whether the backedge-taken count is SCEVCouldNotCompute.
1228       // Note that this serves two purposes: It filters out loops that are
1229       // simply not analyzable, and it covers the case where this code is
1230       // being called from within backedge-taken count analysis, such that
1231       // attempting to ask for the backedge-taken count would likely result
1232       // in infinite recursion. In the later case, the analysis code will
1233       // cope with a conservative value, and it will take care to purge
1234       // that value once it has finished.
1235       const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
1236       if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1237         // Manually compute the final value for AR, checking for
1238         // overflow.
1239
1240         // Check whether the backedge-taken count can be losslessly casted to
1241         // the addrec's type. The count is always unsigned.
1242         const SCEV *CastedMaxBECount =
1243           getTruncateOrZeroExtend(MaxBECount, Start->getType());
1244         const SCEV *RecastedMaxBECount =
1245           getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
1246         if (MaxBECount == RecastedMaxBECount) {
1247           Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1248           // Check whether Start+Step*MaxBECount has no signed overflow.
1249           const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
1250           const SCEV *Add = getAddExpr(Start, SMul);
1251           const SCEV *OperandExtendedAdd =
1252             getAddExpr(getSignExtendExpr(Start, WideTy),
1253                        getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1254                                   getSignExtendExpr(Step, WideTy)));
1255           if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) {
1256             // Cache knowledge of AR NSW, which is propagated to this AddRec.
1257             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1258             // Return the expression with the addrec on the outside.
1259             return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1260                                  getSignExtendExpr(Step, Ty),
1261                                  L, AR->getNoWrapFlags());
1262           }
1263           // Similar to above, only this time treat the step value as unsigned.
1264           // This covers loops that count up with an unsigned step.
1265           const SCEV *UMul = getMulExpr(CastedMaxBECount, Step);
1266           Add = getAddExpr(Start, UMul);
1267           OperandExtendedAdd =
1268             getAddExpr(getSignExtendExpr(Start, WideTy),
1269                        getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1270                                   getZeroExtendExpr(Step, WideTy)));
1271           if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) {
1272             // Cache knowledge of AR NSW, which is propagated to this AddRec.
1273             const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1274             // Return the expression with the addrec on the outside.
1275             return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1276                                  getZeroExtendExpr(Step, Ty),
1277                                  L, AR->getNoWrapFlags());
1278           }
1279         }
1280
1281         // If the backedge is guarded by a comparison with the pre-inc value
1282         // the addrec is safe. Also, if the entry is guarded by a comparison
1283         // with the start value and the backedge is guarded by a comparison
1284         // with the post-inc value, the addrec is safe.
1285         ICmpInst::Predicate Pred;
1286         const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, this);
1287         if (OverflowLimit &&
1288             (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
1289              (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) &&
1290               isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this),
1291                                           OverflowLimit)))) {
1292           // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
1293           const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1294           return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1295                                getSignExtendExpr(Step, Ty),
1296                                L, AR->getNoWrapFlags());
1297         }
1298       }
1299     }
1300
1301   // The cast wasn't folded; create an explicit cast node.
1302   // Recompute the insert position, as it may have been invalidated.
1303   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1304   SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1305                                                    Op, Ty);
1306   UniqueSCEVs.InsertNode(S, IP);
1307   return S;
1308 }
1309
1310 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
1311 /// unspecified bits out to the given type.
1312 ///
1313 const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
1314                                               Type *Ty) {
1315   assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1316          "This is not an extending conversion!");
1317   assert(isSCEVable(Ty) &&
1318          "This is not a conversion to a SCEVable type!");
1319   Ty = getEffectiveSCEVType(Ty);
1320
1321   // Sign-extend negative constants.
1322   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1323     if (SC->getValue()->getValue().isNegative())
1324       return getSignExtendExpr(Op, Ty);
1325
1326   // Peel off a truncate cast.
1327   if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
1328     const SCEV *NewOp = T->getOperand();
1329     if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
1330       return getAnyExtendExpr(NewOp, Ty);
1331     return getTruncateOrNoop(NewOp, Ty);
1332   }
1333
1334   // Next try a zext cast. If the cast is folded, use it.
1335   const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
1336   if (!isa<SCEVZeroExtendExpr>(ZExt))
1337     return ZExt;
1338
1339   // Next try a sext cast. If the cast is folded, use it.
1340   const SCEV *SExt = getSignExtendExpr(Op, Ty);
1341   if (!isa<SCEVSignExtendExpr>(SExt))
1342     return SExt;
1343
1344   // Force the cast to be folded into the operands of an addrec.
1345   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
1346     SmallVector<const SCEV *, 4> Ops;
1347     for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
1348          I != E; ++I)
1349       Ops.push_back(getAnyExtendExpr(*I, Ty));
1350     return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
1351   }
1352
1353   // As a special case, fold anyext(undef) to undef. We don't want to
1354   // know too much about SCEVUnknowns, but this special case is handy
1355   // and harmless.
1356   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Op))
1357     if (isa<UndefValue>(U->getValue()))
1358       return getSCEV(UndefValue::get(Ty));
1359
1360   // If the expression is obviously signed, use the sext cast value.
1361   if (isa<SCEVSMaxExpr>(Op))
1362     return SExt;
1363
1364   // Absent any other information, use the zext cast value.
1365   return ZExt;
1366 }
1367
1368 /// CollectAddOperandsWithScales - Process the given Ops list, which is
1369 /// a list of operands to be added under the given scale, update the given
1370 /// map. This is a helper function for getAddRecExpr. As an example of
1371 /// what it does, given a sequence of operands that would form an add
1372 /// expression like this:
1373 ///
1374 ///    m + n + 13 + (A * (o + p + (B * q + m + 29))) + r + (-1 * r)
1375 ///
1376 /// where A and B are constants, update the map with these values:
1377 ///
1378 ///    (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
1379 ///
1380 /// and add 13 + A*B*29 to AccumulatedConstant.
1381 /// This will allow getAddRecExpr to produce this:
1382 ///
1383 ///    13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
1384 ///
1385 /// This form often exposes folding opportunities that are hidden in
1386 /// the original operand list.
1387 ///
1388 /// Return true iff it appears that any interesting folding opportunities
1389 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
1390 /// the common case where no interesting opportunities are present, and
1391 /// is also used as a check to avoid infinite recursion.
1392 ///
1393 static bool
1394 CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
1395                              SmallVector<const SCEV *, 8> &NewOps,
1396                              APInt &AccumulatedConstant,
1397                              const SCEV *const *Ops, size_t NumOperands,
1398                              const APInt &Scale,
1399                              ScalarEvolution &SE) {
1400   bool Interesting = false;
1401
1402   // Iterate over the add operands. They are sorted, with constants first.
1403   unsigned i = 0;
1404   while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1405     ++i;
1406     // Pull a buried constant out to the outside.
1407     if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
1408       Interesting = true;
1409     AccumulatedConstant += Scale * C->getValue()->getValue();
1410   }
1411
1412   // Next comes everything else. We're especially interested in multiplies
1413   // here, but they're in the middle, so just visit the rest with one loop.
1414   for (; i != NumOperands; ++i) {
1415     const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
1416     if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
1417       APInt NewScale =
1418         Scale * cast<SCEVConstant>(Mul->getOperand(0))->getValue()->getValue();
1419       if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
1420         // A multiplication of a constant with another add; recurse.
1421         const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
1422         Interesting |=
1423           CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1424                                        Add->op_begin(), Add->getNumOperands(),
1425                                        NewScale, SE);
1426       } else {
1427         // A multiplication of a constant with some other value. Update
1428         // the map.
1429         SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
1430         const SCEV *Key = SE.getMulExpr(MulOps);
1431         std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1432           M.insert(std::make_pair(Key, NewScale));
1433         if (Pair.second) {
1434           NewOps.push_back(Pair.first->first);
1435         } else {
1436           Pair.first->second += NewScale;
1437           // The map already had an entry for this value, which may indicate
1438           // a folding opportunity.
1439           Interesting = true;
1440         }
1441       }
1442     } else {
1443       // An ordinary operand. Update the map.
1444       std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1445         M.insert(std::make_pair(Ops[i], Scale));
1446       if (Pair.second) {
1447         NewOps.push_back(Pair.first->first);
1448       } else {
1449         Pair.first->second += Scale;
1450         // The map already had an entry for this value, which may indicate
1451         // a folding opportunity.
1452         Interesting = true;
1453       }
1454     }
1455   }
1456
1457   return Interesting;
1458 }
1459
1460 namespace {
1461   struct APIntCompare {
1462     bool operator()(const APInt &LHS, const APInt &RHS) const {
1463       return LHS.ult(RHS);
1464     }
1465   };
1466 }
1467
1468 /// getAddExpr - Get a canonical add expression, or something simpler if
1469 /// possible.
1470 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
1471                                         SCEV::NoWrapFlags Flags) {
1472   assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
1473          "only nuw or nsw allowed");
1474   assert(!Ops.empty() && "Cannot get empty add!");
1475   if (Ops.size() == 1) return Ops[0];
1476 #ifndef NDEBUG
1477   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
1478   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1479     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
1480            "SCEVAddExpr operand types don't match!");
1481 #endif
1482
1483   // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
1484   // And vice-versa.
1485   int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
1486   SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
1487   if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
1488     bool All = true;
1489     for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
1490          E = Ops.end(); I != E; ++I)
1491       if (!isKnownNonNegative(*I)) {
1492         All = false;
1493         break;
1494       }
1495     if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
1496   }
1497
1498   // Sort by complexity, this groups all similar expression types together.
1499   GroupByComplexity(Ops, LI);
1500
1501   // If there are any constants, fold them together.
1502   unsigned Idx = 0;
1503   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1504     ++Idx;
1505     assert(Idx < Ops.size());
1506     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1507       // We found two constants, fold them together!
1508       Ops[0] = getConstant(LHSC->getValue()->getValue() +
1509                            RHSC->getValue()->getValue());
1510       if (Ops.size() == 2) return Ops[0];
1511       Ops.erase(Ops.begin()+1);  // Erase the folded element
1512       LHSC = cast<SCEVConstant>(Ops[0]);
1513     }
1514
1515     // If we are left with a constant zero being added, strip it off.
1516     if (LHSC->getValue()->isZero()) {
1517       Ops.erase(Ops.begin());
1518       --Idx;
1519     }
1520
1521     if (Ops.size() == 1) return Ops[0];
1522   }
1523
1524   // Okay, check to see if the same value occurs in the operand list more than
1525   // once.  If so, merge them together into an multiply expression.  Since we
1526   // sorted the list, these values are required to be adjacent.
1527   Type *Ty = Ops[0]->getType();
1528   bool FoundMatch = false;
1529   for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
1530     if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
1531       // Scan ahead to count how many equal operands there are.
1532       unsigned Count = 2;
1533       while (i+Count != e && Ops[i+Count] == Ops[i])
1534         ++Count;
1535       // Merge the values into a multiply.
1536       const SCEV *Scale = getConstant(Ty, Count);
1537       const SCEV *Mul = getMulExpr(Scale, Ops[i]);
1538       if (Ops.size() == Count)
1539         return Mul;
1540       Ops[i] = Mul;
1541       Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
1542       --i; e -= Count - 1;
1543       FoundMatch = true;
1544     }
1545   if (FoundMatch)
1546     return getAddExpr(Ops, Flags);
1547
1548   // Check for truncates. If all the operands are truncated from the same
1549   // type, see if factoring out the truncate would permit the result to be
1550   // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
1551   // if the contents of the resulting outer trunc fold to something simple.
1552   for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
1553     const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
1554     Type *DstType = Trunc->getType();
1555     Type *SrcType = Trunc->getOperand()->getType();
1556     SmallVector<const SCEV *, 8> LargeOps;
1557     bool Ok = true;
1558     // Check all the operands to see if they can be represented in the
1559     // source type of the truncate.
1560     for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
1561       if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
1562         if (T->getOperand()->getType() != SrcType) {
1563           Ok = false;
1564           break;
1565         }
1566         LargeOps.push_back(T->getOperand());
1567       } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1568         LargeOps.push_back(getAnyExtendExpr(C, SrcType));
1569       } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
1570         SmallVector<const SCEV *, 8> LargeMulOps;
1571         for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
1572           if (const SCEVTruncateExpr *T =
1573                 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
1574             if (T->getOperand()->getType() != SrcType) {
1575               Ok = false;
1576               break;
1577             }
1578             LargeMulOps.push_back(T->getOperand());
1579           } else if (const SCEVConstant *C =
1580                        dyn_cast<SCEVConstant>(M->getOperand(j))) {
1581             LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
1582           } else {
1583             Ok = false;
1584             break;
1585           }
1586         }
1587         if (Ok)
1588           LargeOps.push_back(getMulExpr(LargeMulOps));
1589       } else {
1590         Ok = false;
1591         break;
1592       }
1593     }
1594     if (Ok) {
1595       // Evaluate the expression in the larger type.
1596       const SCEV *Fold = getAddExpr(LargeOps, Flags);
1597       // If it folds to something simple, use it. Otherwise, don't.
1598       if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
1599         return getTruncateExpr(Fold, DstType);
1600     }
1601   }
1602
1603   // Skip past any other cast SCEVs.
1604   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
1605     ++Idx;
1606
1607   // If there are add operands they would be next.
1608   if (Idx < Ops.size()) {
1609     bool DeletedAdd = false;
1610     while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
1611       // If we have an add, expand the add operands onto the end of the operands
1612       // list.
1613       Ops.erase(Ops.begin()+Idx);
1614       Ops.append(Add->op_begin(), Add->op_end());
1615       DeletedAdd = true;
1616     }
1617
1618     // If we deleted at least one add, we added operands to the end of the list,
1619     // and they are not necessarily sorted.  Recurse to resort and resimplify
1620     // any operands we just acquired.
1621     if (DeletedAdd)
1622       return getAddExpr(Ops);
1623   }
1624
1625   // Skip over the add expression until we get to a multiply.
1626   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1627     ++Idx;
1628
1629   // Check to see if there are any folding opportunities present with
1630   // operands multiplied by constant values.
1631   if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
1632     uint64_t BitWidth = getTypeSizeInBits(Ty);
1633     DenseMap<const SCEV *, APInt> M;
1634     SmallVector<const SCEV *, 8> NewOps;
1635     APInt AccumulatedConstant(BitWidth, 0);
1636     if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1637                                      Ops.data(), Ops.size(),
1638                                      APInt(BitWidth, 1), *this)) {
1639       // Some interesting folding opportunity is present, so its worthwhile to
1640       // re-generate the operands list. Group the operands by constant scale,
1641       // to avoid multiplying by the same constant scale multiple times.
1642       std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
1643       for (SmallVector<const SCEV *, 8>::const_iterator I = NewOps.begin(),
1644            E = NewOps.end(); I != E; ++I)
1645         MulOpLists[M.find(*I)->second].push_back(*I);
1646       // Re-generate the operands list.
1647       Ops.clear();
1648       if (AccumulatedConstant != 0)
1649         Ops.push_back(getConstant(AccumulatedConstant));
1650       for (std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare>::iterator
1651            I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I)
1652         if (I->first != 0)
1653           Ops.push_back(getMulExpr(getConstant(I->first),
1654                                    getAddExpr(I->second)));
1655       if (Ops.empty())
1656         return getConstant(Ty, 0);
1657       if (Ops.size() == 1)
1658         return Ops[0];
1659       return getAddExpr(Ops);
1660     }
1661   }
1662
1663   // If we are adding something to a multiply expression, make sure the
1664   // something is not already an operand of the multiply.  If so, merge it into
1665   // the multiply.
1666   for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
1667     const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
1668     for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
1669       const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
1670       if (isa<SCEVConstant>(MulOpSCEV))
1671         continue;
1672       for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
1673         if (MulOpSCEV == Ops[AddOp]) {
1674           // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
1675           const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
1676           if (Mul->getNumOperands() != 2) {
1677             // If the multiply has more than two operands, we must get the
1678             // Y*Z term.
1679             SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
1680                                                 Mul->op_begin()+MulOp);
1681             MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
1682             InnerMul = getMulExpr(MulOps);
1683           }
1684           const SCEV *One = getConstant(Ty, 1);
1685           const SCEV *AddOne = getAddExpr(One, InnerMul);
1686           const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV);
1687           if (Ops.size() == 2) return OuterMul;
1688           if (AddOp < Idx) {
1689             Ops.erase(Ops.begin()+AddOp);
1690             Ops.erase(Ops.begin()+Idx-1);
1691           } else {
1692             Ops.erase(Ops.begin()+Idx);
1693             Ops.erase(Ops.begin()+AddOp-1);
1694           }
1695           Ops.push_back(OuterMul);
1696           return getAddExpr(Ops);
1697         }
1698
1699       // Check this multiply against other multiplies being added together.
1700       for (unsigned OtherMulIdx = Idx+1;
1701            OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
1702            ++OtherMulIdx) {
1703         const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
1704         // If MulOp occurs in OtherMul, we can fold the two multiplies
1705         // together.
1706         for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
1707              OMulOp != e; ++OMulOp)
1708           if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
1709             // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
1710             const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
1711             if (Mul->getNumOperands() != 2) {
1712               SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
1713                                                   Mul->op_begin()+MulOp);
1714               MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
1715               InnerMul1 = getMulExpr(MulOps);
1716             }
1717             const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
1718             if (OtherMul->getNumOperands() != 2) {
1719               SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
1720                                                   OtherMul->op_begin()+OMulOp);
1721               MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
1722               InnerMul2 = getMulExpr(MulOps);
1723             }
1724             const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
1725             const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
1726             if (Ops.size() == 2) return OuterMul;
1727             Ops.erase(Ops.begin()+Idx);
1728             Ops.erase(Ops.begin()+OtherMulIdx-1);
1729             Ops.push_back(OuterMul);
1730             return getAddExpr(Ops);
1731           }
1732       }
1733     }
1734   }
1735
1736   // If there are any add recurrences in the operands list, see if any other
1737   // added values are loop invariant.  If so, we can fold them into the
1738   // recurrence.
1739   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1740     ++Idx;
1741
1742   // Scan over all recurrences, trying to fold loop invariants into them.
1743   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1744     // Scan all of the other operands to this add and add them to the vector if
1745     // they are loop invariant w.r.t. the recurrence.
1746     SmallVector<const SCEV *, 8> LIOps;
1747     const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1748     const Loop *AddRecLoop = AddRec->getLoop();
1749     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1750       if (isLoopInvariant(Ops[i], AddRecLoop)) {
1751         LIOps.push_back(Ops[i]);
1752         Ops.erase(Ops.begin()+i);
1753         --i; --e;
1754       }
1755
1756     // If we found some loop invariants, fold them into the recurrence.
1757     if (!LIOps.empty()) {
1758       //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
1759       LIOps.push_back(AddRec->getStart());
1760
1761       SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
1762                                              AddRec->op_end());
1763       AddRecOps[0] = getAddExpr(LIOps);
1764
1765       // Build the new addrec. Propagate the NUW and NSW flags if both the
1766       // outer add and the inner addrec are guaranteed to have no overflow.
1767       // Always propagate NW.
1768       Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
1769       const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
1770
1771       // If all of the other operands were loop invariant, we are done.
1772       if (Ops.size() == 1) return NewRec;
1773
1774       // Otherwise, add the folded AddRec by the non-invariant parts.
1775       for (unsigned i = 0;; ++i)
1776         if (Ops[i] == AddRec) {
1777           Ops[i] = NewRec;
1778           break;
1779         }
1780       return getAddExpr(Ops);
1781     }
1782
1783     // Okay, if there weren't any loop invariants to be folded, check to see if
1784     // there are multiple AddRec's with the same loop induction variable being
1785     // added together.  If so, we can fold them.
1786     for (unsigned OtherIdx = Idx+1;
1787          OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
1788          ++OtherIdx)
1789       if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
1790         // Other + {A,+,B}<L> + {C,+,D}<L>  -->  Other + {A+C,+,B+D}<L>
1791         SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
1792                                                AddRec->op_end());
1793         for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
1794              ++OtherIdx)
1795           if (const SCEVAddRecExpr *OtherAddRec =
1796                 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
1797             if (OtherAddRec->getLoop() == AddRecLoop) {
1798               for (unsigned i = 0, e = OtherAddRec->getNumOperands();
1799                    i != e; ++i) {
1800                 if (i >= AddRecOps.size()) {
1801                   AddRecOps.append(OtherAddRec->op_begin()+i,
1802                                    OtherAddRec->op_end());
1803                   break;
1804                 }
1805                 AddRecOps[i] = getAddExpr(AddRecOps[i],
1806                                           OtherAddRec->getOperand(i));
1807               }
1808               Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
1809             }
1810         // Step size has changed, so we cannot guarantee no self-wraparound.
1811         Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
1812         return getAddExpr(Ops);
1813       }
1814
1815     // Otherwise couldn't fold anything into this recurrence.  Move onto the
1816     // next one.
1817   }
1818
1819   // Okay, it looks like we really DO need an add expr.  Check to see if we
1820   // already have one, otherwise create a new one.
1821   FoldingSetNodeID ID;
1822   ID.AddInteger(scAddExpr);
1823   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1824     ID.AddPointer(Ops[i]);
1825   void *IP = 0;
1826   SCEVAddExpr *S =
1827     static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1828   if (!S) {
1829     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
1830     std::uninitialized_copy(Ops.begin(), Ops.end(), O);
1831     S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator),
1832                                         O, Ops.size());
1833     UniqueSCEVs.InsertNode(S, IP);
1834   }
1835   S->setNoWrapFlags(Flags);
1836   return S;
1837 }
1838
1839 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
1840   uint64_t k = i*j;
1841   if (j > 1 && k / j != i) Overflow = true;
1842   return k;
1843 }
1844
1845 /// Compute the result of "n choose k", the binomial coefficient.  If an
1846 /// intermediate computation overflows, Overflow will be set and the return will
1847 /// be garbage. Overflow is not cleared on absense of overflow.
1848 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
1849   // We use the multiplicative formula:
1850   //     n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
1851   // At each iteration, we take the n-th term of the numeral and divide by the
1852   // (k-n)th term of the denominator.  This division will always produce an
1853   // integral result, and helps reduce the chance of overflow in the
1854   // intermediate computations. However, we can still overflow even when the
1855   // final result would fit.
1856
1857   if (n == 0 || n == k) return 1;
1858   if (k > n) return 0;
1859
1860   if (k > n/2)
1861     k = n-k;
1862
1863   uint64_t r = 1;
1864   for (uint64_t i = 1; i <= k; ++i) {
1865     r = umul_ov(r, n-(i-1), Overflow);
1866     r /= i;
1867   }
1868   return r;
1869 }
1870
1871 /// getMulExpr - Get a canonical multiply expression, or something simpler if
1872 /// possible.
1873 const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
1874                                         SCEV::NoWrapFlags Flags) {
1875   assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) &&
1876          "only nuw or nsw allowed");
1877   assert(!Ops.empty() && "Cannot get empty mul!");
1878   if (Ops.size() == 1) return Ops[0];
1879 #ifndef NDEBUG
1880   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
1881   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1882     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
1883            "SCEVMulExpr operand types don't match!");
1884 #endif
1885
1886   // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
1887   // And vice-versa.
1888   int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
1889   SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
1890   if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
1891     bool All = true;
1892     for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
1893          E = Ops.end(); I != E; ++I)
1894       if (!isKnownNonNegative(*I)) {
1895         All = false;
1896         break;
1897       }
1898     if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
1899   }
1900
1901   // Sort by complexity, this groups all similar expression types together.
1902   GroupByComplexity(Ops, LI);
1903
1904   // If there are any constants, fold them together.
1905   unsigned Idx = 0;
1906   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1907
1908     // C1*(C2+V) -> C1*C2 + C1*V
1909     if (Ops.size() == 2)
1910       if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
1911         if (Add->getNumOperands() == 2 &&
1912             isa<SCEVConstant>(Add->getOperand(0)))
1913           return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
1914                             getMulExpr(LHSC, Add->getOperand(1)));
1915
1916     ++Idx;
1917     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1918       // We found two constants, fold them together!
1919       ConstantInt *Fold = ConstantInt::get(getContext(),
1920                                            LHSC->getValue()->getValue() *
1921                                            RHSC->getValue()->getValue());
1922       Ops[0] = getConstant(Fold);
1923       Ops.erase(Ops.begin()+1);  // Erase the folded element
1924       if (Ops.size() == 1) return Ops[0];
1925       LHSC = cast<SCEVConstant>(Ops[0]);
1926     }
1927
1928     // If we are left with a constant one being multiplied, strip it off.
1929     if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
1930       Ops.erase(Ops.begin());
1931       --Idx;
1932     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1933       // If we have a multiply of zero, it will always be zero.
1934       return Ops[0];
1935     } else if (Ops[0]->isAllOnesValue()) {
1936       // If we have a mul by -1 of an add, try distributing the -1 among the
1937       // add operands.
1938       if (Ops.size() == 2) {
1939         if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
1940           SmallVector<const SCEV *, 4> NewOps;
1941           bool AnyFolded = false;
1942           for (SCEVAddRecExpr::op_iterator I = Add->op_begin(),
1943                  E = Add->op_end(); I != E; ++I) {
1944             const SCEV *Mul = getMulExpr(Ops[0], *I);
1945             if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
1946             NewOps.push_back(Mul);
1947           }
1948           if (AnyFolded)
1949             return getAddExpr(NewOps);
1950         }
1951         else if (const SCEVAddRecExpr *
1952                  AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
1953           // Negation preserves a recurrence's no self-wrap property.
1954           SmallVector<const SCEV *, 4> Operands;
1955           for (SCEVAddRecExpr::op_iterator I = AddRec->op_begin(),
1956                  E = AddRec->op_end(); I != E; ++I) {
1957             Operands.push_back(getMulExpr(Ops[0], *I));
1958           }
1959           return getAddRecExpr(Operands, AddRec->getLoop(),
1960                                AddRec->getNoWrapFlags(SCEV::FlagNW));
1961         }
1962       }
1963     }
1964
1965     if (Ops.size() == 1)
1966       return Ops[0];
1967   }
1968
1969   // Skip over the add expression until we get to a multiply.
1970   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1971     ++Idx;
1972
1973   // If there are mul operands inline them all into this expression.
1974   if (Idx < Ops.size()) {
1975     bool DeletedMul = false;
1976     while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
1977       // If we have an mul, expand the mul operands onto the end of the operands
1978       // list.
1979       Ops.erase(Ops.begin()+Idx);
1980       Ops.append(Mul->op_begin(), Mul->op_end());
1981       DeletedMul = true;
1982     }
1983
1984     // If we deleted at least one mul, we added operands to the end of the list,
1985     // and they are not necessarily sorted.  Recurse to resort and resimplify
1986     // any operands we just acquired.
1987     if (DeletedMul)
1988       return getMulExpr(Ops);
1989   }
1990
1991   // If there are any add recurrences in the operands list, see if any other
1992   // added values are loop invariant.  If so, we can fold them into the
1993   // recurrence.
1994   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1995     ++Idx;
1996
1997   // Scan over all recurrences, trying to fold loop invariants into them.
1998   for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1999     // Scan all of the other operands to this mul and add them to the vector if
2000     // they are loop invariant w.r.t. the recurrence.
2001     SmallVector<const SCEV *, 8> LIOps;
2002     const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2003     const Loop *AddRecLoop = AddRec->getLoop();
2004     for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2005       if (isLoopInvariant(Ops[i], AddRecLoop)) {
2006         LIOps.push_back(Ops[i]);
2007         Ops.erase(Ops.begin()+i);
2008         --i; --e;
2009       }
2010
2011     // If we found some loop invariants, fold them into the recurrence.
2012     if (!LIOps.empty()) {
2013       //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
2014       SmallVector<const SCEV *, 4> NewOps;
2015       NewOps.reserve(AddRec->getNumOperands());
2016       const SCEV *Scale = getMulExpr(LIOps);
2017       for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
2018         NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
2019
2020       // Build the new addrec. Propagate the NUW and NSW flags if both the
2021       // outer mul and the inner addrec are guaranteed to have no overflow.
2022       //
2023       // No self-wrap cannot be guaranteed after changing the step size, but
2024       // will be inferred if either NUW or NSW is true.
2025       Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW));
2026       const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags);
2027
2028       // If all of the other operands were loop invariant, we are done.
2029       if (Ops.size() == 1) return NewRec;
2030
2031       // Otherwise, multiply the folded AddRec by the non-invariant parts.
2032       for (unsigned i = 0;; ++i)
2033         if (Ops[i] == AddRec) {
2034           Ops[i] = NewRec;
2035           break;
2036         }
2037       return getMulExpr(Ops);
2038     }
2039
2040     // Okay, if there weren't any loop invariants to be folded, check to see if
2041     // there are multiple AddRec's with the same loop induction variable being
2042     // multiplied together.  If so, we can fold them.
2043     for (unsigned OtherIdx = Idx+1;
2044          OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2045          ++OtherIdx) {
2046       if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2047         // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
2048         // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
2049         //       choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
2050         //   ]]],+,...up to x=2n}.
2051         // Note that the arguments to choose() are always integers with values
2052         // known at compile time, never SCEV objects.
2053         //
2054         // The implementation avoids pointless extra computations when the two
2055         // addrec's are of different length (mathematically, it's equivalent to
2056         // an infinite stream of zeros on the right).
2057         bool OpsModified = false;
2058         for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2059              ++OtherIdx)
2060           if (const SCEVAddRecExpr *OtherAddRec =
2061                 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
2062             if (OtherAddRec->getLoop() == AddRecLoop) {
2063               bool Overflow = false;
2064               Type *Ty = AddRec->getType();
2065               bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
2066               SmallVector<const SCEV*, 7> AddRecOps;
2067               for (int x = 0, xe = AddRec->getNumOperands() +
2068                      OtherAddRec->getNumOperands() - 1;
2069                    x != xe && !Overflow; ++x) {
2070                 const SCEV *Term = getConstant(Ty, 0);
2071                 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
2072                   uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
2073                   for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
2074                          ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
2075                        z < ze && !Overflow; ++z) {
2076                     uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
2077                     uint64_t Coeff;
2078                     if (LargerThan64Bits)
2079                       Coeff = umul_ov(Coeff1, Coeff2, Overflow);
2080                     else
2081                       Coeff = Coeff1*Coeff2;
2082                     const SCEV *CoeffTerm = getConstant(Ty, Coeff);
2083                     const SCEV *Term1 = AddRec->getOperand(y-z);
2084                     const SCEV *Term2 = OtherAddRec->getOperand(z);
2085                     Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2));
2086                   }
2087                 }
2088                 AddRecOps.push_back(Term);
2089               }
2090               if (!Overflow) {
2091                 const SCEV *NewAddRec = getAddRecExpr(AddRecOps,
2092                                                       AddRec->getLoop(),
2093                                                       SCEV::FlagAnyWrap);
2094                 if (Ops.size() == 2) return NewAddRec;
2095                 Ops[Idx] = AddRec = cast<SCEVAddRecExpr>(NewAddRec);
2096                 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2097                 OpsModified = true;
2098               }
2099             }
2100         if (OpsModified)
2101           return getMulExpr(Ops);
2102       }
2103     }
2104
2105     // Otherwise couldn't fold anything into this recurrence.  Move onto the
2106     // next one.
2107   }
2108
2109   // Okay, it looks like we really DO need an mul expr.  Check to see if we
2110   // already have one, otherwise create a new one.
2111   FoldingSetNodeID ID;
2112   ID.AddInteger(scMulExpr);
2113   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2114     ID.AddPointer(Ops[i]);
2115   void *IP = 0;
2116   SCEVMulExpr *S =
2117     static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2118   if (!S) {
2119     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2120     std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2121     S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2122                                         O, Ops.size());
2123     UniqueSCEVs.InsertNode(S, IP);
2124   }
2125   S->setNoWrapFlags(Flags);
2126   return S;
2127 }
2128
2129 /// getUDivExpr - Get a canonical unsigned division expression, or something
2130 /// simpler if possible.
2131 const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
2132                                          const SCEV *RHS) {
2133   assert(getEffectiveSCEVType(LHS->getType()) ==
2134          getEffectiveSCEVType(RHS->getType()) &&
2135          "SCEVUDivExpr operand types don't match!");
2136
2137   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
2138     if (RHSC->getValue()->equalsInt(1))
2139       return LHS;                               // X udiv 1 --> x
2140     // If the denominator is zero, the result of the udiv is undefined. Don't
2141     // try to analyze it, because the resolution chosen here may differ from
2142     // the resolution chosen in other parts of the compiler.
2143     if (!RHSC->getValue()->isZero()) {
2144       // Determine if the division can be folded into the operands of
2145       // its operands.
2146       // TODO: Generalize this to non-constants by using known-bits information.
2147       Type *Ty = LHS->getType();
2148       unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
2149       unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
2150       // For non-power-of-two values, effectively round the value up to the
2151       // nearest power of two.
2152       if (!RHSC->getValue()->getValue().isPowerOf2())
2153         ++MaxShiftAmt;
2154       IntegerType *ExtTy =
2155         IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
2156       if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
2157         if (const SCEVConstant *Step =
2158             dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
2159           // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
2160           const APInt &StepInt = Step->getValue()->getValue();
2161           const APInt &DivInt = RHSC->getValue()->getValue();
2162           if (!StepInt.urem(DivInt) &&
2163               getZeroExtendExpr(AR, ExtTy) ==
2164               getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
2165                             getZeroExtendExpr(Step, ExtTy),
2166                             AR->getLoop(), SCEV::FlagAnyWrap)) {
2167             SmallVector<const SCEV *, 4> Operands;
2168             for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
2169               Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
2170             return getAddRecExpr(Operands, AR->getLoop(),
2171                                  SCEV::FlagNW);
2172           }
2173           /// Get a canonical UDivExpr for a recurrence.
2174           /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
2175           // We can currently only fold X%N if X is constant.
2176           const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
2177           if (StartC && !DivInt.urem(StepInt) &&
2178               getZeroExtendExpr(AR, ExtTy) ==
2179               getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
2180                             getZeroExtendExpr(Step, ExtTy),
2181                             AR->getLoop(), SCEV::FlagAnyWrap)) {
2182             const APInt &StartInt = StartC->getValue()->getValue();
2183             const APInt &StartRem = StartInt.urem(StepInt);
2184             if (StartRem != 0)
2185               LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step,
2186                                   AR->getLoop(), SCEV::FlagNW);
2187           }
2188         }
2189       // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
2190       if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
2191         SmallVector<const SCEV *, 4> Operands;
2192         for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
2193           Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
2194         if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
2195           // Find an operand that's safely divisible.
2196           for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
2197             const SCEV *Op = M->getOperand(i);
2198             const SCEV *Div = getUDivExpr(Op, RHSC);
2199             if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
2200               Operands = SmallVector<const SCEV *, 4>(M->op_begin(),
2201                                                       M->op_end());
2202               Operands[i] = Div;
2203               return getMulExpr(Operands);
2204             }
2205           }
2206       }
2207       // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
2208       if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
2209         SmallVector<const SCEV *, 4> Operands;
2210         for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
2211           Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
2212         if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
2213           Operands.clear();
2214           for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
2215             const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
2216             if (isa<SCEVUDivExpr>(Op) ||
2217                 getMulExpr(Op, RHS) != A->getOperand(i))
2218               break;
2219             Operands.push_back(Op);
2220           }
2221           if (Operands.size() == A->getNumOperands())
2222             return getAddExpr(Operands);
2223         }
2224       }
2225
2226       // Fold if both operands are constant.
2227       if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
2228         Constant *LHSCV = LHSC->getValue();
2229         Constant *RHSCV = RHSC->getValue();
2230         return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
2231                                                                    RHSCV)));
2232       }
2233     }
2234   }
2235
2236   FoldingSetNodeID ID;
2237   ID.AddInteger(scUDivExpr);
2238   ID.AddPointer(LHS);
2239   ID.AddPointer(RHS);
2240   void *IP = 0;
2241   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2242   SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
2243                                              LHS, RHS);
2244   UniqueSCEVs.InsertNode(S, IP);
2245   return S;
2246 }
2247
2248
2249 /// getAddRecExpr - Get an add recurrence expression for the specified loop.
2250 /// Simplify the expression as much as possible.
2251 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
2252                                            const Loop *L,
2253                                            SCEV::NoWrapFlags Flags) {
2254   SmallVector<const SCEV *, 4> Operands;
2255   Operands.push_back(Start);
2256   if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
2257     if (StepChrec->getLoop() == L) {
2258       Operands.append(StepChrec->op_begin(), StepChrec->op_end());
2259       return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
2260     }
2261
2262   Operands.push_back(Step);
2263   return getAddRecExpr(Operands, L, Flags);
2264 }
2265
2266 /// getAddRecExpr - Get an add recurrence expression for the specified loop.
2267 /// Simplify the expression as much as possible.
2268 const SCEV *
2269 ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
2270                                const Loop *L, SCEV::NoWrapFlags Flags) {
2271   if (Operands.size() == 1) return Operands[0];
2272 #ifndef NDEBUG
2273   Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
2274   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
2275     assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
2276            "SCEVAddRecExpr operand types don't match!");
2277   for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2278     assert(isLoopInvariant(Operands[i], L) &&
2279            "SCEVAddRecExpr operand is not loop-invariant!");
2280 #endif
2281
2282   if (Operands.back()->isZero()) {
2283     Operands.pop_back();
2284     return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0}  -->  X
2285   }
2286
2287   // It's tempting to want to call getMaxBackedgeTakenCount count here and
2288   // use that information to infer NUW and NSW flags. However, computing a
2289   // BE count requires calling getAddRecExpr, so we may not yet have a
2290   // meaningful BE count at this point (and if we don't, we'd be stuck
2291   // with a SCEVCouldNotCompute as the cached BE count).
2292
2293   // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2294   // And vice-versa.
2295   int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2296   SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
2297   if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
2298     bool All = true;
2299     for (SmallVectorImpl<const SCEV *>::const_iterator I = Operands.begin(),
2300          E = Operands.end(); I != E; ++I)
2301       if (!isKnownNonNegative(*I)) {
2302         All = false;
2303         break;
2304       }
2305     if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2306   }
2307
2308   // Canonicalize nested AddRecs in by nesting them in order of loop depth.
2309   if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
2310     const Loop *NestedLoop = NestedAR->getLoop();
2311     if (L->contains(NestedLoop) ?
2312         (L->getLoopDepth() < NestedLoop->getLoopDepth()) :
2313         (!NestedLoop->contains(L) &&
2314          DT->dominates(L->getHeader(), NestedLoop->getHeader()))) {
2315       SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
2316                                                   NestedAR->op_end());
2317       Operands[0] = NestedAR->getStart();
2318       // AddRecs require their operands be loop-invariant with respect to their
2319       // loops. Don't perform this transformation if it would break this
2320       // requirement.
2321       bool AllInvariant = true;
2322       for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2323         if (!isLoopInvariant(Operands[i], L)) {
2324           AllInvariant = false;
2325           break;
2326         }
2327       if (AllInvariant) {
2328         // Create a recurrence for the outer loop with the same step size.
2329         //
2330         // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
2331         // inner recurrence has the same property.
2332         SCEV::NoWrapFlags OuterFlags =
2333           maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
2334
2335         NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
2336         AllInvariant = true;
2337         for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i)
2338           if (!isLoopInvariant(NestedOperands[i], NestedLoop)) {
2339             AllInvariant = false;
2340             break;
2341           }
2342         if (AllInvariant) {
2343           // Ok, both add recurrences are valid after the transformation.
2344           //
2345           // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
2346           // the outer recurrence has the same property.
2347           SCEV::NoWrapFlags InnerFlags =
2348             maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
2349           return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
2350         }
2351       }
2352       // Reset Operands to its original state.
2353       Operands[0] = NestedAR;
2354     }
2355   }
2356
2357   // Okay, it looks like we really DO need an addrec expr.  Check to see if we
2358   // already have one, otherwise create a new one.
2359   FoldingSetNodeID ID;
2360   ID.AddInteger(scAddRecExpr);
2361   for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2362     ID.AddPointer(Operands[i]);
2363   ID.AddPointer(L);
2364   void *IP = 0;
2365   SCEVAddRecExpr *S =
2366     static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2367   if (!S) {
2368     const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size());
2369     std::uninitialized_copy(Operands.begin(), Operands.end(), O);
2370     S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator),
2371                                            O, Operands.size(), L);
2372     UniqueSCEVs.InsertNode(S, IP);
2373   }
2374   S->setNoWrapFlags(Flags);
2375   return S;
2376 }
2377
2378 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS,
2379                                          const SCEV *RHS) {
2380   SmallVector<const SCEV *, 2> Ops;
2381   Ops.push_back(LHS);
2382   Ops.push_back(RHS);
2383   return getSMaxExpr(Ops);
2384 }
2385
2386 const SCEV *
2387 ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
2388   assert(!Ops.empty() && "Cannot get empty smax!");
2389   if (Ops.size() == 1) return Ops[0];
2390 #ifndef NDEBUG
2391   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2392   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2393     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2394            "SCEVSMaxExpr operand types don't match!");
2395 #endif
2396
2397   // Sort by complexity, this groups all similar expression types together.
2398   GroupByComplexity(Ops, LI);
2399
2400   // If there are any constants, fold them together.
2401   unsigned Idx = 0;
2402   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2403     ++Idx;
2404     assert(Idx < Ops.size());
2405     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2406       // We found two constants, fold them together!
2407       ConstantInt *Fold = ConstantInt::get(getContext(),
2408                               APIntOps::smax(LHSC->getValue()->getValue(),
2409                                              RHSC->getValue()->getValue()));
2410       Ops[0] = getConstant(Fold);
2411       Ops.erase(Ops.begin()+1);  // Erase the folded element
2412       if (Ops.size() == 1) return Ops[0];
2413       LHSC = cast<SCEVConstant>(Ops[0]);
2414     }
2415
2416     // If we are left with a constant minimum-int, strip it off.
2417     if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
2418       Ops.erase(Ops.begin());
2419       --Idx;
2420     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) {
2421       // If we have an smax with a constant maximum-int, it will always be
2422       // maximum-int.
2423       return Ops[0];
2424     }
2425
2426     if (Ops.size() == 1) return Ops[0];
2427   }
2428
2429   // Find the first SMax
2430   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
2431     ++Idx;
2432
2433   // Check to see if one of the operands is an SMax. If so, expand its operands
2434   // onto our operand list, and recurse to simplify.
2435   if (Idx < Ops.size()) {
2436     bool DeletedSMax = false;
2437     while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
2438       Ops.erase(Ops.begin()+Idx);
2439       Ops.append(SMax->op_begin(), SMax->op_end());
2440       DeletedSMax = true;
2441     }
2442
2443     if (DeletedSMax)
2444       return getSMaxExpr(Ops);
2445   }
2446
2447   // Okay, check to see if the same value occurs in the operand list twice.  If
2448   // so, delete one.  Since we sorted the list, these values are required to
2449   // be adjacent.
2450   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
2451     //  X smax Y smax Y  -->  X smax Y
2452     //  X smax Y         -->  X, if X is always greater than Y
2453     if (Ops[i] == Ops[i+1] ||
2454         isKnownPredicate(ICmpInst::ICMP_SGE, Ops[i], Ops[i+1])) {
2455       Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
2456       --i; --e;
2457     } else if (isKnownPredicate(ICmpInst::ICMP_SLE, Ops[i], Ops[i+1])) {
2458       Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
2459       --i; --e;
2460     }
2461
2462   if (Ops.size() == 1) return Ops[0];
2463
2464   assert(!Ops.empty() && "Reduced smax down to nothing!");
2465
2466   // Okay, it looks like we really DO need an smax expr.  Check to see if we
2467   // already have one, otherwise create a new one.
2468   FoldingSetNodeID ID;
2469   ID.AddInteger(scSMaxExpr);
2470   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2471     ID.AddPointer(Ops[i]);
2472   void *IP = 0;
2473   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2474   const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2475   std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2476   SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator),
2477                                              O, Ops.size());
2478   UniqueSCEVs.InsertNode(S, IP);
2479   return S;
2480 }
2481
2482 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS,
2483                                          const SCEV *RHS) {
2484   SmallVector<const SCEV *, 2> Ops;
2485   Ops.push_back(LHS);
2486   Ops.push_back(RHS);
2487   return getUMaxExpr(Ops);
2488 }
2489
2490 const SCEV *
2491 ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
2492   assert(!Ops.empty() && "Cannot get empty umax!");
2493   if (Ops.size() == 1) return Ops[0];
2494 #ifndef NDEBUG
2495   Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2496   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2497     assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2498            "SCEVUMaxExpr operand types don't match!");
2499 #endif
2500
2501   // Sort by complexity, this groups all similar expression types together.
2502   GroupByComplexity(Ops, LI);
2503
2504   // If there are any constants, fold them together.
2505   unsigned Idx = 0;
2506   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2507     ++Idx;
2508     assert(Idx < Ops.size());
2509     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2510       // We found two constants, fold them together!
2511       ConstantInt *Fold = ConstantInt::get(getContext(),
2512                               APIntOps::umax(LHSC->getValue()->getValue(),
2513                                              RHSC->getValue()->getValue()));
2514       Ops[0] = getConstant(Fold);
2515       Ops.erase(Ops.begin()+1);  // Erase the folded element
2516       if (Ops.size() == 1) return Ops[0];
2517       LHSC = cast<SCEVConstant>(Ops[0]);
2518     }
2519
2520     // If we are left with a constant minimum-int, strip it off.
2521     if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
2522       Ops.erase(Ops.begin());
2523       --Idx;
2524     } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) {
2525       // If we have an umax with a constant maximum-int, it will always be
2526       // maximum-int.
2527       return Ops[0];
2528     }
2529
2530     if (Ops.size() == 1) return Ops[0];
2531   }
2532
2533   // Find the first UMax
2534   while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
2535     ++Idx;
2536
2537   // Check to see if one of the operands is a UMax. If so, expand its operands
2538   // onto our operand list, and recurse to simplify.
2539   if (Idx < Ops.size()) {
2540     bool DeletedUMax = false;
2541     while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
2542       Ops.erase(Ops.begin()+Idx);
2543       Ops.append(UMax->op_begin(), UMax->op_end());
2544       DeletedUMax = true;
2545     }
2546
2547     if (DeletedUMax)
2548       return getUMaxExpr(Ops);
2549   }
2550
2551   // Okay, check to see if the same value occurs in the operand list twice.  If
2552   // so, delete one.  Since we sorted the list, these values are required to
2553   // be adjacent.
2554   for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
2555     //  X umax Y umax Y  -->  X umax Y
2556     //  X umax Y         -->  X, if X is always greater than Y
2557     if (Ops[i] == Ops[i+1] ||
2558         isKnownPredicate(ICmpInst::ICMP_UGE, Ops[i], Ops[i+1])) {
2559       Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
2560       --i; --e;
2561     } else if (isKnownPredicate(ICmpInst::ICMP_ULE, Ops[i], Ops[i+1])) {
2562       Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
2563       --i; --e;
2564     }
2565
2566   if (Ops.size() == 1) return Ops[0];
2567
2568   assert(!Ops.empty() && "Reduced umax down to nothing!");
2569
2570   // Okay, it looks like we really DO need a umax expr.  Check to see if we
2571   // already have one, otherwise create a new one.
2572   FoldingSetNodeID ID;
2573   ID.AddInteger(scUMaxExpr);
2574   for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2575     ID.AddPointer(Ops[i]);
2576   void *IP = 0;
2577   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2578   const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2579   std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2580   SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator),
2581                                              O, Ops.size());
2582   UniqueSCEVs.InsertNode(S, IP);
2583   return S;
2584 }
2585
2586 const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
2587                                          const SCEV *RHS) {
2588   // ~smax(~x, ~y) == smin(x, y).
2589   return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
2590 }
2591
2592 const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
2593                                          const SCEV *RHS) {
2594   // ~umax(~x, ~y) == umin(x, y)
2595   return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
2596 }
2597
2598 const SCEV *ScalarEvolution::getSizeOfExpr(Type *AllocTy) {
2599   // If we have TargetData, we can bypass creating a target-independent
2600   // constant expression and then folding it back into a ConstantInt.
2601   // This is just a compile-time optimization.
2602   if (TD)
2603     return getConstant(TD->getIntPtrType(getContext()),
2604                        TD->getTypeAllocSize(AllocTy));
2605
2606   Constant *C = ConstantExpr::getSizeOf(AllocTy);
2607   if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2608     if (Constant *Folded = ConstantFoldConstantExpression(CE, TD, TLI))
2609       C = Folded;
2610   Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy));
2611   return getTruncateOrZeroExtend(getSCEV(C), Ty);
2612 }
2613
2614 const SCEV *ScalarEvolution::getAlignOfExpr(Type *AllocTy) {
2615   Constant *C = ConstantExpr::getAlignOf(AllocTy);
2616   if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2617     if (Constant *Folded = ConstantFoldConstantExpression(CE, TD, TLI))
2618       C = Folded;
2619   Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy));
2620   return getTruncateOrZeroExtend(getSCEV(C), Ty);
2621 }
2622
2623 const SCEV *ScalarEvolution::getOffsetOfExpr(StructType *STy,
2624                                              unsigned FieldNo) {
2625   // If we have TargetData, we can bypass creating a target-independent
2626   // constant expression and then folding it back into a ConstantInt.
2627   // This is just a compile-time optimization.
2628   if (TD)
2629     return getConstant(TD->getIntPtrType(getContext()),
2630                        TD->getStructLayout(STy)->getElementOffset(FieldNo));
2631
2632   Constant *C = ConstantExpr::getOffsetOf(STy, FieldNo);
2633   if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2634     if (Constant *Folded = ConstantFoldConstantExpression(CE, TD, TLI))
2635       C = Folded;
2636   Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy));
2637   return getTruncateOrZeroExtend(getSCEV(C), Ty);
2638 }
2639
2640 const SCEV *ScalarEvolution::getOffsetOfExpr(Type *CTy,
2641                                              Constant *FieldNo) {
2642   Constant *C = ConstantExpr::getOffsetOf(CTy, FieldNo);
2643   if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2644     if (Constant *Folded = ConstantFoldConstantExpression(CE, TD, TLI))
2645       C = Folded;
2646   Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(CTy));
2647   return getTruncateOrZeroExtend(getSCEV(C), Ty);
2648 }
2649
2650 const SCEV *ScalarEvolution::getUnknown(Value *V) {
2651   // Don't attempt to do anything other than create a SCEVUnknown object
2652   // here.  createSCEV only calls getUnknown after checking for all other
2653   // interesting possibilities, and any other code that calls getUnknown
2654   // is doing so in order to hide a value from SCEV canonicalization.
2655
2656   FoldingSetNodeID ID;
2657   ID.AddInteger(scUnknown);
2658   ID.AddPointer(V);
2659   void *IP = 0;
2660   if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
2661     assert(cast<SCEVUnknown>(S)->getValue() == V &&
2662            "Stale SCEVUnknown in uniquing map!");
2663     return S;
2664   }
2665   SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
2666                                             FirstUnknown);
2667   FirstUnknown = cast<SCEVUnknown>(S);
2668   UniqueSCEVs.InsertNode(S, IP);
2669   return S;
2670 }
2671
2672 //===----------------------------------------------------------------------===//
2673 //            Basic SCEV Analysis and PHI Idiom Recognition Code
2674 //
2675
2676 /// isSCEVable - Test if values of the given type are analyzable within
2677 /// the SCEV framework. This primarily includes integer types, and it
2678 /// can optionally include pointer types if the ScalarEvolution class
2679 /// has access to target-specific information.
2680 bool ScalarEvolution::isSCEVable(Type *Ty) const {
2681   // Integers and pointers are always SCEVable.
2682   return Ty->isIntegerTy() || Ty->isPointerTy();
2683 }
2684
2685 /// getTypeSizeInBits - Return the size in bits of the specified type,
2686 /// for which isSCEVable must return true.
2687 uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
2688   assert(isSCEVable(Ty) && "Type is not SCEVable!");
2689
2690   // If we have a TargetData, use it!
2691   if (TD)
2692     return TD->getTypeSizeInBits(Ty);
2693
2694   // Integer types have fixed sizes.
2695   if (Ty->isIntegerTy())
2696     return Ty->getPrimitiveSizeInBits();
2697
2698   // The only other support type is pointer. Without TargetData, conservatively
2699   // assume pointers are 64-bit.
2700   assert(Ty->isPointerTy() && "isSCEVable permitted a non-SCEVable type!");
2701   return 64;
2702 }
2703
2704 /// getEffectiveSCEVType - Return a type with the same bitwidth as
2705 /// the given type and which represents how SCEV will treat the given
2706 /// type, for which isSCEVable must return true. For pointer types,
2707 /// this is the pointer-sized integer type.
2708 Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const {
2709   assert(isSCEVable(Ty) && "Type is not SCEVable!");
2710
2711   if (Ty->isIntegerTy())
2712     return Ty;
2713
2714   // The only other support type is pointer.
2715   assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
2716   if (TD) return TD->getIntPtrType(getContext());
2717
2718   // Without TargetData, conservatively assume pointers are 64-bit.
2719   return Type::getInt64Ty(getContext());
2720 }
2721
2722 const SCEV *ScalarEvolution::getCouldNotCompute() {
2723   return &CouldNotCompute;
2724 }
2725
2726 /// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
2727 /// expression and create a new one.
2728 const SCEV *ScalarEvolution::getSCEV(Value *V) {
2729   assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
2730
2731   ValueExprMapType::const_iterator I = ValueExprMap.find(V);
2732   if (I != ValueExprMap.end()) return I->second;
2733   const SCEV *S = createSCEV(V);
2734
2735   // The process of creating a SCEV for V may have caused other SCEVs
2736   // to have been created, so it's necessary to insert the new entry
2737   // from scratch, rather than trying to remember the insert position
2738   // above.
2739   ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S));
2740   return S;
2741 }
2742
2743 /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
2744 ///
2745 const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) {
2746   if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2747     return getConstant(
2748                cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
2749
2750   Type *Ty = V->getType();
2751   Ty = getEffectiveSCEVType(Ty);
2752   return getMulExpr(V,
2753                   getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))));
2754 }
2755
2756 /// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
2757 const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
2758   if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2759     return getConstant(
2760                 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
2761
2762   Type *Ty = V->getType();
2763   Ty = getEffectiveSCEVType(Ty);
2764   const SCEV *AllOnes =
2765                    getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty)));
2766   return getMinusSCEV(AllOnes, V);
2767 }
2768
2769 /// getMinusSCEV - Return LHS-RHS.  Minus is represented in SCEV as A+B*-1.
2770 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
2771                                           SCEV::NoWrapFlags Flags) {
2772   assert(!maskFlags(Flags, SCEV::FlagNUW) && "subtraction does not have NUW");
2773
2774   // Fast path: X - X --> 0.
2775   if (LHS == RHS)
2776     return getConstant(LHS->getType(), 0);
2777
2778   // X - Y --> X + -Y
2779   return getAddExpr(LHS, getNegativeSCEV(RHS), Flags);
2780 }
2781
2782 /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
2783 /// input value to the specified type.  If the type must be extended, it is zero
2784 /// extended.
2785 const SCEV *
2786 ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) {
2787   Type *SrcTy = V->getType();
2788   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2789          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2790          "Cannot truncate or zero extend with non-integer arguments!");
2791   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2792     return V;  // No conversion
2793   if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2794     return getTruncateExpr(V, Ty);
2795   return getZeroExtendExpr(V, Ty);
2796 }
2797
2798 /// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
2799 /// input value to the specified type.  If the type must be extended, it is sign
2800 /// extended.
2801 const SCEV *
2802 ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
2803                                          Type *Ty) {
2804   Type *SrcTy = V->getType();
2805   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2806          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2807          "Cannot truncate or zero extend with non-integer arguments!");
2808   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2809     return V;  // No conversion
2810   if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2811     return getTruncateExpr(V, Ty);
2812   return getSignExtendExpr(V, Ty);
2813 }
2814
2815 /// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the
2816 /// input value to the specified type.  If the type must be extended, it is zero
2817 /// extended.  The conversion must not be narrowing.
2818 const SCEV *
2819 ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
2820   Type *SrcTy = V->getType();
2821   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2822          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2823          "Cannot noop or zero extend with non-integer arguments!");
2824   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2825          "getNoopOrZeroExtend cannot truncate!");
2826   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2827     return V;  // No conversion
2828   return getZeroExtendExpr(V, Ty);
2829 }
2830
2831 /// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the
2832 /// input value to the specified type.  If the type must be extended, it is sign
2833 /// extended.  The conversion must not be narrowing.
2834 const SCEV *
2835 ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
2836   Type *SrcTy = V->getType();
2837   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2838          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2839          "Cannot noop or sign extend with non-integer arguments!");
2840   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2841          "getNoopOrSignExtend cannot truncate!");
2842   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2843     return V;  // No conversion
2844   return getSignExtendExpr(V, Ty);
2845 }
2846
2847 /// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of
2848 /// the input value to the specified type. If the type must be extended,
2849 /// it is extended with unspecified bits. The conversion must not be
2850 /// narrowing.
2851 const SCEV *
2852 ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
2853   Type *SrcTy = V->getType();
2854   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2855          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2856          "Cannot noop or any extend with non-integer arguments!");
2857   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2858          "getNoopOrAnyExtend cannot truncate!");
2859   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2860     return V;  // No conversion
2861   return getAnyExtendExpr(V, Ty);
2862 }
2863
2864 /// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the
2865 /// input value to the specified type.  The conversion must not be widening.
2866 const SCEV *
2867 ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
2868   Type *SrcTy = V->getType();
2869   assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2870          (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2871          "Cannot truncate or noop with non-integer arguments!");
2872   assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
2873          "getTruncateOrNoop cannot extend!");
2874   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2875     return V;  // No conversion
2876   return getTruncateExpr(V, Ty);
2877 }
2878
2879 /// getUMaxFromMismatchedTypes - Promote the operands to the wider of
2880 /// the types using zero-extension, and then perform a umax operation
2881 /// with them.
2882 const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
2883                                                         const SCEV *RHS) {
2884   const SCEV *PromotedLHS = LHS;
2885   const SCEV *PromotedRHS = RHS;
2886
2887   if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2888     PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2889   else
2890     PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2891
2892   return getUMaxExpr(PromotedLHS, PromotedRHS);
2893 }
2894
2895 /// getUMinFromMismatchedTypes - Promote the operands to the wider of
2896 /// the types using zero-extension, and then perform a umin operation
2897 /// with them.
2898 const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
2899                                                         const SCEV *RHS) {
2900   const SCEV *PromotedLHS = LHS;
2901   const SCEV *PromotedRHS = RHS;
2902
2903   if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2904     PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2905   else
2906     PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2907
2908   return getUMinExpr(PromotedLHS, PromotedRHS);
2909 }
2910
2911 /// getPointerBase - Transitively follow the chain of pointer-type operands
2912 /// until reaching a SCEV that does not have a single pointer operand. This
2913 /// returns a SCEVUnknown pointer for well-formed pointer-type expressions,
2914 /// but corner cases do exist.
2915 const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
2916   // A pointer operand may evaluate to a nonpointer expression, such as null.
2917   if (!V->getType()->isPointerTy())
2918     return V;
2919
2920   if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
2921     return getPointerBase(Cast->getOperand());
2922   }
2923   else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
2924     const SCEV *PtrOp = 0;
2925     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
2926          I != E; ++I) {
2927       if ((*I)->getType()->isPointerTy()) {
2928         // Cannot find the base of an expression with multiple pointer operands.
2929         if (PtrOp)
2930           return V;
2931         PtrOp = *I;
2932       }
2933     }
2934     if (!PtrOp)
2935       return V;
2936     return getPointerBase(PtrOp);
2937   }
2938   return V;
2939 }
2940
2941 /// PushDefUseChildren - Push users of the given Instruction
2942 /// onto the given Worklist.
2943 static void
2944 PushDefUseChildren(Instruction *I,
2945                    SmallVectorImpl<Instruction *> &Worklist) {
2946   // Push the def-use children onto the Worklist stack.
2947   for (Value::use_iterator UI = I->use_begin(), UE = I->use_end();
2948        UI != UE; ++UI)
2949     Worklist.push_back(cast<Instruction>(*UI));
2950 }
2951
2952 /// ForgetSymbolicValue - This looks up computed SCEV values for all
2953 /// instructions that depend on the given instruction and removes them from
2954 /// the ValueExprMapType map if they reference SymName. This is used during PHI
2955 /// resolution.
2956 void
2957 ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
2958   SmallVector<Instruction *, 16> Worklist;
2959   PushDefUseChildren(PN, Worklist);
2960
2961   SmallPtrSet<Instruction *, 8> Visited;
2962   Visited.insert(PN);
2963   while (!Worklist.empty()) {
2964     Instruction *I = Worklist.pop_back_val();
2965     if (!Visited.insert(I)) continue;
2966
2967     ValueExprMapType::iterator It =
2968       ValueExprMap.find(static_cast<Value *>(I));
2969     if (It != ValueExprMap.end()) {
2970       const SCEV *Old = It->second;
2971
2972       // Short-circuit the def-use traversal if the symbolic name
2973       // ceases to appear in expressions.
2974       if (Old != SymName && !hasOperand(Old, SymName))
2975         continue;
2976
2977       // SCEVUnknown for a PHI either means that it has an unrecognized
2978       // structure, it's a PHI that's in the progress of being computed
2979       // by createNodeForPHI, or it's a single-value PHI. In the first case,
2980       // additional loop trip count information isn't going to change anything.
2981       // In the second case, createNodeForPHI will perform the necessary
2982       // updates on its own when it gets to that point. In the third, we do
2983       // want to forget the SCEVUnknown.
2984       if (!isa<PHINode>(I) ||
2985           !isa<SCEVUnknown>(Old) ||
2986           (I != PN && Old == SymName)) {
2987         forgetMemoizedResults(Old);
2988         ValueExprMap.erase(It);
2989       }
2990     }
2991
2992     PushDefUseChildren(I, Worklist);
2993   }
2994 }
2995
2996 /// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
2997 /// a loop header, making it a potential recurrence, or it doesn't.
2998 ///
2999 const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
3000   if (const Loop *L = LI->getLoopFor(PN->getParent()))
3001     if (L->getHeader() == PN->getParent()) {
3002       // The loop may have multiple entrances or multiple exits; we can analyze
3003       // this phi as an addrec if it has a unique entry value and a unique
3004       // backedge value.
3005       Value *BEValueV = 0, *StartValueV = 0;
3006       for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
3007         Value *V = PN->getIncomingValue(i);
3008         if (L->contains(PN->getIncomingBlock(i))) {
3009           if (!BEValueV) {
3010             BEValueV = V;
3011           } else if (BEValueV != V) {
3012             BEValueV = 0;
3013             break;
3014           }
3015         } else if (!StartValueV) {
3016           StartValueV = V;
3017         } else if (StartValueV != V) {
3018           StartValueV = 0;
3019           break;
3020         }
3021       }
3022       if (BEValueV && StartValueV) {
3023         // While we are analyzing this PHI node, handle its value symbolically.
3024         const SCEV *SymbolicName = getUnknown(PN);
3025         assert(ValueExprMap.find(PN) == ValueExprMap.end() &&
3026                "PHI node already processed?");
3027         ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
3028
3029         // Using this symbolic name for the PHI, analyze the value coming around
3030         // the back-edge.
3031         const SCEV *BEValue = getSCEV(BEValueV);
3032
3033         // NOTE: If BEValue is loop invariant, we know that the PHI node just
3034         // has a special value for the first iteration of the loop.
3035
3036         // If the value coming around the backedge is an add with the symbolic
3037         // value we just inserted, then we found a simple induction variable!
3038         if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
3039           // If there is a single occurrence of the symbolic value, replace it
3040           // with a recurrence.
3041           unsigned FoundIndex = Add->getNumOperands();
3042           for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
3043             if (Add->getOperand(i) == SymbolicName)
3044               if (FoundIndex == e) {
3045                 FoundIndex = i;
3046                 break;
3047               }
3048
3049           if (FoundIndex != Add->getNumOperands()) {
3050             // Create an add with everything but the specified operand.
3051             SmallVector<const SCEV *, 8> Ops;
3052             for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
3053               if (i != FoundIndex)
3054                 Ops.push_back(Add->getOperand(i));
3055             const SCEV *Accum = getAddExpr(Ops);
3056
3057             // This is not a valid addrec if the step amount is varying each
3058             // loop iteration, but is not itself an addrec in this loop.
3059             if (isLoopInvariant(Accum, L) ||
3060                 (isa<SCEVAddRecExpr>(Accum) &&
3061                  cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
3062               SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
3063
3064               // If the increment doesn't overflow, then neither the addrec nor
3065               // the post-increment will overflow.
3066               if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) {
3067                 if (OBO->hasNoUnsignedWrap())
3068                   Flags = setFlags(Flags, SCEV::FlagNUW);
3069                 if (OBO->hasNoSignedWrap())
3070                   Flags = setFlags(Flags, SCEV::FlagNSW);
3071               } else if (const GEPOperator *GEP =
3072                          dyn_cast<GEPOperator>(BEValueV)) {
3073                 // If the increment is an inbounds GEP, then we know the address
3074                 // space cannot be wrapped around. We cannot make any guarantee
3075                 // about signed or unsigned overflow because pointers are
3076                 // unsigned but we may have a negative index from the base
3077                 // pointer.
3078                 if (GEP->isInBounds())
3079                   Flags = setFlags(Flags, SCEV::FlagNW);
3080               }
3081
3082               const SCEV *StartVal = getSCEV(StartValueV);
3083               const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
3084
3085               // Since the no-wrap flags are on the increment, they apply to the
3086               // post-incremented value as well.
3087               if (isLoopInvariant(Accum, L))
3088                 (void)getAddRecExpr(getAddExpr(StartVal, Accum),
3089                                     Accum, L, Flags);
3090
3091               // Okay, for the entire analysis of this edge we assumed the PHI
3092               // to be symbolic.  We now need to go back and purge all of the
3093               // entries for the scalars that use the symbolic expression.
3094               ForgetSymbolicName(PN, SymbolicName);
3095               ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
3096               return PHISCEV;
3097             }
3098           }
3099         } else if (const SCEVAddRecExpr *AddRec =
3100                      dyn_cast<SCEVAddRecExpr>(BEValue)) {
3101           // Otherwise, this could be a loop like this:
3102           //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
3103           // In this case, j = {1,+,1}  and BEValue is j.
3104           // Because the other in-value of i (0) fits the evolution of BEValue
3105           // i really is an addrec evolution.
3106           if (AddRec->getLoop() == L && AddRec->isAffine()) {
3107             const SCEV *StartVal = getSCEV(StartValueV);
3108
3109             // If StartVal = j.start - j.stride, we can use StartVal as the
3110             // initial step of the addrec evolution.
3111             if (StartVal == getMinusSCEV(AddRec->getOperand(0),
3112                                          AddRec->getOperand(1))) {
3113               // FIXME: For constant StartVal, we should be able to infer
3114               // no-wrap flags.
3115               const SCEV *PHISCEV =
3116                 getAddRecExpr(StartVal, AddRec->getOperand(1), L,
3117                               SCEV::FlagAnyWrap);
3118
3119               // Okay, for the entire analysis of this edge we assumed the PHI
3120               // to be symbolic.  We now need to go back and purge all of the
3121               // entries for the scalars that use the symbolic expression.
3122               ForgetSymbolicName(PN, SymbolicName);
3123               ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
3124               return PHISCEV;
3125             }
3126           }
3127         }
3128       }
3129     }
3130
3131   // If the PHI has a single incoming value, follow that value, unless the
3132   // PHI's incoming blocks are in a different loop, in which case doing so
3133   // risks breaking LCSSA form. Instcombine would normally zap these, but
3134   // it doesn't have DominatorTree information, so it may miss cases.
3135   if (Value *V = SimplifyInstruction(PN, TD, TLI, DT))
3136     if (LI->replacementPreservesLCSSAForm(PN, V))
3137       return getSCEV(V);
3138
3139   // If it's not a loop phi, we can't handle it yet.
3140   return getUnknown(PN);
3141 }
3142
3143 /// createNodeForGEP - Expand GEP instructions into add and multiply
3144 /// operations. This allows them to be analyzed by regular SCEV code.
3145 ///
3146 const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
3147
3148   // Don't blindly transfer the inbounds flag from the GEP instruction to the
3149   // Add expression, because the Instruction may be guarded by control flow
3150   // and the no-overflow bits may not be valid for the expression in any
3151   // context.
3152   bool isInBounds = GEP->isInBounds();
3153
3154   Type *IntPtrTy = getEffectiveSCEVType(GEP->getType());
3155   Value *Base = GEP->getOperand(0);
3156   // Don't attempt to analyze GEPs over unsized objects.
3157   if (!cast<PointerType>(Base->getType())->getElementType()->isSized())
3158     return getUnknown(GEP);
3159   const SCEV *TotalOffset = getConstant(IntPtrTy, 0);
3160   gep_type_iterator GTI = gep_type_begin(GEP);
3161   for (GetElementPtrInst::op_iterator I = llvm::next(GEP->op_begin()),
3162                                       E = GEP->op_end();
3163        I != E; ++I) {
3164     Value *Index = *I;
3165     // Compute the (potentially symbolic) offset in bytes for this index.
3166     if (StructType *STy = dyn_cast<StructType>(*GTI++)) {
3167       // For a struct, add the member offset.
3168       unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
3169       const SCEV *FieldOffset = getOffsetOfExpr(STy, FieldNo);
3170
3171       // Add the field offset to the running total offset.
3172       TotalOffset = getAddExpr(TotalOffset, FieldOffset);
3173     } else {
3174       // For an array, add the element offset, explicitly scaled.
3175       const SCEV *ElementSize = getSizeOfExpr(*GTI);
3176       const SCEV *IndexS = getSCEV(Index);
3177       // Getelementptr indices are signed.
3178       IndexS = getTruncateOrSignExtend(IndexS, IntPtrTy);
3179
3180       // Multiply the index by the element size to compute the element offset.
3181       const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize,
3182                                            isInBounds ? SCEV::FlagNSW :
3183                                            SCEV::FlagAnyWrap);
3184
3185       // Add the element offset to the running total offset.
3186       TotalOffset = getAddExpr(TotalOffset, LocalOffset);
3187     }
3188   }
3189
3190   // Get the SCEV for the GEP base.
3191   const SCEV *BaseS = getSCEV(Base);
3192
3193   // Add the total offset from all the GEP indices to the base.
3194   return getAddExpr(BaseS, TotalOffset,
3195                     isInBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap);
3196 }
3197
3198 /// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
3199 /// guaranteed to end in (at every loop iteration).  It is, at the same time,
3200 /// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
3201 /// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
3202 uint32_t
3203 ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
3204   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3205     return C->getValue()->getValue().countTrailingZeros();
3206
3207   if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
3208     return std::min(GetMinTrailingZeros(T->getOperand()),
3209                     (uint32_t)getTypeSizeInBits(T->getType()));
3210
3211   if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
3212     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
3213     return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
3214              getTypeSizeInBits(E->getType()) : OpRes;
3215   }
3216
3217   if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
3218     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
3219     return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
3220              getTypeSizeInBits(E->getType()) : OpRes;
3221   }
3222
3223   if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
3224     // The result is the min of all operands results.
3225     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
3226     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
3227       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
3228     return MinOpRes;
3229   }
3230
3231   if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
3232     // The result is the sum of all operands results.
3233     uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
3234     uint32_t BitWidth = getTypeSizeInBits(M->getType());
3235     for (unsigned i = 1, e = M->getNumOperands();
3236          SumOpRes != BitWidth && i != e; ++i)
3237       SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
3238                           BitWidth);
3239     return SumOpRes;
3240   }
3241
3242   if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
3243     // The result is the min of all operands results.
3244     uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
3245     for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
3246       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
3247     return MinOpRes;
3248   }
3249
3250   if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
3251     // The result is the min of all operands results.
3252     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
3253     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
3254       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
3255     return MinOpRes;
3256   }
3257
3258   if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
3259     // The result is the min of all operands results.
3260     uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
3261     for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
3262       MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
3263     return MinOpRes;
3264   }
3265
3266   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3267     // For a SCEVUnknown, ask ValueTracking.
3268     unsigned BitWidth = getTypeSizeInBits(U->getType());
3269     APInt Mask = APInt::getAllOnesValue(BitWidth);
3270     APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
3271     ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones);
3272     return Zeros.countTrailingOnes();
3273   }
3274
3275   // SCEVUDivExpr
3276   return 0;
3277 }
3278
3279 /// getUnsignedRange - Determine the unsigned range for a particular SCEV.
3280 ///
3281 ConstantRange
3282 ScalarEvolution::getUnsignedRange(const SCEV *S) {
3283   // See if we've computed this range already.
3284   DenseMap<const SCEV *, ConstantRange>::iterator I = UnsignedRanges.find(S);
3285   if (I != UnsignedRanges.end())
3286     return I->second;
3287
3288   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3289     return setUnsignedRange(C, ConstantRange(C->getValue()->getValue()));
3290
3291   unsigned BitWidth = getTypeSizeInBits(S->getType());
3292   ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
3293
3294   // If the value has known zeros, the maximum unsigned value will have those
3295   // known zeros as well.
3296   uint32_t TZ = GetMinTrailingZeros(S);
3297   if (TZ != 0)
3298     ConservativeResult =
3299       ConstantRange(APInt::getMinValue(BitWidth),
3300                     APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
3301
3302   if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
3303     ConstantRange X = getUnsignedRange(Add->getOperand(0));
3304     for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
3305       X = X.add(getUnsignedRange(Add->getOperand(i)));
3306     return setUnsignedRange(Add, ConservativeResult.intersectWith(X));
3307   }
3308
3309   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
3310     ConstantRange X = getUnsignedRange(Mul->getOperand(0));
3311     for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
3312       X = X.multiply(getUnsignedRange(Mul->getOperand(i)));
3313     return setUnsignedRange(Mul, ConservativeResult.intersectWith(X));
3314   }
3315
3316   if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
3317     ConstantRange X = getUnsignedRange(SMax->getOperand(0));
3318     for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
3319       X = X.smax(getUnsignedRange(SMax->getOperand(i)));
3320     return setUnsignedRange(SMax, ConservativeResult.intersectWith(X));
3321   }
3322
3323   if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
3324     ConstantRange X = getUnsignedRange(UMax->getOperand(0));
3325     for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
3326       X = X.umax(getUnsignedRange(UMax->getOperand(i)));
3327     return setUnsignedRange(UMax, ConservativeResult.intersectWith(X));
3328   }
3329
3330   if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
3331     ConstantRange X = getUnsignedRange(UDiv->getLHS());
3332     ConstantRange Y = getUnsignedRange(UDiv->getRHS());
3333     return setUnsignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y)));
3334   }
3335
3336   if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
3337     ConstantRange X = getUnsignedRange(ZExt->getOperand());
3338     return setUnsignedRange(ZExt,
3339       ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
3340   }
3341
3342   if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
3343     ConstantRange X = getUnsignedRange(SExt->getOperand());
3344     return setUnsignedRange(SExt,
3345       ConservativeResult.intersectWith(X.signExtend(BitWidth)));
3346   }
3347
3348   if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
3349     ConstantRange X = getUnsignedRange(Trunc->getOperand());
3350     return setUnsignedRange(Trunc,
3351       ConservativeResult.intersectWith(X.truncate(BitWidth)));
3352   }
3353
3354   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
3355     // If there's no unsigned wrap, the value will never be less than its
3356     // initial value.
3357     if (AddRec->getNoWrapFlags(SCEV::FlagNUW))
3358       if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
3359         if (!C->getValue()->isZero())
3360           ConservativeResult =
3361             ConservativeResult.intersectWith(
3362               ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0)));
3363
3364     // TODO: non-affine addrec
3365     if (AddRec->isAffine()) {
3366       Type *Ty = AddRec->getType();
3367       const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
3368       if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
3369           getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
3370         MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
3371
3372         const SCEV *Start = AddRec->getStart();
3373         const SCEV *Step = AddRec->getStepRecurrence(*this);
3374
3375         ConstantRange StartRange = getUnsignedRange(Start);
3376         ConstantRange StepRange = getSignedRange(Step);
3377         ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
3378         ConstantRange EndRange =
3379           StartRange.add(MaxBECountRange.multiply(StepRange));
3380
3381         // Check for overflow. This must be done with ConstantRange arithmetic
3382         // because we could be called from within the ScalarEvolution overflow
3383         // checking code.
3384         ConstantRange ExtStartRange = StartRange.zextOrTrunc(BitWidth*2+1);
3385         ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1);
3386         ConstantRange ExtMaxBECountRange =
3387           MaxBECountRange.zextOrTrunc(BitWidth*2+1);
3388         ConstantRange ExtEndRange = EndRange.zextOrTrunc(BitWidth*2+1);
3389         if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) !=
3390             ExtEndRange)
3391           return setUnsignedRange(AddRec, ConservativeResult);
3392
3393         APInt Min = APIntOps::umin(StartRange.getUnsignedMin(),
3394                                    EndRange.getUnsignedMin());
3395         APInt Max = APIntOps::umax(StartRange.getUnsignedMax(),
3396                                    EndRange.getUnsignedMax());
3397         if (Min.isMinValue() && Max.isMaxValue())
3398           return setUnsignedRange(AddRec, ConservativeResult);
3399         return setUnsignedRange(AddRec,
3400           ConservativeResult.intersectWith(ConstantRange(Min, Max+1)));
3401       }
3402     }
3403
3404     return setUnsignedRange(AddRec, ConservativeResult);
3405   }
3406
3407   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3408     // For a SCEVUnknown, ask ValueTracking.
3409     APInt Mask = APInt::getAllOnesValue(BitWidth);
3410     APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
3411     ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD);
3412     if (Ones == ~Zeros + 1)
3413       return setUnsignedRange(U, ConservativeResult);
3414     return setUnsignedRange(U,
3415       ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)));
3416   }
3417
3418   return setUnsignedRange(S, ConservativeResult);
3419 }
3420
3421 /// getSignedRange - Determine the signed range for a particular SCEV.
3422 ///
3423 ConstantRange
3424 ScalarEvolution::getSignedRange(const SCEV *S) {
3425   // See if we've computed this range already.
3426   DenseMap<const SCEV *, ConstantRange>::iterator I = SignedRanges.find(S);
3427   if (I != SignedRanges.end())
3428     return I->second;
3429
3430   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3431     return setSignedRange(C, ConstantRange(C->getValue()->getValue()));
3432
3433   unsigned BitWidth = getTypeSizeInBits(S->getType());
3434   ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
3435
3436   // If the value has known zeros, the maximum signed value will have those
3437   // known zeros as well.
3438   uint32_t TZ = GetMinTrailingZeros(S);
3439   if (TZ != 0)
3440     ConservativeResult =
3441       ConstantRange(APInt::getSignedMinValue(BitWidth),
3442                     APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
3443
3444   if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
3445     ConstantRange X = getSignedRange(Add->getOperand(0));
3446     for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
3447       X = X.add(getSignedRange(Add->getOperand(i)));
3448     return setSignedRange(Add, ConservativeResult.intersectWith(X));
3449   }
3450
3451   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
3452     ConstantRange X = getSignedRange(Mul->getOperand(0));
3453     for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
3454       X = X.multiply(getSignedRange(Mul->getOperand(i)));
3455     return setSignedRange(Mul, ConservativeResult.intersectWith(X));
3456   }
3457
3458   if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
3459     ConstantRange X = getSignedRange(SMax->getOperand(0));
3460     for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
3461       X = X.smax(getSignedRange(SMax->getOperand(i)));
3462     return setSignedRange(SMax, ConservativeResult.intersectWith(X));
3463   }
3464
3465   if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
3466     ConstantRange X = getSignedRange(UMax->getOperand(0));
3467     for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
3468       X = X.umax(getSignedRange(UMax->getOperand(i)));
3469     return setSignedRange(UMax, ConservativeResult.intersectWith(X));
3470   }
3471
3472   if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
3473     ConstantRange X = getSignedRange(UDiv->getLHS());
3474     ConstantRange Y = getSignedRange(UDiv->getRHS());
3475     return setSignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y)));
3476   }
3477
3478   if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
3479     ConstantRange X = getSignedRange(ZExt->getOperand());
3480     return setSignedRange(ZExt,
3481       ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
3482   }
3483
3484   if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
3485     ConstantRange X = getSignedRange(SExt->getOperand());
3486     return setSignedRange(SExt,
3487       ConservativeResult.intersectWith(X.signExtend(BitWidth)));
3488   }
3489
3490   if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
3491     ConstantRange X = getSignedRange(Trunc->getOperand());
3492     return setSignedRange(Trunc,
3493       ConservativeResult.intersectWith(X.truncate(BitWidth)));
3494   }
3495
3496   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
3497     // If there's no signed wrap, and all the operands have the same sign or
3498     // zero, the value won't ever change sign.
3499     if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) {
3500       bool AllNonNeg = true;
3501       bool AllNonPos = true;
3502       for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3503         if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false;
3504         if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false;
3505       }
3506       if (AllNonNeg)
3507         ConservativeResult = ConservativeResult.intersectWith(
3508           ConstantRange(APInt(BitWidth, 0),
3509                         APInt::getSignedMinValue(BitWidth)));
3510       else if (AllNonPos)
3511         ConservativeResult = ConservativeResult.intersectWith(
3512           ConstantRange(APInt::getSignedMinValue(BitWidth),
3513                         APInt(BitWidth, 1)));
3514     }
3515
3516     // TODO: non-affine addrec
3517     if (AddRec->isAffine()) {
3518       Type *Ty = AddRec->getType();
3519       const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
3520       if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
3521           getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
3522         MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
3523
3524         const SCEV *Start = AddRec->getStart();
3525         const SCEV *Step = AddRec->getStepRecurrence(*this);
3526
3527         ConstantRange StartRange = getSignedRange(Start);
3528         ConstantRange StepRange = getSignedRange(Step);
3529         ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
3530         ConstantRange EndRange =
3531           StartRange.add(MaxBECountRange.multiply(StepRange));
3532
3533         // Check for overflow. This must be done with ConstantRange arithmetic
3534         // because we could be called from within the ScalarEvolution overflow
3535         // checking code.
3536         ConstantRange ExtStartRange = StartRange.sextOrTrunc(BitWidth*2+1);
3537         ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1);
3538         ConstantRange ExtMaxBECountRange =
3539           MaxBECountRange.zextOrTrunc(BitWidth*2+1);
3540         ConstantRange ExtEndRange = EndRange.sextOrTrunc(BitWidth*2+1);
3541         if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) !=
3542             ExtEndRange)
3543           return setSignedRange(AddRec, ConservativeResult);
3544
3545         APInt Min = APIntOps::smin(StartRange.getSignedMin(),
3546                                    EndRange.getSignedMin());
3547         APInt Max = APIntOps::smax(StartRange.getSignedMax(),
3548                                    EndRange.getSignedMax());
3549         if (Min.isMinSignedValue() && Max.isMaxSignedValue())
3550           return setSignedRange(AddRec, ConservativeResult);
3551         return setSignedRange(AddRec,
3552           ConservativeResult.intersectWith(ConstantRange(Min, Max+1)));
3553       }
3554     }
3555
3556     return setSignedRange(AddRec, ConservativeResult);
3557   }
3558
3559   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3560     // For a SCEVUnknown, ask ValueTracking.
3561     if (!U->getValue()->getType()->isIntegerTy() && !TD)
3562       return setSignedRange(U, ConservativeResult);
3563     unsigned NS = ComputeNumSignBits(U->getValue(), TD);
3564     if (NS == 1)
3565       return setSignedRange(U, ConservativeResult);
3566     return setSignedRange(U, ConservativeResult.intersectWith(
3567       ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
3568                     APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1)));
3569   }
3570
3571   return setSignedRange(S, ConservativeResult);
3572 }
3573
3574 /// createSCEV - We know that there is no SCEV for the specified value.
3575 /// Analyze the expression.
3576 ///
3577 const SCEV *ScalarEvolution::createSCEV(Value *V) {
3578   if (!isSCEVable(V->getType()))
3579     return getUnknown(V);
3580
3581   unsigned Opcode = Instruction::UserOp1;
3582   if (Instruction *I = dyn_cast<Instruction>(V)) {
3583     Opcode = I->getOpcode();
3584
3585     // Don't attempt to analyze instructions in blocks that aren't
3586     // reachable. Such instructions don't matter, and they aren't required
3587     // to obey basic rules for definitions dominating uses which this
3588     // analysis depends on.
3589     if (!DT->isReachableFromEntry(I->getParent()))
3590       return getUnknown(V);
3591   } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
3592     Opcode = CE->getOpcode();
3593   else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
3594     return getConstant(CI);
3595   else if (isa<ConstantPointerNull>(V))
3596     return getConstant(V->getType(), 0);
3597   else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
3598     return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee());
3599   else
3600     return getUnknown(V);
3601
3602   Operator *U = cast<Operator>(V);
3603   switch (Opcode) {
3604   case Instruction::Add: {
3605     // The simple thing to do would be to just call getSCEV on both operands
3606     // and call getAddExpr with the result. However if we're looking at a
3607     // bunch of things all added together, this can be quite inefficient,
3608     // because it leads to N-1 getAddExpr calls for N ultimate operands.
3609     // Instead, gather up all the operands and make a single getAddExpr call.
3610     // LLVM IR canonical form means we need only traverse the left operands.
3611     //
3612     // Don't apply this instruction's NSW or NUW flags to the new
3613     // expression. The instruction may be guarded by control flow that the
3614     // no-wrap behavior depends on. Non-control-equivalent instructions can be
3615     // mapped to the same SCEV expression, and it would be incorrect to transfer
3616     // NSW/NUW semantics to those operations.
3617     SmallVector<const SCEV *, 4> AddOps;
3618     AddOps.push_back(getSCEV(U->getOperand(1)));
3619     for (Value *Op = U->getOperand(0); ; Op = U->getOperand(0)) {
3620       unsigned Opcode = Op->getValueID() - Value::InstructionVal;
3621       if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
3622         break;
3623       U = cast<Operator>(Op);
3624       const SCEV *Op1 = getSCEV(U->getOperand(1));
3625       if (Opcode == Instruction::Sub)
3626         AddOps.push_back(getNegativeSCEV(Op1));
3627       else
3628         AddOps.push_back(Op1);
3629     }
3630     AddOps.push_back(getSCEV(U->getOperand(0)));
3631     return getAddExpr(AddOps);
3632   }
3633   case Instruction::Mul: {
3634     // Don't transfer NSW/NUW for the same reason as AddExpr.
3635     SmallVector<const SCEV *, 4> MulOps;
3636     MulOps.push_back(getSCEV(U->getOperand(1)));
3637     for (Value *Op = U->getOperand(0);
3638          Op->getValueID() == Instruction::Mul + Value::InstructionVal;
3639          Op = U->getOperand(0)) {
3640       U = cast<Operator>(Op);
3641       MulOps.push_back(getSCEV(U->getOperand(1)));
3642     }
3643     MulOps.push_back(getSCEV(U->getOperand(0)));
3644     return getMulExpr(MulOps);
3645   }
3646   case Instruction::UDiv:
3647     return getUDivExpr(getSCEV(U->getOperand(0)),
3648                        getSCEV(U->getOperand(1)));
3649   case Instruction::Sub:
3650     return getMinusSCEV(getSCEV(U->getOperand(0)),
3651                         getSCEV(U->getOperand(1)));
3652   case Instruction::And:
3653     // For an expression like x&255 that merely masks off the high bits,
3654     // use zext(trunc(x)) as the SCEV expression.
3655     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3656       if (CI->isNullValue())
3657         return getSCEV(U->getOperand(1));
3658       if (CI->isAllOnesValue())
3659         return getSCEV(U->getOperand(0));
3660       const APInt &A = CI->getValue();
3661
3662       // Instcombine's ShrinkDemandedConstant may strip bits out of
3663       // constants, obscuring what would otherwise be a low-bits mask.
3664       // Use ComputeMaskedBits to compute what ShrinkDemandedConstant
3665       // knew about to reconstruct a low-bits mask value.
3666       unsigned LZ = A.countLeadingZeros();
3667       unsigned BitWidth = A.getBitWidth();
3668       APInt AllOnes = APInt::getAllOnesValue(BitWidth);
3669       APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
3670       ComputeMaskedBits(U->getOperand(0), AllOnes, KnownZero, KnownOne, TD);
3671
3672       APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ);
3673
3674       if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask))
3675         return
3676           getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)),
3677                                 IntegerType::get(getContext(), BitWidth - LZ)),
3678                             U->getType());
3679     }
3680     break;
3681
3682   case Instruction::Or:
3683     // If the RHS of the Or is a constant, we may have something like:
3684     // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
3685     // optimizations will transparently handle this case.
3686     //
3687     // In order for this transformation to be safe, the LHS must be of the
3688     // form X*(2^n) and the Or constant must be less than 2^n.
3689     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3690       const SCEV *LHS = getSCEV(U->getOperand(0));
3691       const APInt &CIVal = CI->getValue();
3692       if (GetMinTrailingZeros(LHS) >=
3693           (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
3694         // Build a plain add SCEV.
3695         const SCEV *S = getAddExpr(LHS, getSCEV(CI));
3696         // If the LHS of the add was an addrec and it has no-wrap flags,
3697         // transfer the no-wrap flags, since an or won't introduce a wrap.
3698         if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
3699           const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
3700           const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags(
3701             OldAR->getNoWrapFlags());
3702         }
3703         return S;
3704       }
3705     }
3706     break;
3707   case Instruction::Xor:
3708     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3709       // If the RHS of the xor is a signbit, then this is just an add.
3710       // Instcombine turns add of signbit into xor as a strength reduction step.
3711       if (CI->getValue().isSignBit())
3712         return getAddExpr(getSCEV(U->getOperand(0)),
3713                           getSCEV(U->getOperand(1)));
3714
3715       // If the RHS of xor is -1, then this is a not operation.
3716       if (CI->isAllOnesValue())
3717         return getNotSCEV(getSCEV(U->getOperand(0)));
3718
3719       // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
3720       // This is a variant of the check for xor with -1, and it handles
3721       // the case where instcombine has trimmed non-demanded bits out
3722       // of an xor with -1.
3723       if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0)))
3724         if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1)))
3725           if (BO->getOpcode() == Instruction::And &&
3726               LCI->getValue() == CI->getValue())
3727             if (const SCEVZeroExtendExpr *Z =
3728                   dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) {
3729               Type *UTy = U->getType();
3730               const SCEV *Z0 = Z->getOperand();
3731               Type *Z0Ty = Z0->getType();
3732               unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
3733
3734               // If C is a low-bits mask, the zero extend is serving to
3735               // mask off the high bits. Complement the operand and
3736               // re-apply the zext.
3737               if (APIntOps::isMask(Z0TySize, CI->getValue()))
3738                 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
3739
3740               // If C is a single bit, it may be in the sign-bit position
3741               // before the zero-extend. In this case, represent the xor
3742               // using an add, which is equivalent, and re-apply the zext.
3743               APInt Trunc = CI->getValue().trunc(Z0TySize);
3744               if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
3745                   Trunc.isSignBit())
3746                 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
3747                                          UTy);
3748             }
3749     }
3750     break;
3751
3752   case Instruction::Shl:
3753     // Turn shift left of a constant amount into a multiply.
3754     if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
3755       uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
3756
3757       // If the shift count is not less than the bitwidth, the result of
3758       // the shift is undefined. Don't try to analyze it, because the
3759       // resolution chosen here may differ from the resolution chosen in
3760       // other parts of the compiler.
3761       if (SA->getValue().uge(BitWidth))
3762         break;
3763
3764       Constant *X = ConstantInt::get(getContext(),
3765         APInt(BitWidth, 1).shl(SA->getZExtValue()));
3766       return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
3767     }
3768     break;
3769
3770   case Instruction::LShr:
3771     // Turn logical shift right of a constant into a unsigned divide.
3772     if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
3773       uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
3774
3775       // If the shift count is not less than the bitwidth, the result of
3776       // the shift is undefined. Don't try to analyze it, because the
3777       // resolution chosen here may differ from the resolution chosen in
3778       // other parts of the compiler.
3779       if (SA->getValue().uge(BitWidth))
3780         break;
3781
3782       Constant *X = ConstantInt::get(getContext(),
3783         APInt(BitWidth, 1).shl(SA->getZExtValue()));
3784       return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
3785     }
3786     break;
3787
3788   case Instruction::AShr:
3789     // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
3790     if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
3791       if (Operator *L = dyn_cast<Operator>(U->getOperand(0)))
3792         if (L->getOpcode() == Instruction::Shl &&
3793             L->getOperand(1) == U->getOperand(1)) {
3794           uint64_t BitWidth = getTypeSizeInBits(U->getType());
3795
3796           // If the shift count is not less than the bitwidth, the result of
3797           // the shift is undefined. Don't try to analyze it, because the
3798           // resolution chosen here may differ from the resolution chosen in
3799           // other parts of the compiler.
3800           if (CI->getValue().uge(BitWidth))
3801             break;
3802
3803           uint64_t Amt = BitWidth - CI->getZExtValue();
3804           if (Amt == BitWidth)
3805             return getSCEV(L->getOperand(0));       // shift by zero --> noop
3806           return
3807             getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
3808                                               IntegerType::get(getContext(),
3809                                                                Amt)),
3810                               U->getType());
3811         }
3812     break;
3813
3814   case Instruction::Trunc:
3815     return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
3816
3817   case Instruction::ZExt:
3818     return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
3819
3820   case Instruction::SExt:
3821     return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
3822
3823   case Instruction::BitCast:
3824     // BitCasts are no-op casts so we just eliminate the cast.
3825     if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
3826       return getSCEV(U->getOperand(0));
3827     break;
3828
3829   // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can
3830   // lead to pointer expressions which cannot safely be expanded to GEPs,
3831   // because ScalarEvolution doesn't respect the GEP aliasing rules when
3832   // simplifying integer expressions.
3833
3834   case Instruction::GetElementPtr:
3835     return createNodeForGEP(cast<GEPOperator>(U));
3836
3837   case Instruction::PHI:
3838     return createNodeForPHI(cast<PHINode>(U));
3839
3840   case Instruction::Select:
3841     // This could be a smax or umax that was lowered earlier.
3842     // Try to recover it.
3843     if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
3844       Value *LHS = ICI->getOperand(0);
3845       Value *RHS = ICI->getOperand(1);
3846       switch (ICI->getPredicate()) {
3847       case ICmpInst::ICMP_SLT:
3848       case ICmpInst::ICMP_SLE:
3849         std::swap(LHS, RHS);
3850         // fall through
3851       case ICmpInst::ICMP_SGT:
3852       case ICmpInst::ICMP_SGE:
3853         // a >s b ? a+x : b+x  ->  smax(a, b)+x
3854         // a >s b ? b+x : a+x  ->  smin(a, b)+x
3855         if (LHS->getType() == U->getType()) {
3856           const SCEV *LS = getSCEV(LHS);
3857           const SCEV *RS = getSCEV(RHS);
3858           const SCEV *LA = getSCEV(U->getOperand(1));
3859           const SCEV *RA = getSCEV(U->getOperand(2));
3860           const SCEV *LDiff = getMinusSCEV(LA, LS);
3861           const SCEV *RDiff = getMinusSCEV(RA, RS);
3862           if (LDiff == RDiff)
3863             return getAddExpr(getSMaxExpr(LS, RS), LDiff);
3864           LDiff = getMinusSCEV(LA, RS);
3865           RDiff = getMinusSCEV(RA, LS);
3866           if (LDiff == RDiff)
3867             return getAddExpr(getSMinExpr(LS, RS), LDiff);
3868         }
3869         break;
3870       case ICmpInst::ICMP_ULT:
3871       case ICmpInst::ICMP_ULE:
3872         std::swap(LHS, RHS);
3873         // fall through
3874       case ICmpInst::ICMP_UGT:
3875       case ICmpInst::ICMP_UGE:
3876         // a >u b ? a+x : b+x  ->  umax(a, b)+x
3877         // a >u b ? b+x : a+x  ->  umin(a, b)+x
3878         if (LHS->getType() == U->getType()) {
3879           const SCEV *LS = getSCEV(LHS);
3880           const SCEV *RS = getSCEV(RHS);
3881           const SCEV *LA = getSCEV(U->getOperand(1));
3882           const SCEV *RA = getSCEV(U->getOperand(2));
3883           const SCEV *LDiff = getMinusSCEV(LA, LS);
3884           const SCEV *RDiff = getMinusSCEV(RA, RS);
3885           if (LDiff == RDiff)
3886             return getAddExpr(getUMaxExpr(LS, RS), LDiff);
3887           LDiff = getMinusSCEV(LA, RS);
3888           RDiff = getMinusSCEV(RA, LS);
3889           if (LDiff == RDiff)
3890             return getAddExpr(getUMinExpr(LS, RS), LDiff);
3891         }
3892         break;
3893       case ICmpInst::ICMP_NE:
3894         // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x
3895         if (LHS->getType() == U->getType() &&
3896             isa<ConstantInt>(RHS) &&
3897             cast<ConstantInt>(RHS)->isZero()) {
3898           const SCEV *One = getConstant(LHS->getType(), 1);
3899           const SCEV *LS = getSCEV(LHS);
3900           const SCEV *LA = getSCEV(U->getOperand(1));
3901           const SCEV *RA = getSCEV(U->getOperand(2));
3902           const SCEV *LDiff = getMinusSCEV(LA, LS);
3903           const SCEV *RDiff = getMinusSCEV(RA, One);
3904           if (LDiff == RDiff)
3905             return getAddExpr(getUMaxExpr(One, LS), LDiff);
3906         }
3907         break;
3908       case ICmpInst::ICMP_EQ:
3909         // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x
3910         if (LHS->getType() == U->getType() &&
3911             isa<ConstantInt>(RHS) &&
3912             cast<ConstantInt>(RHS)->isZero()) {
3913           const SCEV *One = getConstant(LHS->getType(), 1);
3914           const SCEV *LS = getSCEV(LHS);
3915           const SCEV *LA = getSCEV(U->getOperand(1));
3916           const SCEV *RA = getSCEV(U->getOperand(2));
3917           const SCEV *LDiff = getMinusSCEV(LA, One);
3918           const SCEV *RDiff = getMinusSCEV(RA, LS);
3919           if (LDiff == RDiff)
3920             return getAddExpr(getUMaxExpr(One, LS), LDiff);
3921         }
3922         break;
3923       default:
3924         break;
3925       }
3926     }
3927
3928   default: // We cannot analyze this expression.
3929     break;
3930   }
3931
3932   return getUnknown(V);
3933 }
3934
3935
3936
3937 //===----------------------------------------------------------------------===//
3938 //                   Iteration Count Computation Code
3939 //
3940
3941 /// getSmallConstantTripCount - Returns the maximum trip count of this loop as a
3942 /// normal unsigned value, if possible. Returns 0 if the trip count is unknown
3943 /// or not constant. Will also return 0 if the maximum trip count is very large
3944 /// (>= 2^32)
3945 unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L,
3946                                                     BasicBlock *ExitBlock) {
3947   const SCEVConstant *ExitCount =
3948     dyn_cast<SCEVConstant>(getExitCount(L, ExitBlock));
3949   if (!ExitCount)
3950     return 0;
3951
3952   ConstantInt *ExitConst = ExitCount->getValue();
3953
3954   // Guard against huge trip counts.
3955   if (ExitConst->getValue().getActiveBits() > 32)
3956     return 0;
3957
3958   // In case of integer overflow, this returns 0, which is correct.
3959   return ((unsigned)ExitConst->getZExtValue()) + 1;
3960 }
3961
3962 /// getSmallConstantTripMultiple - Returns the largest constant divisor of the
3963 /// trip count of this loop as a normal unsigned value, if possible. This
3964 /// means that the actual trip count is always a multiple of the returned
3965 /// value (don't forget the trip count could very well be zero as well!).
3966 ///
3967 /// Returns 1 if the trip count is unknown or not guaranteed to be the
3968 /// multiple of a constant (which is also the case if the trip count is simply
3969 /// constant, use getSmallConstantTripCount for that case), Will also return 1
3970 /// if the trip count is very large (>= 2^32).
3971 unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L,
3972                                                        BasicBlock *ExitBlock) {
3973   const SCEV *ExitCount = getExitCount(L, ExitBlock);
3974   if (ExitCount == getCouldNotCompute())
3975     return 1;
3976
3977   // Get the trip count from the BE count by adding 1.
3978   const SCEV *TCMul = getAddExpr(ExitCount,
3979                                  getConstant(ExitCount->getType(), 1));
3980   // FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt
3981   // to factor simple cases.
3982   if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(TCMul))
3983     TCMul = Mul->getOperand(0);
3984
3985   const SCEVConstant *MulC = dyn_cast<SCEVConstant>(TCMul);
3986   if (!MulC)
3987     return 1;
3988
3989   ConstantInt *Result = MulC->getValue();
3990
3991   // Guard against huge trip counts.
3992   if (!Result || Result->getValue().getActiveBits() > 32)
3993     return 1;
3994
3995   return (unsigned)Result->getZExtValue();
3996 }
3997
3998 // getExitCount - Get the expression for the number of loop iterations for which
3999 // this loop is guaranteed not to exit via ExitintBlock. Otherwise return
4000 // SCEVCouldNotCompute.
4001 const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) {
4002   return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
4003 }
4004
4005 /// getBackedgeTakenCount - If the specified loop has a predictable
4006 /// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
4007 /// object. The backedge-taken count is the number of times the loop header
4008 /// will be branched to from within the loop. This is one less than the
4009 /// trip count of the loop, since it doesn't count the first iteration,
4010 /// when the header is branched to from outside the loop.
4011 ///
4012 /// Note that it is not valid to call this method on a loop without a
4013 /// loop-invariant backedge-taken count (see
4014 /// hasLoopInvariantBackedgeTakenCount).
4015 ///
4016 const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
4017   return getBackedgeTakenInfo(L).getExact(this);
4018 }
4019
4020 /// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except
4021 /// return the least SCEV value that is known never to be less than the
4022 /// actual backedge taken count.
4023 const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
4024   return getBackedgeTakenInfo(L).getMax(this);
4025 }
4026
4027 /// PushLoopPHIs - Push PHI nodes in the header of the given loop
4028 /// onto the given Worklist.
4029 static void
4030 PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
4031   BasicBlock *Header = L->getHeader();
4032
4033   // Push all Loop-header PHIs onto the Worklist stack.
4034   for (BasicBlock::iterator I = Header->begin();
4035        PHINode *PN = dyn_cast<PHINode>(I); ++I)
4036     Worklist.push_back(PN);
4037 }
4038
4039 const ScalarEvolution::BackedgeTakenInfo &
4040 ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
4041   // Initially insert an invalid entry for this loop. If the insertion
4042   // succeeds, proceed to actually compute a backedge-taken count and
4043   // update the value. The temporary CouldNotCompute value tells SCEV
4044   // code elsewhere that it shouldn't attempt to request a new
4045   // backedge-taken count, which could result in infinite recursion.
4046   std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
4047     BackedgeTakenCounts.insert(std::make_pair(L, BackedgeTakenInfo()));
4048   if (!Pair.second)
4049     return Pair.first->second;
4050
4051   // ComputeBackedgeTakenCount may allocate memory for its result. Inserting it
4052   // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
4053   // must be cleared in this scope.
4054   BackedgeTakenInfo Result = ComputeBackedgeTakenCount(L);
4055
4056   if (Result.getExact(this) != getCouldNotCompute()) {
4057     assert(isLoopInvariant(Result.getExact(this), L) &&
4058            isLoopInvariant(Result.getMax(this), L) &&
4059            "Computed backedge-taken count isn't loop invariant for loop!");
4060     ++NumTripCountsComputed;
4061   }
4062   else if (Result.getMax(this) == getCouldNotCompute() &&
4063            isa<PHINode>(L->getHeader()->begin())) {
4064     // Only count loops that have phi nodes as not being computable.
4065     ++NumTripCountsNotComputed;
4066   }
4067
4068   // Now that we know more about the trip count for this loop, forget any
4069   // existing SCEV values for PHI nodes in this loop since they are only
4070   // conservative estimates made without the benefit of trip count
4071   // information. This is similar to the code in forgetLoop, except that
4072   // it handles SCEVUnknown PHI nodes specially.
4073   if (Result.hasAnyInfo()) {
4074     SmallVector<Instruction *, 16> Worklist;
4075     PushLoopPHIs(L, Worklist);
4076
4077     SmallPtrSet<Instruction *, 8> Visited;
4078     while (!Worklist.empty()) {
4079       Instruction *I = Worklist.pop_back_val();
4080       if (!Visited.insert(I)) continue;
4081
4082       ValueExprMapType::iterator It =
4083         ValueExprMap.find(static_cast<Value *>(I));
4084       if (It != ValueExprMap.end()) {
4085         const SCEV *Old = It->second;
4086
4087         // SCEVUnknown for a PHI either means that it has an unrecognized
4088         // structure, or it's a PHI that's in the progress of being computed
4089         // by createNodeForPHI.  In the former case, additional loop trip
4090         // count information isn't going to change anything. In the later
4091         // case, createNodeForPHI will perform the necessary updates on its
4092         // own when it gets to that point.
4093         if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) {
4094           forgetMemoizedResults(Old);
4095           ValueExprMap.erase(It);
4096         }
4097         if (PHINode *PN = dyn_cast<PHINode>(I))
4098           ConstantEvolutionLoopExitValue.erase(PN);
4099       }
4100
4101       PushDefUseChildren(I, Worklist);
4102     }
4103   }
4104
4105   // Re-lookup the insert position, since the call to
4106   // ComputeBackedgeTakenCount above could result in a
4107   // recusive call to getBackedgeTakenInfo (on a different
4108   // loop), which would invalidate the iterator computed
4109   // earlier.
4110   return BackedgeTakenCounts.find(L)->second = Result;
4111 }
4112
4113 /// forgetLoop - This method should be called by the client when it has
4114 /// changed a loop in a way that may effect ScalarEvolution's ability to
4115 /// compute a trip count, or if the loop is deleted.
4116 void ScalarEvolution::forgetLoop(const Loop *L) {
4117   // Drop any stored trip count value.
4118   DenseMap<const Loop*, BackedgeTakenInfo>::iterator BTCPos =
4119     BackedgeTakenCounts.find(L);
4120   if (BTCPos != BackedgeTakenCounts.end()) {
4121     BTCPos->second.clear();
4122     BackedgeTakenCounts.erase(BTCPos);
4123   }
4124
4125   // Drop information about expressions based on loop-header PHIs.
4126   SmallVector<Instruction *, 16> Worklist;
4127   PushLoopPHIs(L, Worklist);
4128
4129   SmallPtrSet<Instruction *, 8> Visited;
4130   while (!Worklist.empty()) {
4131     Instruction *I = Worklist.pop_back_val();
4132     if (!Visited.insert(I)) continue;
4133
4134     ValueExprMapType::iterator It = ValueExprMap.find(static_cast<Value *>(I));
4135     if (It != ValueExprMap.end()) {
4136       forgetMemoizedResults(It->second);
4137       ValueExprMap.erase(It);
4138       if (PHINode *PN = dyn_cast<PHINode>(I))
4139         ConstantEvolutionLoopExitValue.erase(PN);
4140     }
4141
4142     PushDefUseChildren(I, Worklist);
4143   }
4144
4145   // Forget all contained loops too, to avoid dangling entries in the
4146   // ValuesAtScopes map.
4147   for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
4148     forgetLoop(*I);
4149 }
4150
4151 /// forgetValue - This method should be called by the client when it has
4152 /// changed a value in a way that may effect its value, or which may
4153 /// disconnect it from a def-use chain linking it to a loop.
4154 void ScalarEvolution::forgetValue(Value *V) {
4155   Instruction *I = dyn_cast<Instruction>(V);
4156   if (!I) return;
4157
4158   // Drop information about expressions based on loop-header PHIs.
4159   SmallVector<Instruction *, 16> Worklist;
4160   Worklist.push_back(I);
4161
4162   SmallPtrSet<Instruction *, 8> Visited;
4163   while (!Worklist.empty()) {
4164     I = Worklist.pop_back_val();
4165     if (!Visited.insert(I)) continue;
4166
4167     ValueExprMapType::iterator It = ValueExprMap.find(static_cast<Value *>(I));
4168     if (It != ValueExprMap.end()) {
4169       forgetMemoizedResults(It->second);
4170       ValueExprMap.erase(It);
4171       if (PHINode *PN = dyn_cast<PHINode>(I))
4172         ConstantEvolutionLoopExitValue.erase(PN);
4173     }
4174
4175     PushDefUseChildren(I, Worklist);
4176   }
4177 }
4178
4179 /// getExact - Get the exact loop backedge taken count considering all loop
4180 /// exits. A computable result can only be return for loops with a single exit.
4181 /// Returning the minimum taken count among all exits is incorrect because one
4182 /// of the loop's exit limit's may have been skipped. HowFarToZero assumes that
4183 /// the limit of each loop test is never skipped. This is a valid assumption as
4184 /// long as the loop exits via that test. For precise results, it is the
4185 /// caller's responsibility to specify the relevant loop exit using
4186 /// getExact(ExitingBlock, SE).
4187 const SCEV *
4188 ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const {
4189   // If any exits were not computable, the loop is not computable.
4190   if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute();
4191
4192   // We need exactly one computable exit.
4193   if (!ExitNotTaken.ExitingBlock) return SE->getCouldNotCompute();
4194   assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info");
4195
4196   const SCEV *BECount = 0;
4197   for (const ExitNotTakenInfo *ENT = &ExitNotTaken;
4198        ENT != 0; ENT = ENT->getNextExit()) {
4199
4200     assert(ENT->ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV");
4201
4202     if (!BECount)
4203       BECount = ENT->ExactNotTaken;
4204     else if (BECount != ENT->ExactNotTaken)
4205       return SE->getCouldNotCompute();
4206   }
4207   assert(BECount && "Invalid not taken count for loop exit");
4208   return BECount;
4209 }
4210
4211 /// getExact - Get the exact not taken count for this loop exit.
4212 const SCEV *
4213 ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock,
4214                                              ScalarEvolution *SE) const {
4215   for (const ExitNotTakenInfo *ENT = &ExitNotTaken;
4216        ENT != 0; ENT = ENT->getNextExit()) {
4217
4218     if (ENT->ExitingBlock == ExitingBlock)
4219       return ENT->ExactNotTaken;
4220   }
4221   return SE->getCouldNotCompute();
4222 }
4223
4224 /// getMax - Get the max backedge taken count for the loop.
4225 const SCEV *
4226 ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const {
4227   return Max ? Max : SE->getCouldNotCompute();
4228 }
4229
4230 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
4231 /// computable exit into a persistent ExitNotTakenInfo array.
4232 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
4233   SmallVectorImpl< std::pair<BasicBlock *, const SCEV *> > &ExitCounts,
4234   bool Complete, const SCEV *MaxCount) : Max(MaxCount) {
4235
4236   if (!Complete)
4237     ExitNotTaken.setIncomplete();
4238
4239   unsigned NumExits = ExitCounts.size();
4240   if (NumExits == 0) return;
4241
4242   ExitNotTaken.ExitingBlock = ExitCounts[0].first;
4243   ExitNotTaken.ExactNotTaken = ExitCounts[0].second;
4244   if (NumExits == 1) return;
4245
4246   // Handle the rare case of multiple computable exits.
4247   ExitNotTakenInfo *ENT = new ExitNotTakenInfo[NumExits-1];
4248
4249   ExitNotTakenInfo *PrevENT = &ExitNotTaken;
4250   for (unsigned i = 1; i < NumExits; ++i, PrevENT = ENT, ++ENT) {
4251     PrevENT->setNextExit(ENT);
4252     ENT->ExitingBlock = ExitCounts[i].first;
4253     ENT->ExactNotTaken = ExitCounts[i].second;
4254   }
4255 }
4256
4257 /// clear - Invalidate this result and free the ExitNotTakenInfo array.
4258 void ScalarEvolution::BackedgeTakenInfo::clear() {
4259   ExitNotTaken.ExitingBlock = 0;
4260   ExitNotTaken.ExactNotTaken = 0;
4261   delete[] ExitNotTaken.getNextExit();
4262 }
4263
4264 /// ComputeBackedgeTakenCount - Compute the number of times the backedge
4265 /// of the specified loop will execute.
4266 ScalarEvolution::BackedgeTakenInfo
4267 ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
4268   SmallVector<BasicBlock *, 8> ExitingBlocks;
4269   L->getExitingBlocks(ExitingBlocks);
4270
4271   // Examine all exits and pick the most conservative values.
4272   const SCEV *MaxBECount = getCouldNotCompute();
4273   bool CouldComputeBECount = true;
4274   SmallVector<std::pair<BasicBlock *, const SCEV *>, 4> ExitCounts;
4275   for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
4276     ExitLimit EL = ComputeExitLimit(L, ExitingBlocks[i]);
4277     if (EL.Exact == getCouldNotCompute())
4278       // We couldn't compute an exact value for this exit, so
4279       // we won't be able to compute an exact value for the loop.
4280       CouldComputeBECount = false;
4281     else
4282       ExitCounts.push_back(std::make_pair(ExitingBlocks[i], EL.Exact));
4283
4284     if (MaxBECount == getCouldNotCompute())
4285       MaxBECount = EL.Max;
4286     else if (EL.Max != getCouldNotCompute()) {
4287       // We cannot take the "min" MaxBECount, because non-unit stride loops may
4288       // skip some loop tests. Taking the max over the exits is sufficiently
4289       // conservative.  TODO: We could do better taking into consideration
4290       // that (1) the loop has unit stride (2) the last loop test is
4291       // less-than/greater-than (3) any loop test is less-than/greater-than AND
4292       // falls-through some constant times less then the other tests.
4293       MaxBECount = getUMaxFromMismatchedTypes(MaxBECount, EL.Max);
4294     }
4295   }
4296
4297   return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount);
4298 }
4299
4300 /// ComputeExitLimit - Compute the number of times the backedge of the specified
4301 /// loop will execute if it exits via the specified block.
4302 ScalarEvolution::ExitLimit
4303 ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) {
4304
4305   // Okay, we've chosen an exiting block.  See what condition causes us to
4306   // exit at this block.
4307   //
4308   // FIXME: we should be able to handle switch instructions (with a single exit)
4309   BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
4310   if (ExitBr == 0) return getCouldNotCompute();
4311   assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
4312
4313   // At this point, we know we have a conditional branch that determines whether
4314   // the loop is exited.  However, we don't know if the branch is executed each
4315   // time through the loop.  If not, then the execution count of the branch will
4316   // not be equal to the trip count of the loop.
4317   //
4318   // Currently we check for this by checking to see if the Exit branch goes to
4319   // the loop header.  If so, we know it will always execute the same number of
4320   // times as the loop.  We also handle the case where the exit block *is* the
4321   // loop header.  This is common for un-rotated loops.
4322   //
4323   // If both of those tests fail, walk up the unique predecessor chain to the
4324   // header, stopping if there is an edge that doesn't exit the loop. If the
4325   // header is reached, the execution count of the branch will be equal to the
4326   // trip count of the loop.
4327   //
4328   //  More extensive analysis could be done to handle more cases here.
4329   //
4330   if (ExitBr->getSuccessor(0) != L->getHeader() &&
4331       ExitBr->getSuccessor(1) != L->getHeader() &&
4332       ExitBr->getParent() != L->getHeader()) {
4333     // The simple checks failed, try climbing the unique predecessor chain
4334     // up to the header.
4335     bool Ok = false;
4336     for (BasicBlock *BB = ExitBr->getParent(); BB; ) {
4337       BasicBlock *Pred = BB->getUniquePredecessor();
4338       if (!Pred)
4339         return getCouldNotCompute();
4340       TerminatorInst *PredTerm = Pred->getTerminator();
4341       for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) {
4342         BasicBlock *PredSucc = PredTerm->getSuccessor(i);
4343         if (PredSucc == BB)
4344           continue;
4345         // If the predecessor has a successor that isn't BB and isn't
4346         // outside the loop, assume the worst.
4347         if (L->contains(PredSucc))
4348           return getCouldNotCompute();
4349       }
4350       if (Pred == L->getHeader()) {
4351         Ok = true;
4352         break;
4353       }
4354       BB = Pred;
4355     }
4356     if (!Ok)
4357       return getCouldNotCompute();
4358   }
4359
4360   // Proceed to the next level to examine the exit condition expression.
4361   return ComputeExitLimitFromCond(L, ExitBr->getCondition(),
4362                                   ExitBr->getSuccessor(0),
4363                                   ExitBr->getSuccessor(1));
4364 }
4365
4366 /// ComputeExitLimitFromCond - Compute the number of times the
4367 /// backedge of the specified loop will execute if its exit condition
4368 /// were a conditional branch of ExitCond, TBB, and FBB.
4369 ScalarEvolution::ExitLimit
4370 ScalarEvolution::ComputeExitLimitFromCond(const Loop *L,
4371                                           Value *ExitCond,
4372                                           BasicBlock *TBB,
4373                                           BasicBlock *FBB) {
4374   // Check if the controlling expression for this loop is an And or Or.
4375   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
4376     if (BO->getOpcode() == Instruction::And) {
4377       // Recurse on the operands of the and.
4378       ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB);
4379       ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB);
4380       const SCEV *BECount = getCouldNotCompute();
4381       const SCEV *MaxBECount = getCouldNotCompute();
4382       if (L->contains(TBB)) {
4383         // Both conditions must be true for the loop to continue executing.
4384         // Choose the less conservative count.
4385         if (EL0.Exact == getCouldNotCompute() ||
4386             EL1.Exact == getCouldNotCompute())
4387           BECount = getCouldNotCompute();
4388         else
4389           BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact);
4390         if (EL0.Max == getCouldNotCompute())
4391           MaxBECount = EL1.Max;
4392         else if (EL1.Max == getCouldNotCompute())
4393           MaxBECount = EL0.Max;
4394         else
4395           MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max);
4396       } else {
4397         // Both conditions must be true at the same time for the loop to exit.
4398         // For now, be conservative.
4399         assert(L->contains(FBB) && "Loop block has no successor in loop!");
4400         if (EL0.Max == EL1.Max)
4401           MaxBECount = EL0.Max;
4402         if (EL0.Exact == EL1.Exact)
4403           BECount = EL0.Exact;
4404       }
4405
4406       return ExitLimit(BECount, MaxBECount);
4407     }
4408     if (BO->getOpcode() == Instruction::Or) {
4409       // Recurse on the operands of the or.
4410       ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB);
4411       ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB);
4412       const SCEV *BECount = getCouldNotCompute();
4413       const SCEV *MaxBECount = getCouldNotCompute();
4414       if (L->contains(FBB)) {
4415         // Both conditions must be false for the loop to continue executing.
4416         // Choose the less conservative count.
4417         if (EL0.Exact == getCouldNotCompute() ||
4418             EL1.Exact == getCouldNotCompute())
4419           BECount = getCouldNotCompute();
4420         else
4421           BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact);
4422         if (EL0.Max == getCouldNotCompute())
4423           MaxBECount = EL1.Max;
4424         else if (EL1.Max == getCouldNotCompute())
4425           MaxBECount = EL0.Max;
4426         else
4427           MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max);
4428       } else {
4429         // Both conditions must be false at the same time for the loop to exit.
4430         // For now, be conservative.
4431         assert(L->contains(TBB) && "Loop block has no successor in loop!");
4432         if (EL0.Max == EL1.Max)
4433           MaxBECount = EL0.Max;
4434         if (EL0.Exact == EL1.Exact)
4435           BECount = EL0.Exact;
4436       }
4437
4438       return ExitLimit(BECount, MaxBECount);
4439     }
4440   }
4441
4442   // With an icmp, it may be feasible to compute an exact backedge-taken count.
4443   // Proceed to the next level to examine the icmp.
4444   if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
4445     return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB);
4446
4447   // Check for a constant condition. These are normally stripped out by
4448   // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
4449   // preserve the CFG and is temporarily leaving constant conditions
4450   // in place.
4451   if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
4452     if (L->contains(FBB) == !CI->getZExtValue())
4453       // The backedge is always taken.
4454       return getCouldNotCompute();
4455     else
4456       // The backedge is never taken.
4457       return getConstant(CI->getType(), 0);
4458   }
4459
4460   // If it's not an integer or pointer comparison then compute it the hard way.
4461   return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
4462 }
4463
4464 /// ComputeExitLimitFromICmp - Compute the number of times the
4465 /// backedge of the specified loop will execute if its exit condition
4466 /// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB.
4467 ScalarEvolution::ExitLimit
4468 ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L,
4469                                           ICmpInst *ExitCond,
4470                                           BasicBlock *TBB,
4471                                           BasicBlock *FBB) {
4472
4473   // If the condition was exit on true, convert the condition to exit on false
4474   ICmpInst::Predicate Cond;
4475   if (!L->contains(FBB))
4476     Cond = ExitCond->getPredicate();
4477   else
4478     Cond = ExitCond->getInversePredicate();
4479
4480   // Handle common loops like: for (X = "string"; *X; ++X)
4481   if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
4482     if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
4483       ExitLimit ItCnt =
4484         ComputeLoadConstantCompareExitLimit(LI, RHS, L, Cond);
4485       if (ItCnt.hasAnyInfo())
4486         return ItCnt;
4487     }
4488
4489   const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
4490   const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
4491
4492   // Try to evaluate any dependencies out of the loop.
4493   LHS = getSCEVAtScope(LHS, L);
4494   RHS = getSCEVAtScope(RHS, L);
4495
4496   // At this point, we would like to compute how many iterations of the
4497   // loop the predicate will return true for these inputs.
4498   if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
4499     // If there is a loop-invariant, force it into the RHS.
4500     std::swap(LHS, RHS);
4501     Cond = ICmpInst::getSwappedPredicate(Cond);
4502   }
4503
4504   // Simplify the operands before analyzing them.
4505   (void)SimplifyICmpOperands(Cond, LHS, RHS);
4506
4507   // If we have a comparison of a chrec against a constant, try to use value
4508   // ranges to answer this query.
4509   if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
4510     if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
4511       if (AddRec->getLoop() == L) {
4512         // Form the constant range.
4513         ConstantRange CompRange(
4514             ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue()));
4515
4516         const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
4517         if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
4518       }
4519
4520   switch (Cond) {
4521   case ICmpInst::ICMP_NE: {                     // while (X != Y)
4522     // Convert to: while (X-Y != 0)
4523     ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L);
4524     if (EL.hasAnyInfo()) return EL;
4525     break;
4526   }
4527   case ICmpInst::ICMP_EQ: {                     // while (X == Y)
4528     // Convert to: while (X-Y == 0)
4529     ExitLimit EL = HowFarToNonZero(getMinusSCEV(LHS, RHS), L);
4530     if (EL.hasAnyInfo()) return EL;
4531     break;
4532   }
4533   case ICmpInst::ICMP_SLT: {
4534     ExitLimit EL = HowManyLessThans(LHS, RHS, L, true);
4535     if (EL.hasAnyInfo()) return EL;
4536     break;
4537   }
4538   case ICmpInst::ICMP_SGT: {
4539     ExitLimit EL = HowManyLessThans(getNotSCEV(LHS),
4540                                              getNotSCEV(RHS), L, true);
4541     if (EL.hasAnyInfo()) return EL;
4542     break;
4543   }
4544   case ICmpInst::ICMP_ULT: {
4545     ExitLimit EL = HowManyLessThans(LHS, RHS, L, false);
4546     if (EL.hasAnyInfo()) return EL;
4547     break;
4548   }
4549   case ICmpInst::ICMP_UGT: {
4550     ExitLimit EL = HowManyLessThans(getNotSCEV(LHS),
4551                                              getNotSCEV(RHS), L, false);
4552     if (EL.hasAnyInfo()) return EL;
4553     break;
4554   }
4555   default:
4556 #if 0
4557     dbgs() << "ComputeBackedgeTakenCount ";
4558     if (ExitCond->getOperand(0)->getType()->isUnsigned())
4559       dbgs() << "[unsigned] ";
4560     dbgs() << *LHS << "   "
4561          << Instruction::getOpcodeName(Instruction::ICmp)
4562          << "   " << *RHS << "\n";
4563 #endif
4564     break;
4565   }
4566   return ComputeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
4567 }
4568
4569 static ConstantInt *
4570 EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
4571                                 ScalarEvolution &SE) {
4572   const SCEV *InVal = SE.getConstant(C);
4573   const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
4574   assert(isa<SCEVConstant>(Val) &&
4575          "Evaluation of SCEV at constant didn't fold correctly?");
4576   return cast<SCEVConstant>(Val)->getValue();
4577 }
4578
4579 /// GetAddressedElementFromGlobal - Given a global variable with an initializer
4580 /// and a GEP expression (missing the pointer index) indexing into it, return
4581 /// the addressed element of the initializer or null if the index expression is
4582 /// invalid.
4583 static Constant *
4584 GetAddressedElementFromGlobal(GlobalVariable *GV,
4585                               const std::vector<ConstantInt*> &Indices) {
4586   Constant *Init = GV->getInitializer();
4587   for (unsigned i = 0, e = Indices.size(); i != e; ++i) {
4588     uint64_t Idx = Indices[i]->getZExtValue();
4589     if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Init)) {
4590       assert(Idx < CS->getNumOperands() && "Bad struct index!");
4591       Init = cast<Constant>(CS->getOperand(Idx));
4592     } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Init)) {
4593       if (Idx >= CA->getNumOperands()) return 0;  // Bogus program
4594       Init = cast<Constant>(CA->getOperand(Idx));
4595     } else if (isa<ConstantAggregateZero>(Init)) {
4596       if (StructType *STy = dyn_cast<StructType>(Init->getType())) {
4597         assert(Idx < STy->getNumElements() && "Bad struct index!");
4598         Init = Constant::getNullValue(STy->getElementType(Idx));
4599       } else if (ArrayType *ATy = dyn_cast<ArrayType>(Init->getType())) {
4600         if (Idx >= ATy->getNumElements()) return 0;  // Bogus program
4601         Init = Constant::getNullValue(ATy->getElementType());
4602       } else {
4603         llvm_unreachable("Unknown constant aggregate type!");
4604       }
4605       return 0;
4606     } else {
4607       return 0; // Unknown initializer type
4608     }
4609   }
4610   return Init;
4611 }
4612
4613 /// ComputeLoadConstantCompareExitLimit - Given an exit condition of
4614 /// 'icmp op load X, cst', try to see if we can compute the backedge
4615 /// execution count.
4616 ScalarEvolution::ExitLimit
4617 ScalarEvolution::ComputeLoadConstantCompareExitLimit(
4618   LoadInst *LI,
4619   Constant *RHS,
4620   const Loop *L,
4621   ICmpInst::Predicate predicate) {
4622
4623   if (LI->isVolatile()) return getCouldNotCompute();
4624
4625   // Check to see if the loaded pointer is a getelementptr of a global.
4626   // TODO: Use SCEV instead of manually grubbing with GEPs.
4627   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
4628   if (!GEP) return getCouldNotCompute();
4629
4630   // Make sure that it is really a constant global we are gepping, with an
4631   // initializer, and make sure the first IDX is really 0.
4632   GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
4633   if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() ||
4634       GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
4635       !cast<Constant>(GEP->getOperand(1))->isNullValue())
4636     return getCouldNotCompute();
4637
4638   // Okay, we allow one non-constant index into the GEP instruction.
4639   Value *VarIdx = 0;
4640   std::vector<ConstantInt*> Indexes;
4641   unsigned VarIdxNum = 0;
4642   for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
4643     if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
4644       Indexes.push_back(CI);
4645     } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
4646       if (VarIdx) return getCouldNotCompute();  // Multiple non-constant idx's.
4647       VarIdx = GEP->getOperand(i);
4648       VarIdxNum = i-2;
4649       Indexes.push_back(0);
4650     }
4651
4652   // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
4653   // Check to see if X is a loop variant variable value now.
4654   const SCEV *Idx = getSCEV(VarIdx);
4655   Idx = getSCEVAtScope(Idx, L);
4656
4657   // We can only recognize very limited forms of loop index expressions, in
4658   // particular, only affine AddRec's like {C1,+,C2}.
4659   const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
4660   if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) ||
4661       !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
4662       !isa<SCEVConstant>(IdxExpr->getOperand(1)))
4663     return getCouldNotCompute();
4664
4665   unsigned MaxSteps = MaxBruteForceIterations;
4666   for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
4667     ConstantInt *ItCst = ConstantInt::get(
4668                            cast<IntegerType>(IdxExpr->getType()), IterationNum);
4669     ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
4670
4671     // Form the GEP offset.
4672     Indexes[VarIdxNum] = Val;
4673
4674     Constant *Result = GetAddressedElementFromGlobal(GV, Indexes);
4675     if (Result == 0) break;  // Cannot compute!
4676
4677     // Evaluate the condition for this iteration.
4678     Result = ConstantExpr::getICmp(predicate, Result, RHS);
4679     if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
4680     if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
4681 #if 0
4682       dbgs() << "\n***\n*** Computed loop count " << *ItCst
4683              << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
4684              << "***\n";
4685 #endif
4686       ++NumArrayLenItCounts;
4687       return getConstant(ItCst);   // Found terminating iteration!
4688     }
4689   }
4690   return getCouldNotCompute();
4691 }
4692
4693
4694 /// CanConstantFold - Return true if we can constant fold an instruction of the
4695 /// specified type, assuming that all operands were constants.
4696 static bool CanConstantFold(const Instruction *I) {
4697   if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
4698       isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
4699       isa<LoadInst>(I))
4700     return true;
4701
4702   if (const CallInst *CI = dyn_cast<CallInst>(I))
4703     if (const Function *F = CI->getCalledFunction())
4704       return canConstantFoldCallTo(F);
4705   return false;
4706 }
4707
4708 /// Determine whether this instruction can constant evolve within this loop
4709 /// assuming its operands can all constant evolve.
4710 static bool canConstantEvolve(Instruction *I, const Loop *L) {
4711   // An instruction outside of the loop can't be derived from a loop PHI.
4712   if (!L->contains(I)) return false;
4713
4714   if (isa<PHINode>(I)) {
4715     if (L->getHeader() == I->getParent())
4716       return true;
4717     else
4718       // We don't currently keep track of the control flow needed to evaluate
4719       // PHIs, so we cannot handle PHIs inside of loops.
4720       return false;
4721   }
4722
4723   // If we won't be able to constant fold this expression even if the operands
4724   // are constants, bail early.
4725   return CanConstantFold(I);
4726 }
4727
4728 /// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
4729 /// recursing through each instruction operand until reaching a loop header phi.
4730 static PHINode *
4731 getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L,
4732                                DenseMap<Instruction *, PHINode *> &PHIMap) {
4733
4734   // Otherwise, we can evaluate this instruction if all of its operands are
4735   // constant or derived from a PHI node themselves.
4736   PHINode *PHI = 0;
4737   for (Instruction::op_iterator OpI = UseInst->op_begin(),
4738          OpE = UseInst->op_end(); OpI != OpE; ++OpI) {
4739
4740     if (isa<Constant>(*OpI)) continue;
4741
4742     Instruction *OpInst = dyn_cast<Instruction>(*OpI);
4743     if (!OpInst || !canConstantEvolve(OpInst, L)) return 0;
4744
4745     PHINode *P = dyn_cast<PHINode>(OpInst);
4746     if (!P)
4747       // If this operand is already visited, reuse the prior result.
4748       // We may have P != PHI if this is the deepest point at which the
4749       // inconsistent paths meet.
4750       P = PHIMap.lookup(OpInst);
4751     if (!P) {
4752       // Recurse and memoize the results, whether a phi is found or not.
4753       // This recursive call invalidates pointers into PHIMap.
4754       P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap);
4755       PHIMap[OpInst] = P;
4756     }
4757     if (P == 0) return 0;        // Not evolving from PHI
4758     if (PHI && PHI != P) return 0;  // Evolving from multiple different PHIs.
4759     PHI = P;
4760   }
4761   // This is a expression evolving from a constant PHI!
4762   return PHI;
4763 }
4764
4765 /// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
4766 /// in the loop that V is derived from.  We allow arbitrary operations along the
4767 /// way, but the operands of an operation must either be constants or a value
4768 /// derived from a constant PHI.  If this expression does not fit with these
4769 /// constraints, return null.
4770 static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
4771   Instruction *I = dyn_cast<Instruction>(V);
4772   if (I == 0 || !canConstantEvolve(I, L)) return 0;
4773
4774   if (PHINode *PN = dyn_cast<PHINode>(I)) {
4775     return PN;
4776   }
4777
4778   // Record non-constant instructions contained by the loop.
4779   DenseMap<Instruction *, PHINode *> PHIMap;
4780   return getConstantEvolvingPHIOperands(I, L, PHIMap);
4781 }
4782
4783 /// EvaluateExpression - Given an expression that passes the
4784 /// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
4785 /// in the loop has the value PHIVal.  If we can't fold this expression for some
4786 /// reason, return null.
4787 static Constant *EvaluateExpression(Value *V, const Loop *L,
4788                                     DenseMap<Instruction *, Constant *> &Vals,
4789                                     const TargetData *TD,
4790                                     const TargetLibraryInfo *TLI) {
4791   // Convenient constant check, but redundant for recursive calls.
4792   if (Constant *C = dyn_cast<Constant>(V)) return C;
4793   Instruction *I = dyn_cast<Instruction>(V);
4794   if (!I) return 0;
4795
4796   if (Constant *C = Vals.lookup(I)) return C;
4797
4798   // An instruction inside the loop depends on a value outside the loop that we
4799   // weren't given a mapping for, or a value such as a call inside the loop.
4800   if (!canConstantEvolve(I, L)) return 0;
4801
4802   // An unmapped PHI can be due to a branch or another loop inside this loop,
4803   // or due to this not being the initial iteration through a loop where we
4804   // couldn't compute the evolution of this particular PHI last time.
4805   if (isa<PHINode>(I)) return 0;
4806
4807   std::vector<Constant*> Operands(I->getNumOperands());
4808
4809   for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
4810     Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
4811     if (!Operand) {
4812       Operands[i] = dyn_cast<Constant>(I->getOperand(i));
4813       if (!Operands[i]) return 0;
4814       continue;
4815     }
4816     Constant *C = EvaluateExpression(Operand, L, Vals, TD, TLI);
4817     Vals[Operand] = C;
4818     if (!C) return 0;
4819     Operands[i] = C;
4820   }
4821
4822   if (CmpInst *CI = dyn_cast<CmpInst>(I))
4823     return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
4824                                            Operands[1], TD, TLI);
4825   if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
4826     if (!LI->isVolatile())
4827       return ConstantFoldLoadFromConstPtr(Operands[0], TD);
4828   }
4829   return ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, TD,
4830                                   TLI);
4831 }
4832
4833 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
4834 /// in the header of its containing loop, we know the loop executes a
4835 /// constant number of times, and the PHI node is just a recurrence
4836 /// involving constants, fold it.
4837 Constant *
4838 ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
4839                                                    const APInt &BEs,
4840                                                    const Loop *L) {
4841   DenseMap<PHINode*, Constant*>::const_iterator I =
4842     ConstantEvolutionLoopExitValue.find(PN);
4843   if (I != ConstantEvolutionLoopExitValue.end())
4844     return I->second;
4845
4846   if (BEs.ugt(MaxBruteForceIterations))
4847     return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
4848
4849   Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
4850
4851   DenseMap<Instruction *, Constant *> CurrentIterVals;
4852   BasicBlock *Header = L->getHeader();
4853   assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
4854
4855   // Since the loop is canonicalized, the PHI node must have two entries.  One
4856   // entry must be a constant (coming in from outside of the loop), and the
4857   // second must be derived from the same PHI.
4858   bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
4859   PHINode *PHI = 0;
4860   for (BasicBlock::iterator I = Header->begin();
4861        (PHI = dyn_cast<PHINode>(I)); ++I) {
4862     Constant *StartCST =
4863       dyn_cast<Constant>(PHI->getIncomingValue(!SecondIsBackedge));
4864     if (StartCST == 0) continue;
4865     CurrentIterVals[PHI] = StartCST;
4866   }
4867   if (!CurrentIterVals.count(PN))
4868     return RetVal = 0;
4869
4870   Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
4871
4872   // Execute the loop symbolically to determine the exit value.
4873   if (BEs.getActiveBits() >= 32)
4874     return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
4875
4876   unsigned NumIterations = BEs.getZExtValue(); // must be in range
4877   unsigned IterationNum = 0;
4878   for (; ; ++IterationNum) {
4879     if (IterationNum == NumIterations)
4880       return RetVal = CurrentIterVals[PN];  // Got exit value!
4881
4882     // Compute the value of the PHIs for the next iteration.
4883     // EvaluateExpression adds non-phi values to the CurrentIterVals map.
4884     DenseMap<Instruction *, Constant *> NextIterVals;
4885     Constant *NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, TD,
4886                                            TLI);
4887     if (NextPHI == 0)
4888       return 0;        // Couldn't evaluate!
4889     NextIterVals[PN] = NextPHI;
4890
4891     bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
4892
4893     // Also evaluate the other PHI nodes.  However, we don't get to stop if we
4894     // cease to be able to evaluate one of them or if they stop evolving,
4895     // because that doesn't necessarily prevent us from computing PN.
4896     SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute;
4897     for (DenseMap<Instruction *, Constant *>::const_iterator
4898            I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){
4899       PHINode *PHI = dyn_cast<PHINode>(I->first);
4900       if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
4901       PHIsToCompute.push_back(std::make_pair(PHI, I->second));
4902     }
4903     // We use two distinct loops because EvaluateExpression may invalidate any
4904     // iterators into CurrentIterVals.
4905     for (SmallVectorImpl<std::pair<PHINode *, Constant*> >::const_iterator
4906              I = PHIsToCompute.begin(), E = PHIsToCompute.end(); I != E; ++I) {
4907       PHINode *PHI = I->first;
4908       Constant *&NextPHI = NextIterVals[PHI];
4909       if (!NextPHI) {   // Not already computed.
4910         Value *BEValue = PHI->getIncomingValue(SecondIsBackedge);
4911         NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, TD, TLI);
4912       }
4913       if (NextPHI != I->second)
4914         StoppedEvolving = false;
4915     }
4916
4917     // If all entries in CurrentIterVals == NextIterVals then we can stop
4918     // iterating, the loop can't continue to change.
4919     if (StoppedEvolving)
4920       return RetVal = CurrentIterVals[PN];
4921
4922     CurrentIterVals.swap(NextIterVals);
4923   }
4924 }
4925
4926 /// ComputeExitCountExhaustively - If the loop is known to execute a
4927 /// constant number of times (the condition evolves only from constants),
4928 /// try to evaluate a few iterations of the loop until we get the exit
4929 /// condition gets a value of ExitWhen (true or false).  If we cannot
4930 /// evaluate the trip count of the loop, return getCouldNotCompute().
4931 const SCEV *ScalarEvolution::ComputeExitCountExhaustively(const Loop *L,
4932                                                           Value *Cond,
4933                                                           bool ExitWhen) {
4934   PHINode *PN = getConstantEvolvingPHI(Cond, L);
4935   if (PN == 0) return getCouldNotCompute();
4936
4937   // If the loop is canonicalized, the PHI will have exactly two entries.
4938   // That's the only form we support here.
4939   if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
4940
4941   DenseMap<Instruction *, Constant *> CurrentIterVals;
4942   BasicBlock *Header = L->getHeader();
4943   assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
4944
4945   // One entry must be a constant (coming in from outside of the loop), and the
4946   // second must be derived from the same PHI.
4947   bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
4948   PHINode *PHI = 0;
4949   for (BasicBlock::iterator I = Header->begin();
4950        (PHI = dyn_cast<PHINode>(I)); ++I) {
4951     Constant *StartCST =
4952       dyn_cast<Constant>(PHI->getIncomingValue(!SecondIsBackedge));
4953     if (StartCST == 0) continue;
4954     CurrentIterVals[PHI] = StartCST;
4955   }
4956   if (!CurrentIterVals.count(PN))
4957     return getCouldNotCompute();
4958
4959   // Okay, we find a PHI node that defines the trip count of this loop.  Execute
4960   // the loop symbolically to determine when the condition gets a value of
4961   // "ExitWhen".
4962
4963   unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
4964   for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
4965     ConstantInt *CondVal =
4966       dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, L, CurrentIterVals,
4967                                                        TD, TLI));
4968
4969     // Couldn't symbolically evaluate.
4970     if (!CondVal) return getCouldNotCompute();
4971
4972     if (CondVal->getValue() == uint64_t(ExitWhen)) {
4973       ++NumBruteForceTripCountsComputed;
4974       return getConstant(Type::getInt32Ty(getContext()), IterationNum);
4975     }
4976
4977     // Update all the PHI nodes for the next iteration.
4978     DenseMap<Instruction *, Constant *> NextIterVals;
4979
4980     // Create a list of which PHIs we need to compute. We want to do this before
4981     // calling EvaluateExpression on them because that may invalidate iterators
4982     // into CurrentIterVals.
4983     SmallVector<PHINode *, 8> PHIsToCompute;
4984     for (DenseMap<Instruction *, Constant *>::const_iterator
4985            I = CurrentIterVals.begin(), E = CurrentIterVals.end(); I != E; ++I){
4986       PHINode *PHI = dyn_cast<PHINode>(I->first);
4987       if (!PHI || PHI->getParent() != Header) continue;
4988       PHIsToCompute.push_back(PHI);
4989     }
4990     for (SmallVectorImpl<PHINode *>::const_iterator I = PHIsToCompute.begin(),
4991              E = PHIsToCompute.end(); I != E; ++I) {
4992       PHINode *PHI = *I;
4993       Constant *&NextPHI = NextIterVals[PHI];
4994       if (NextPHI) continue;    // Already computed!
4995
4996       Value *BEValue = PHI->getIncomingValue(SecondIsBackedge);
4997       NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, TD, TLI);
4998     }
4999     CurrentIterVals.swap(NextIterVals);
5000   }
5001
5002   // Too many iterations were needed to evaluate.
5003   return getCouldNotCompute();
5004 }
5005
5006 /// getSCEVAtScope - Return a SCEV expression for the specified value
5007 /// at the specified scope in the program.  The L value specifies a loop
5008 /// nest to evaluate the expression at, where null is the top-level or a
5009 /// specified loop is immediately inside of the loop.
5010 ///
5011 /// This method can be used to compute the exit value for a variable defined
5012 /// in a loop by querying what the value will hold in the parent loop.
5013 ///
5014 /// In the case that a relevant loop exit value cannot be computed, the
5015 /// original value V is returned.
5016 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
5017   // Check to see if we've folded this expression at this loop before.
5018   std::map<const Loop *, const SCEV *> &Values = ValuesAtScopes[V];
5019   std::pair<std::map<const Loop *, const SCEV *>::iterator, bool> Pair =
5020     Values.insert(std::make_pair(L, static_cast<const SCEV *>(0)));
5021   if (!Pair.second)
5022     return Pair.first->second ? Pair.first->second : V;
5023
5024   // Otherwise compute it.
5025   const SCEV *C = computeSCEVAtScope(V, L);
5026   ValuesAtScopes[V][L] = C;
5027   return C;
5028 }
5029
5030 /// This builds up a Constant using the ConstantExpr interface.  That way, we
5031 /// will return Constants for objects which aren't represented by a
5032 /// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
5033 /// Returns NULL if the SCEV isn't representable as a Constant.
5034 static Constant *BuildConstantFromSCEV(const SCEV *V) {
5035   switch (V->getSCEVType()) {
5036     default:  // TODO: smax, umax.
5037     case scCouldNotCompute:
5038     case scAddRecExpr:
5039       break;
5040     case scConstant:
5041       return cast<SCEVConstant>(V)->getValue();
5042     case scUnknown:
5043       return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
5044     case scSignExtend: {
5045       const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
5046       if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
5047         return ConstantExpr::getSExt(CastOp, SS->getType());
5048       break;
5049     }
5050     case scZeroExtend: {
5051       const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
5052       if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
5053         return ConstantExpr::getZExt(CastOp, SZ->getType());
5054       break;
5055     }
5056     case scTruncate: {
5057       const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
5058       if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
5059         return ConstantExpr::getTrunc(CastOp, ST->getType());
5060       break;
5061     }
5062     case scAddExpr: {
5063       const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
5064       if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) {
5065         if (C->getType()->isPointerTy())
5066           C = ConstantExpr::getBitCast(C, Type::getInt8PtrTy(C->getContext()));
5067         for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) {
5068           Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i));
5069           if (!C2) return 0;
5070
5071           // First pointer!
5072           if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) {
5073             std::swap(C, C2);
5074             // The offsets have been converted to bytes.  We can add bytes to an
5075             // i8* by GEP with the byte count in the first index.
5076             C = ConstantExpr::getBitCast(C,Type::getInt8PtrTy(C->getContext()));
5077           }
5078
5079           // Don't bother trying to sum two pointers. We probably can't
5080           // statically compute a load that results from it anyway.
5081           if (C2->getType()->isPointerTy())
5082             return 0;
5083
5084           if (C->getType()->isPointerTy()) {
5085             if (cast<PointerType>(C->getType())->getElementType()->isStructTy())
5086               C2 = ConstantExpr::getIntegerCast(
5087                   C2, Type::getInt32Ty(C->getContext()), true);
5088             C = ConstantExpr::getGetElementPtr(C, C2);
5089           } else
5090             C = ConstantExpr::getAdd(C, C2);
5091         }
5092         return C;
5093       }
5094       break;
5095     }
5096     case scMulExpr: {
5097       const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
5098       if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) {
5099         // Don't bother with pointers at all.
5100         if (C->getType()->isPointerTy()) return 0;
5101         for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) {
5102           Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i));
5103           if (!C2 || C2->getType()->isPointerTy()) return 0;
5104           C = ConstantExpr::getMul(C, C2);
5105         }
5106         return C;
5107       }
5108       break;
5109     }
5110     case scUDivExpr: {
5111       const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
5112       if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS()))
5113         if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
5114           if (LHS->getType() == RHS->getType())
5115             return ConstantExpr::getUDiv(LHS, RHS);
5116       break;
5117     }
5118   }
5119   return 0;
5120 }
5121
5122 const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
5123   if (isa<SCEVConstant>(V)) return V;
5124
5125   // If this instruction is evolved from a constant-evolving PHI, compute the
5126   // exit value from the loop without using SCEVs.
5127   if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
5128     if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
5129       const Loop *LI = (*this->LI)[I->getParent()];
5130       if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
5131         if (PHINode *PN = dyn_cast<PHINode>(I))
5132           if (PN->getParent() == LI->getHeader()) {
5133             // Okay, there is no closed form solution for the PHI node.  Check
5134             // to see if the loop that contains it has a known backedge-taken
5135             // count.  If so, we may be able to force computation of the exit
5136             // value.
5137             const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI);
5138             if (const SCEVConstant *BTCC =
5139                   dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
5140               // Okay, we know how many times the containing loop executes.  If
5141               // this is a constant evolving PHI node, get the final value at
5142               // the specified iteration number.
5143               Constant *RV = getConstantEvolutionLoopExitValue(PN,
5144                                                    BTCC->getValue()->getValue(),
5145                                                                LI);
5146               if (RV) return getSCEV(RV);
5147             }
5148           }
5149
5150       // Okay, this is an expression that we cannot symbolically evaluate
5151       // into a SCEV.  Check to see if it's possible to symbolically evaluate
5152       // the arguments into constants, and if so, try to constant propagate the
5153       // result.  This is particularly useful for computing loop exit values.
5154       if (CanConstantFold(I)) {
5155         SmallVector<Constant *, 4> Operands;
5156         bool MadeImprovement = false;
5157         for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
5158           Value *Op = I->getOperand(i);
5159           if (Constant *C = dyn_cast<Constant>(Op)) {
5160             Operands.push_back(C);
5161             continue;
5162           }
5163
5164           // If any of the operands is non-constant and if they are
5165           // non-integer and non-pointer, don't even try to analyze them
5166           // with scev techniques.
5167           if (!isSCEVable(Op->getType()))
5168             return V;
5169
5170           const SCEV *OrigV = getSCEV(Op);
5171           const SCEV *OpV = getSCEVAtScope(OrigV, L);
5172           MadeImprovement |= OrigV != OpV;
5173
5174           Constant *C = BuildConstantFromSCEV(OpV);
5175           if (!C) return V;
5176           if (C->getType() != Op->getType())
5177             C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
5178                                                               Op->getType(),
5179                                                               false),
5180                                       C, Op->getType());
5181           Operands.push_back(C);
5182         }
5183
5184         // Check to see if getSCEVAtScope actually made an improvement.
5185         if (MadeImprovement) {
5186           Constant *C = 0;
5187           if (const CmpInst *CI = dyn_cast<CmpInst>(I))
5188             C = ConstantFoldCompareInstOperands(CI->getPredicate(),
5189                                                 Operands[0], Operands[1], TD,
5190                                                 TLI);
5191           else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) {
5192             if (!LI->isVolatile())
5193               C = ConstantFoldLoadFromConstPtr(Operands[0], TD);
5194           } else
5195             C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
5196                                          Operands, TD, TLI);
5197           if (!C) return V;
5198           return getSCEV(C);
5199         }
5200       }
5201     }
5202
5203     // This is some other type of SCEVUnknown, just return it.
5204     return V;
5205   }
5206
5207   if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
5208     // Avoid performing the look-up in the common case where the specified
5209     // expression has no loop-variant portions.
5210     for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
5211       const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
5212       if (OpAtScope != Comm->getOperand(i)) {
5213         // Okay, at least one of these operands is loop variant but might be
5214         // foldable.  Build a new instance of the folded commutative expression.
5215         SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
5216                                             Comm->op_begin()+i);
5217         NewOps.push_back(OpAtScope);
5218
5219         for (++i; i != e; ++i) {
5220           OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
5221           NewOps.push_back(OpAtScope);
5222         }
5223         if (isa<SCEVAddExpr>(Comm))
5224           return getAddExpr(NewOps);
5225         if (isa<SCEVMulExpr>(Comm))
5226           return getMulExpr(NewOps);
5227         if (isa<SCEVSMaxExpr>(Comm))
5228           return getSMaxExpr(NewOps);
5229         if (isa<SCEVUMaxExpr>(Comm))
5230           return getUMaxExpr(NewOps);
5231         llvm_unreachable("Unknown commutative SCEV type!");
5232       }
5233     }
5234     // If we got here, all operands are loop invariant.
5235     return Comm;
5236   }
5237
5238   if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
5239     const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
5240     const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
5241     if (LHS == Div->getLHS() && RHS == Div->getRHS())
5242       return Div;   // must be loop invariant
5243     return getUDivExpr(LHS, RHS);
5244   }
5245
5246   // If this is a loop recurrence for a loop that does not contain L, then we
5247   // are dealing with the final value computed by the loop.
5248   if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
5249     // First, attempt to evaluate each operand.
5250     // Avoid performing the look-up in the common case where the specified
5251     // expression has no loop-variant portions.
5252     for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
5253       const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
5254       if (OpAtScope == AddRec->getOperand(i))
5255         continue;
5256
5257       // Okay, at least one of these operands is loop variant but might be
5258       // foldable.  Build a new instance of the folded commutative expression.
5259       SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(),
5260                                           AddRec->op_begin()+i);
5261       NewOps.push_back(OpAtScope);
5262       for (++i; i != e; ++i)
5263         NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
5264
5265       const SCEV *FoldedRec =
5266         getAddRecExpr(NewOps, AddRec->getLoop(),
5267                       AddRec->getNoWrapFlags(SCEV::FlagNW));
5268       AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
5269       // The addrec may be folded to a nonrecurrence, for example, if the
5270       // induction variable is multiplied by zero after constant folding. Go
5271       // ahead and return the folded value.
5272       if (!AddRec)
5273         return FoldedRec;
5274       break;
5275     }
5276
5277     // If the scope is outside the addrec's loop, evaluate it by using the
5278     // loop exit value of the addrec.
5279     if (!AddRec->getLoop()->contains(L)) {
5280       // To evaluate this recurrence, we need to know how many times the AddRec
5281       // loop iterates.  Compute this now.
5282       const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
5283       if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
5284
5285       // Then, evaluate the AddRec.
5286       return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
5287     }
5288
5289     return AddRec;
5290   }
5291
5292   if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
5293     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
5294     if (Op == Cast->getOperand())
5295       return Cast;  // must be loop invariant
5296     return getZeroExtendExpr(Op, Cast->getType());
5297   }
5298
5299   if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
5300     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
5301     if (Op == Cast->getOperand())
5302       return Cast;  // must be loop invariant
5303     return getSignExtendExpr(Op, Cast->getType());
5304   }
5305
5306   if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
5307     const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
5308     if (Op == Cast->getOperand())
5309       return Cast;  // must be loop invariant
5310     return getTruncateExpr(Op, Cast->getType());
5311   }
5312
5313   llvm_unreachable("Unknown SCEV type!");
5314   return 0;
5315 }
5316
5317 /// getSCEVAtScope - This is a convenience function which does
5318 /// getSCEVAtScope(getSCEV(V), L).
5319 const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
5320   return getSCEVAtScope(getSCEV(V), L);
5321 }
5322
5323 /// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
5324 /// following equation:
5325 ///
5326 ///     A * X = B (mod N)
5327 ///
5328 /// where N = 2^BW and BW is the common bit width of A and B. The signedness of
5329 /// A and B isn't important.
5330 ///
5331 /// If the equation does not have a solution, SCEVCouldNotCompute is returned.
5332 static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
5333                                                ScalarEvolution &SE) {
5334   uint32_t BW = A.getBitWidth();
5335   assert(BW == B.getBitWidth() && "Bit widths must be the same.");
5336   assert(A != 0 && "A must be non-zero.");
5337
5338   // 1. D = gcd(A, N)
5339   //
5340   // The gcd of A and N may have only one prime factor: 2. The number of
5341   // trailing zeros in A is its multiplicity
5342   uint32_t Mult2 = A.countTrailingZeros();
5343   // D = 2^Mult2
5344
5345   // 2. Check if B is divisible by D.
5346   //
5347   // B is divisible by D if and only if the multiplicity of prime factor 2 for B
5348   // is not less than multiplicity of this prime factor for D.
5349   if (B.countTrailingZeros() < Mult2)
5350     return SE.getCouldNotCompute();
5351
5352   // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
5353   // modulo (N / D).
5354   //
5355   // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
5356   // bit width during computations.
5357   APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
5358   APInt Mod(BW + 1, 0);
5359   Mod.setBit(BW - Mult2);  // Mod = N / D
5360   APInt I = AD.multiplicativeInverse(Mod);
5361
5362   // 4. Compute the minimum unsigned root of the equation:
5363   // I * (B / D) mod (N / D)
5364   APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
5365
5366   // The result is guaranteed to be less than 2^BW so we may truncate it to BW
5367   // bits.
5368   return SE.getConstant(Result.trunc(BW));
5369 }
5370
5371 /// SolveQuadraticEquation - Find the roots of the quadratic equation for the
5372 /// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
5373 /// might be the same) or two SCEVCouldNotCompute objects.
5374 ///
5375 static std::pair<const SCEV *,const SCEV *>
5376 SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
5377   assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
5378   const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
5379   const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
5380   const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
5381
5382   // We currently can only solve this if the coefficients are constants.
5383   if (!LC || !MC || !NC) {
5384     const SCEV *CNC = SE.getCouldNotCompute();
5385     return std::make_pair(CNC, CNC);
5386   }
5387
5388   uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
5389   const APInt &L = LC->getValue()->getValue();
5390   const APInt &M = MC->getValue()->getValue();
5391   const APInt &N = NC->getValue()->getValue();
5392   APInt Two(BitWidth, 2);
5393   APInt Four(BitWidth, 4);
5394
5395   {
5396     using namespace APIntOps;
5397     const APInt& C = L;
5398     // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
5399     // The B coefficient is M-N/2
5400     APInt B(M);
5401     B -= sdiv(N,Two);
5402
5403     // The A coefficient is N/2
5404     APInt A(N.sdiv(Two));
5405
5406     // Compute the B^2-4ac term.
5407     APInt SqrtTerm(B);
5408     SqrtTerm *= B;
5409     SqrtTerm -= Four * (A * C);
5410
5411     // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
5412     // integer value or else APInt::sqrt() will assert.
5413     APInt SqrtVal(SqrtTerm.sqrt());
5414
5415     // Compute the two solutions for the quadratic formula.
5416     // The divisions must be performed as signed divisions.
5417     APInt NegB(-B);
5418     APInt TwoA(A << 1);
5419     if (TwoA.isMinValue()) {
5420       const SCEV *CNC = SE.getCouldNotCompute();
5421       return std::make_pair(CNC, CNC);
5422     }
5423
5424     LLVMContext &Context = SE.getContext();
5425
5426     ConstantInt *Solution1 =
5427       ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA));
5428     ConstantInt *Solution2 =
5429       ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA));
5430
5431     return std::make_pair(SE.getConstant(Solution1),
5432                           SE.getConstant(Solution2));
5433   } // end APIntOps namespace
5434 }
5435
5436 /// HowFarToZero - Return the number of times a backedge comparing the specified
5437 /// value to zero will execute.  If not computable, return CouldNotCompute.
5438 ///
5439 /// This is only used for loops with a "x != y" exit test. The exit condition is
5440 /// now expressed as a single expression, V = x-y. So the exit test is
5441 /// effectively V != 0.  We know and take advantage of the fact that this
5442 /// expression only being used in a comparison by zero context.
5443 ScalarEvolution::ExitLimit
5444 ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) {
5445   // If the value is a constant
5446   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
5447     // If the value is already zero, the branch will execute zero times.
5448     if (C->getValue()->isZero()) return C;
5449     return getCouldNotCompute();  // Otherwise it will loop infinitely.
5450   }
5451
5452   const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
5453   if (!AddRec || AddRec->getLoop() != L)
5454     return getCouldNotCompute();
5455
5456   // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
5457   // the quadratic equation to solve it.
5458   if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
5459     std::pair<const SCEV *,const SCEV *> Roots =
5460       SolveQuadraticEquation(AddRec, *this);
5461     const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
5462     const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
5463     if (R1 && R2) {
5464 #if 0
5465       dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1
5466              << "  sol#2: " << *R2 << "\n";
5467 #endif
5468       // Pick the smallest positive root value.
5469       if (ConstantInt *CB =
5470           dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT,
5471                                                       R1->getValue(),
5472                                                       R2->getValue()))) {
5473         if (CB->getZExtValue() == false)
5474           std::swap(R1, R2);   // R1 is the minimum root now.
5475
5476         // We can only use this value if the chrec ends up with an exact zero
5477         // value at this index.  When solving for "X*X != 5", for example, we
5478         // should not accept a root of 2.
5479         const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
5480         if (Val->isZero())
5481           return R1;  // We found a quadratic root!
5482       }
5483     }
5484     return getCouldNotCompute();
5485   }
5486
5487   // Otherwise we can only handle this if it is affine.
5488   if (!AddRec->isAffine())
5489     return getCouldNotCompute();
5490
5491   // If this is an affine expression, the execution count of this branch is
5492   // the minimum unsigned root of the following equation:
5493   //
5494   //     Start + Step*N = 0 (mod 2^BW)
5495   //
5496   // equivalent to:
5497   //
5498   //             Step*N = -Start (mod 2^BW)
5499   //
5500   // where BW is the common bit width of Start and Step.
5501
5502   // Get the initial value for the loop.
5503   const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
5504   const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
5505
5506   // For now we handle only constant steps.
5507   //
5508   // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
5509   // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
5510   // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
5511   // We have not yet seen any such cases.
5512   const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
5513   if (StepC == 0)
5514     return getCouldNotCompute();
5515
5516   // For positive steps (counting up until unsigned overflow):
5517   //   N = -Start/Step (as unsigned)
5518   // For negative steps (counting down to zero):
5519   //   N = Start/-Step
5520   // First compute the unsigned distance from zero in the direction of Step.
5521   bool CountDown = StepC->getValue()->getValue().isNegative();
5522   const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
5523
5524   // Handle unitary steps, which cannot wraparound.
5525   // 1*N = -Start; -1*N = Start (mod 2^BW), so:
5526   //   N = Distance (as unsigned)
5527   if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) {
5528     ConstantRange CR = getUnsignedRange(Start);
5529     const SCEV *MaxBECount;
5530     if (!CountDown && CR.getUnsignedMin().isMinValue())
5531       // When counting up, the worst starting value is 1, not 0.
5532       MaxBECount = CR.getUnsignedMax().isMinValue()
5533         ? getConstant(APInt::getMinValue(CR.getBitWidth()))
5534         : getConstant(APInt::getMaxValue(CR.getBitWidth()));
5535     else
5536       MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
5537                                          : -CR.getUnsignedMin());
5538     return ExitLimit(Distance, MaxBECount);
5539   }
5540
5541   // If the recurrence is known not to wraparound, unsigned divide computes the
5542   // back edge count. We know that the value will either become zero (and thus
5543   // the loop terminates), that the loop will terminate through some other exit
5544   // condition first, or that the loop has undefined behavior.  This means
5545   // we can't "miss" the exit value, even with nonunit stride.
5546   //
5547   // FIXME: Prove that loops always exhibits *acceptable* undefined
5548   // behavior. Loops must exhibit defined behavior until a wrapped value is
5549   // actually used. So the trip count computed by udiv could be smaller than the
5550   // number of well-defined iterations.
5551   if (AddRec->getNoWrapFlags(SCEV::FlagNW)) {
5552     // FIXME: We really want an "isexact" bit for udiv.
5553     return getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
5554   }
5555   // Then, try to solve the above equation provided that Start is constant.
5556   if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
5557     return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
5558                                         -StartC->getValue()->getValue(),
5559                                         *this);
5560   return getCouldNotCompute();
5561 }
5562
5563 /// HowFarToNonZero - Return the number of times a backedge checking the
5564 /// specified value for nonzero will execute.  If not computable, return
5565 /// CouldNotCompute
5566 ScalarEvolution::ExitLimit
5567 ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
5568   // Loops that look like: while (X == 0) are very strange indeed.  We don't
5569   // handle them yet except for the trivial case.  This could be expanded in the
5570   // future as needed.
5571
5572   // If the value is a constant, check to see if it is known to be non-zero
5573   // already.  If so, the backedge will execute zero times.
5574   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
5575     if (!C->getValue()->isNullValue())
5576       return getConstant(C->getType(), 0);
5577     return getCouldNotCompute();  // Otherwise it will loop infinitely.
5578   }
5579
5580   // We could implement others, but I really doubt anyone writes loops like
5581   // this, and if they did, they would already be constant folded.
5582   return getCouldNotCompute();
5583 }
5584
5585 /// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
5586 /// (which may not be an immediate predecessor) which has exactly one
5587 /// successor from which BB is reachable, or null if no such block is
5588 /// found.
5589 ///
5590 std::pair<BasicBlock *, BasicBlock *>
5591 ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
5592   // If the block has a unique predecessor, then there is no path from the
5593   // predecessor to the block that does not go through the direct edge
5594   // from the predecessor to the block.
5595   if (BasicBlock *Pred = BB->getSinglePredecessor())
5596     return std::make_pair(Pred, BB);
5597
5598   // A loop's header is defined to be a block that dominates the loop.
5599   // If the header has a unique predecessor outside the loop, it must be
5600   // a block that has exactly one successor that can reach the loop.
5601   if (Loop *L = LI->getLoopFor(BB))
5602     return std::make_pair(L->getLoopPredecessor(), L->getHeader());
5603
5604   return std::pair<BasicBlock *, BasicBlock *>();
5605 }
5606
5607 /// HasSameValue - SCEV structural equivalence is usually sufficient for
5608 /// testing whether two expressions are equal, however for the purposes of
5609 /// looking for a condition guarding a loop, it can be useful to be a little
5610 /// more general, since a front-end may have replicated the controlling
5611 /// expression.
5612 ///
5613 static bool HasSameValue(const SCEV *A, const SCEV *B) {
5614   // Quick check to see if they are the same SCEV.
5615   if (A == B) return true;
5616
5617   // Otherwise, if they're both SCEVUnknown, it's possible that they hold
5618   // two different instructions with the same value. Check for this case.
5619   if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
5620     if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
5621       if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
5622         if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
5623           if (AI->isIdenticalTo(BI) && !AI->mayReadFromMemory())
5624             return true;
5625
5626   // Otherwise assume they may have a different value.
5627   return false;
5628 }
5629
5630 /// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with
5631 /// predicate Pred. Return true iff any changes were made.
5632 ///
5633 bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
5634                                            const SCEV *&LHS, const SCEV *&RHS) {
5635   bool Changed = false;
5636
5637   // Canonicalize a constant to the right side.
5638   if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
5639     // Check for both operands constant.
5640     if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
5641       if (ConstantExpr::getICmp(Pred,
5642                                 LHSC->getValue(),
5643                                 RHSC->getValue())->isNullValue())
5644         goto trivially_false;
5645       else
5646         goto trivially_true;
5647     }
5648     // Otherwise swap the operands to put the constant on the right.
5649     std::swap(LHS, RHS);
5650     Pred = ICmpInst::getSwappedPredicate(Pred);
5651     Changed = true;
5652   }
5653
5654   // If we're comparing an addrec with a value which is loop-invariant in the
5655   // addrec's loop, put the addrec on the left. Also make a dominance check,
5656   // as both operands could be addrecs loop-invariant in each other's loop.
5657   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
5658     const Loop *L = AR->getLoop();
5659     if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
5660       std::swap(LHS, RHS);
5661       Pred = ICmpInst::getSwappedPredicate(Pred);
5662       Changed = true;
5663     }
5664   }
5665
5666   // If there's a constant operand, canonicalize comparisons with boundary
5667   // cases, and canonicalize *-or-equal comparisons to regular comparisons.
5668   if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
5669     const APInt &RA = RC->getValue()->getValue();
5670     switch (Pred) {
5671     default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
5672     case ICmpInst::ICMP_EQ:
5673     case ICmpInst::ICMP_NE:
5674       break;
5675     case ICmpInst::ICMP_UGE:
5676       if ((RA - 1).isMinValue()) {
5677         Pred = ICmpInst::ICMP_NE;
5678         RHS = getConstant(RA - 1);
5679         Changed = true;
5680         break;
5681       }
5682       if (RA.isMaxValue()) {
5683         Pred = ICmpInst::ICMP_EQ;
5684         Changed = true;
5685         break;
5686       }
5687       if (RA.isMinValue()) goto trivially_true;
5688
5689       Pred = ICmpInst::ICMP_UGT;
5690       RHS = getConstant(RA - 1);
5691       Changed = true;
5692       break;
5693     case ICmpInst::ICMP_ULE:
5694       if ((RA + 1).isMaxValue()) {
5695         Pred = ICmpInst::ICMP_NE;
5696         RHS = getConstant(RA + 1);
5697         Changed = true;
5698         break;
5699       }
5700       if (RA.isMinValue()) {
5701         Pred = ICmpInst::ICMP_EQ;
5702         Changed = true;
5703         break;
5704       }
5705       if (RA.isMaxValue()) goto trivially_true;
5706
5707       Pred = ICmpInst::ICMP_ULT;
5708       RHS = getConstant(RA + 1);
5709       Changed = true;
5710       break;
5711     case ICmpInst::ICMP_SGE:
5712       if ((RA - 1).isMinSignedValue()) {
5713         Pred = ICmpInst::ICMP_NE;
5714         RHS = getConstant(RA - 1);
5715         Changed = true;
5716         break;
5717       }
5718       if (RA.isMaxSignedValue()) {
5719         Pred = ICmpInst::ICMP_EQ;
5720         Changed = true;
5721         break;
5722       }
5723       if (RA.isMinSignedValue()) goto trivially_true;
5724
5725       Pred = ICmpInst::ICMP_SGT;
5726       RHS = getConstant(RA - 1);
5727       Changed = true;
5728       break;
5729     case ICmpInst::ICMP_SLE:
5730       if ((RA + 1).isMaxSignedValue()) {
5731         Pred = ICmpInst::ICMP_NE;
5732         RHS = getConstant(RA + 1);
5733         Changed = true;
5734         break;
5735       }
5736       if (RA.isMinSignedValue()) {
5737         Pred = ICmpInst::ICMP_EQ;
5738         Changed = true;
5739         break;
5740       }
5741       if (RA.isMaxSignedValue()) goto trivially_true;
5742
5743       Pred = ICmpInst::ICMP_SLT;
5744       RHS = getConstant(RA + 1);
5745       Changed = true;
5746       break;
5747     case ICmpInst::ICMP_UGT:
5748       if (RA.isMinValue()) {
5749         Pred = ICmpInst::ICMP_NE;
5750         Changed = true;
5751         break;
5752       }
5753       if ((RA + 1).isMaxValue()) {
5754         Pred = ICmpInst::ICMP_EQ;
5755         RHS = getConstant(RA + 1);
5756         Changed = true;
5757         break;
5758       }
5759       if (RA.isMaxValue()) goto trivially_false;
5760       break;
5761     case ICmpInst::ICMP_ULT:
5762       if (RA.isMaxValue()) {
5763         Pred = ICmpInst::ICMP_NE;
5764         Changed = true;
5765         break;
5766       }
5767       if ((RA - 1).isMinValue()) {
5768         Pred = ICmpInst::ICMP_EQ;
5769         RHS = getConstant(RA - 1);
5770         Changed = true;
5771         break;
5772       }
5773       if (RA.isMinValue()) goto trivially_false;
5774       break;
5775     case ICmpInst::ICMP_SGT:
5776       if (RA.isMinSignedValue()) {
5777         Pred = ICmpInst::ICMP_NE;
5778         Changed = true;
5779         break;
5780       }
5781       if ((RA + 1).isMaxSignedValue()) {
5782         Pred = ICmpInst::ICMP_EQ;
5783         RHS = getConstant(RA + 1);
5784         Changed = true;
5785         break;
5786       }
5787       if (RA.isMaxSignedValue()) goto trivially_false;
5788       break;
5789     case ICmpInst::ICMP_SLT:
5790       if (RA.isMaxSignedValue()) {
5791         Pred = ICmpInst::ICMP_NE;
5792         Changed = true;
5793         break;
5794       }
5795       if ((RA - 1).isMinSignedValue()) {
5796        Pred = ICmpInst::ICMP_EQ;
5797        RHS = getConstant(RA - 1);
5798         Changed = true;
5799        break;
5800       }
5801       if (RA.isMinSignedValue()) goto trivially_false;
5802       break;
5803     }
5804   }
5805
5806   // Check for obvious equality.
5807   if (HasSameValue(LHS, RHS)) {
5808     if (ICmpInst::isTrueWhenEqual(Pred))
5809       goto trivially_true;
5810     if (ICmpInst::isFalseWhenEqual(Pred))
5811       goto trivially_false;
5812   }
5813
5814   // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
5815   // adding or subtracting 1 from one of the operands.
5816   switch (Pred) {
5817   case ICmpInst::ICMP_SLE:
5818     if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) {
5819       RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
5820                        SCEV::FlagNSW);
5821       Pred = ICmpInst::ICMP_SLT;
5822       Changed = true;
5823     } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) {
5824       LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
5825                        SCEV::FlagNSW);
5826       Pred = ICmpInst::ICMP_SLT;
5827       Changed = true;
5828     }
5829     break;
5830   case ICmpInst::ICMP_SGE:
5831     if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) {
5832       RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
5833                        SCEV::FlagNSW);
5834       Pred = ICmpInst::ICMP_SGT;
5835       Changed = true;
5836     } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) {
5837       LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
5838                        SCEV::FlagNSW);
5839       Pred = ICmpInst::ICMP_SGT;
5840       Changed = true;
5841     }
5842     break;
5843   case ICmpInst::ICMP_ULE:
5844     if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) {
5845       RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
5846                        SCEV::FlagNUW);
5847       Pred = ICmpInst::ICMP_ULT;
5848       Changed = true;
5849     } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
5850       LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
5851                        SCEV::FlagNUW);
5852       Pred = ICmpInst::ICMP_ULT;
5853       Changed = true;
5854     }
5855     break;
5856   case ICmpInst::ICMP_UGE:
5857     if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
5858       RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
5859                        SCEV::FlagNUW);
5860       Pred = ICmpInst::ICMP_UGT;
5861       Changed = true;
5862     } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
5863       LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
5864                        SCEV::FlagNUW);
5865       Pred = ICmpInst::ICMP_UGT;
5866       Changed = true;
5867     }
5868     break;
5869   default:
5870     break;
5871   }
5872
5873   // TODO: More simplifications are possible here.
5874
5875   return Changed;
5876
5877 trivially_true:
5878   // Return 0 == 0.
5879   LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
5880   Pred = ICmpInst::ICMP_EQ;
5881   return true;
5882
5883 trivially_false:
5884   // Return 0 != 0.
5885   LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
5886   Pred = ICmpInst::ICMP_NE;
5887   return true;
5888 }
5889
5890 bool ScalarEvolution::isKnownNegative(const SCEV *S) {
5891   return getSignedRange(S).getSignedMax().isNegative();
5892 }
5893
5894 bool ScalarEvolution::isKnownPositive(const SCEV *S) {
5895   return getSignedRange(S).getSignedMin().isStrictlyPositive();
5896 }
5897
5898 bool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
5899   return !getSignedRange(S).getSignedMin().isNegative();
5900 }
5901
5902 bool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
5903   return !getSignedRange(S).getSignedMax().isStrictlyPositive();
5904 }
5905
5906 bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
5907   return isKnownNegative(S) || isKnownPositive(S);
5908 }
5909
5910 bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
5911                                        const SCEV *LHS, const SCEV *RHS) {
5912   // Canonicalize the inputs first.
5913   (void)SimplifyICmpOperands(Pred, LHS, RHS);
5914
5915   // If LHS or RHS is an addrec, check to see if the condition is true in
5916   // every iteration of the loop.
5917   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
5918     if (isLoopEntryGuardedByCond(
5919           AR->getLoop(), Pred, AR->getStart(), RHS) &&
5920         isLoopBackedgeGuardedByCond(
5921           AR->getLoop(), Pred, AR->getPostIncExpr(*this), RHS))
5922       return true;
5923   if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS))
5924     if (isLoopEntryGuardedByCond(
5925           AR->getLoop(), Pred, LHS, AR->getStart()) &&
5926         isLoopBackedgeGuardedByCond(
5927           AR->getLoop(), Pred, LHS, AR->getPostIncExpr(*this)))
5928       return true;
5929
5930   // Otherwise see what can be done with known constant ranges.
5931   return isKnownPredicateWithRanges(Pred, LHS, RHS);
5932 }
5933
5934 bool
5935 ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
5936                                             const SCEV *LHS, const SCEV *RHS) {
5937   if (HasSameValue(LHS, RHS))
5938     return ICmpInst::isTrueWhenEqual(Pred);
5939
5940   // This code is split out from isKnownPredicate because it is called from
5941   // within isLoopEntryGuardedByCond.
5942   switch (Pred) {
5943   default:
5944     llvm_unreachable("Unexpected ICmpInst::Predicate value!");
5945     break;
5946   case ICmpInst::ICMP_SGT:
5947     Pred = ICmpInst::ICMP_SLT;
5948     std::swap(LHS, RHS);
5949   case ICmpInst::ICMP_SLT: {
5950     ConstantRange LHSRange = getSignedRange(LHS);
5951     ConstantRange RHSRange = getSignedRange(RHS);
5952     if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin()))
5953       return true;
5954     if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax()))
5955       return false;
5956     break;
5957   }
5958   case ICmpInst::ICMP_SGE:
5959     Pred = ICmpInst::ICMP_SLE;
5960     std::swap(LHS, RHS);
5961   case ICmpInst::ICMP_SLE: {
5962     ConstantRange LHSRange = getSignedRange(LHS);
5963     ConstantRange RHSRange = getSignedRange(RHS);
5964     if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin()))
5965       return true;
5966     if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax()))
5967       return false;
5968     break;
5969   }
5970   case ICmpInst::ICMP_UGT:
5971     Pred = ICmpInst::ICMP_ULT;
5972     std::swap(LHS, RHS);
5973   case ICmpInst::ICMP_ULT: {
5974     ConstantRange LHSRange = getUnsignedRange(LHS);
5975     ConstantRange RHSRange = getUnsignedRange(RHS);
5976     if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin()))
5977       return true;
5978     if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax()))
5979       return false;
5980     break;
5981   }
5982   case ICmpInst::ICMP_UGE:
5983     Pred = ICmpInst::ICMP_ULE;
5984     std::swap(LHS, RHS);
5985   case ICmpInst::ICMP_ULE: {
5986     ConstantRange LHSRange = getUnsignedRange(LHS);
5987     ConstantRange RHSRange = getUnsignedRange(RHS);
5988     if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin()))
5989       return true;
5990     if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax()))
5991       return false;
5992     break;
5993   }
5994   case ICmpInst::ICMP_NE: {
5995     if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet())
5996       return true;
5997     if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet())
5998       return true;
5999
6000     const SCEV *Diff = getMinusSCEV(LHS, RHS);
6001     if (isKnownNonZero(Diff))
6002       return true;
6003     break;
6004   }
6005   case ICmpInst::ICMP_EQ:
6006     // The check at the top of the function catches the case where
6007     // the values are known to be equal.
6008     break;
6009   }
6010   return false;
6011 }
6012
6013 /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
6014 /// protected by a conditional between LHS and RHS.  This is used to
6015 /// to eliminate casts.
6016 bool
6017 ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
6018                                              ICmpInst::Predicate Pred,
6019                                              const SCEV *LHS, const SCEV *RHS) {
6020   // Interpret a null as meaning no loop, where there is obviously no guard
6021   // (interprocedural conditions notwithstanding).
6022   if (!L) return true;
6023
6024   BasicBlock *Latch = L->getLoopLatch();
6025   if (!Latch)
6026     return false;
6027
6028   BranchInst *LoopContinuePredicate =
6029     dyn_cast<BranchInst>(Latch->getTerminator());
6030   if (!LoopContinuePredicate ||
6031       LoopContinuePredicate->isUnconditional())
6032     return false;
6033
6034   return isImpliedCond(Pred, LHS, RHS,
6035                        LoopContinuePredicate->getCondition(),
6036                        LoopContinuePredicate->getSuccessor(0) != L->getHeader());
6037 }
6038
6039 /// isLoopEntryGuardedByCond - Test whether entry to the loop is protected
6040 /// by a conditional between LHS and RHS.  This is used to help avoid max
6041 /// expressions in loop trip counts, and to eliminate casts.
6042 bool
6043 ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
6044                                           ICmpInst::Predicate Pred,
6045                                           const SCEV *LHS, const SCEV *RHS) {
6046   // Interpret a null as meaning no loop, where there is obviously no guard
6047   // (interprocedural conditions notwithstanding).
6048   if (!L) return false;
6049
6050   // Starting at the loop predecessor, climb up the predecessor chain, as long
6051   // as there are predecessors that can be found that have unique successors
6052   // leading to the original header.
6053   for (std::pair<BasicBlock *, BasicBlock *>
6054          Pair(L->getLoopPredecessor(), L->getHeader());
6055        Pair.first;
6056        Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
6057
6058     BranchInst *LoopEntryPredicate =
6059       dyn_cast<BranchInst>(Pair.first->getTerminator());
6060     if (!LoopEntryPredicate ||
6061         LoopEntryPredicate->isUnconditional())
6062       continue;
6063
6064     if (isImpliedCond(Pred, LHS, RHS,
6065                       LoopEntryPredicate->getCondition(),
6066                       LoopEntryPredicate->getSuccessor(0) != Pair.second))
6067       return true;
6068   }
6069
6070   return false;
6071 }
6072
6073 /// isImpliedCond - Test whether the condition described by Pred, LHS,
6074 /// and RHS is true whenever the given Cond value evaluates to true.
6075 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
6076                                     const SCEV *LHS, const SCEV *RHS,
6077                                     Value *FoundCondValue,
6078                                     bool Inverse) {
6079   // Recursively handle And and Or conditions.
6080   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
6081     if (BO->getOpcode() == Instruction::And) {
6082       if (!Inverse)
6083         return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
6084                isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
6085     } else if (BO->getOpcode() == Instruction::Or) {
6086       if (Inverse)
6087         return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
6088                isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
6089     }
6090   }
6091
6092   ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
6093   if (!ICI) return false;
6094
6095   // Bail if the ICmp's operands' types are wider than the needed type
6096   // before attempting to call getSCEV on them. This avoids infinite
6097   // recursion, since the analysis of widening casts can require loop
6098   // exit condition information for overflow checking, which would
6099   // lead back here.
6100   if (getTypeSizeInBits(LHS->getType()) <
6101       getTypeSizeInBits(ICI->getOperand(0)->getType()))
6102     return false;
6103
6104   // Now that we found a conditional branch that dominates the loop, check to
6105   // see if it is the comparison we are looking for.
6106   ICmpInst::Predicate FoundPred;
6107   if (Inverse)
6108     FoundPred = ICI->getInversePredicate();
6109   else
6110     FoundPred = ICI->getPredicate();
6111
6112   const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
6113   const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
6114
6115   // Balance the types. The case where FoundLHS' type is wider than
6116   // LHS' type is checked for above.
6117   if (getTypeSizeInBits(LHS->getType()) >
6118       getTypeSizeInBits(FoundLHS->getType())) {
6119     if (CmpInst::isSigned(Pred)) {
6120       FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
6121       FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
6122     } else {
6123       FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
6124       FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
6125     }
6126   }
6127
6128   // Canonicalize the query to match the way instcombine will have
6129   // canonicalized the comparison.
6130   if (SimplifyICmpOperands(Pred, LHS, RHS))
6131     if (LHS == RHS)
6132       return CmpInst::isTrueWhenEqual(Pred);
6133   if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
6134     if (FoundLHS == FoundRHS)
6135       return CmpInst::isFalseWhenEqual(Pred);
6136
6137   // Check to see if we can make the LHS or RHS match.
6138   if (LHS == FoundRHS || RHS == FoundLHS) {
6139     if (isa<SCEVConstant>(RHS)) {
6140       std::swap(FoundLHS, FoundRHS);
6141       FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
6142     } else {
6143       std::swap(LHS, RHS);
6144       Pred = ICmpInst::getSwappedPredicate(Pred);
6145     }
6146   }
6147
6148   // Check whether the found predicate is the same as the desired predicate.
6149   if (FoundPred == Pred)
6150     return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
6151
6152   // Check whether swapping the found predicate makes it the same as the
6153   // desired predicate.
6154   if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
6155     if (isa<SCEVConstant>(RHS))
6156       return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
6157     else
6158       return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
6159                                    RHS, LHS, FoundLHS, FoundRHS);
6160   }
6161
6162   // Check whether the actual condition is beyond sufficient.
6163   if (FoundPred == ICmpInst::ICMP_EQ)
6164     if (ICmpInst::isTrueWhenEqual(Pred))
6165       if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
6166         return true;
6167   if (Pred == ICmpInst::ICMP_NE)
6168     if (!ICmpInst::isTrueWhenEqual(FoundPred))
6169       if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
6170         return true;
6171
6172   // Otherwise assume the worst.
6173   return false;
6174 }
6175
6176 /// isImpliedCondOperands - Test whether the condition described by Pred,
6177 /// LHS, and RHS is true whenever the condition described by Pred, FoundLHS,
6178 /// and FoundRHS is true.
6179 bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
6180                                             const SCEV *LHS, const SCEV *RHS,
6181                                             const SCEV *FoundLHS,
6182                                             const SCEV *FoundRHS) {
6183   return isImpliedCondOperandsHelper(Pred, LHS, RHS,
6184                                      FoundLHS, FoundRHS) ||
6185          // ~x < ~y --> x > y
6186          isImpliedCondOperandsHelper(Pred, LHS, RHS,
6187                                      getNotSCEV(FoundRHS),
6188                                      getNotSCEV(FoundLHS));
6189 }
6190
6191 /// isImpliedCondOperandsHelper - Test whether the condition described by
6192 /// Pred, LHS, and RHS is true whenever the condition described by Pred,
6193 /// FoundLHS, and FoundRHS is true.
6194 bool
6195 ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
6196                                              const SCEV *LHS, const SCEV *RHS,
6197                                              const SCEV *FoundLHS,
6198                                              const SCEV *FoundRHS) {
6199   switch (Pred) {
6200   default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
6201   case ICmpInst::ICMP_EQ:
6202   case ICmpInst::ICMP_NE:
6203     if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
6204       return true;
6205     break;
6206   case ICmpInst::ICMP_SLT:
6207   case ICmpInst::ICMP_SLE:
6208     if (isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
6209         isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, RHS, FoundRHS))
6210       return true;
6211     break;
6212   case ICmpInst::ICMP_SGT:
6213   case ICmpInst::ICMP_SGE:
6214     if (isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
6215         isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, RHS, FoundRHS))
6216       return true;
6217     break;
6218   case ICmpInst::ICMP_ULT:
6219   case ICmpInst::ICMP_ULE:
6220     if (isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
6221         isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, RHS, FoundRHS))
6222       return true;
6223     break;
6224   case ICmpInst::ICMP_UGT:
6225   case ICmpInst::ICMP_UGE:
6226     if (isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
6227         isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, RHS, FoundRHS))
6228       return true;
6229     break;
6230   }
6231
6232   return false;
6233 }
6234
6235 /// getBECount - Subtract the end and start values and divide by the step,
6236 /// rounding up, to get the number of times the backedge is executed. Return
6237 /// CouldNotCompute if an intermediate computation overflows.
6238 const SCEV *ScalarEvolution::getBECount(const SCEV *Start,
6239                                         const SCEV *End,
6240                                         const SCEV *Step,
6241                                         bool NoWrap) {
6242   assert(!isKnownNegative(Step) &&
6243          "This code doesn't handle negative strides yet!");
6244
6245   Type *Ty = Start->getType();
6246
6247   // When Start == End, we have an exact BECount == 0. Short-circuit this case
6248   // here because SCEV may not be able to determine that the unsigned division
6249   // after rounding is zero.
6250   if (Start == End)
6251     return getConstant(Ty, 0);
6252
6253   const SCEV *NegOne = getConstant(Ty, (uint64_t)-1);
6254   const SCEV *Diff = getMinusSCEV(End, Start);
6255   const SCEV *RoundUp = getAddExpr(Step, NegOne);
6256
6257   // Add an adjustment to the difference between End and Start so that
6258   // the division will effectively round up.
6259   const SCEV *Add = getAddExpr(Diff, RoundUp);
6260
6261   if (!NoWrap) {
6262     // Check Add for unsigned overflow.
6263     // TODO: More sophisticated things could be done here.
6264     Type *WideTy = IntegerType::get(getContext(),
6265                                           getTypeSizeInBits(Ty) + 1);
6266     const SCEV *EDiff = getZeroExtendExpr(Diff, WideTy);
6267     const SCEV *ERoundUp = getZeroExtendExpr(RoundUp, WideTy);
6268     const SCEV *OperandExtendedAdd = getAddExpr(EDiff, ERoundUp);
6269     if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd)
6270       return getCouldNotCompute();
6271   }
6272
6273   return getUDivExpr(Add, Step);
6274 }
6275
6276 /// HowManyLessThans - Return the number of times a backedge containing the
6277 /// specified less-than comparison will execute.  If not computable, return
6278 /// CouldNotCompute.
6279 ScalarEvolution::ExitLimit
6280 ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
6281                                   const Loop *L, bool isSigned) {
6282   // Only handle:  "ADDREC < LoopInvariant".
6283   if (!isLoopInvariant(RHS, L)) return getCouldNotCompute();
6284
6285   const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS);
6286   if (!AddRec || AddRec->getLoop() != L)
6287     return getCouldNotCompute();
6288
6289   // Check to see if we have a flag which makes analysis easy.
6290   bool NoWrap = isSigned ?
6291     AddRec->getNoWrapFlags((SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNW)) :
6292     AddRec->getNoWrapFlags((SCEV::NoWrapFlags)(SCEV::FlagNUW | SCEV::FlagNW));
6293
6294   if (AddRec->isAffine()) {
6295     unsigned BitWidth = getTypeSizeInBits(AddRec->getType());
6296     const SCEV *Step = AddRec->getStepRecurrence(*this);
6297
6298     if (Step->isZero())
6299       return getCouldNotCompute();
6300     if (Step->isOne()) {
6301       // With unit stride, the iteration never steps past the limit value.
6302     } else if (isKnownPositive(Step)) {
6303       // Test whether a positive iteration can step past the limit
6304       // value and past the maximum value for its type in a single step.
6305       // Note that it's not sufficient to check NoWrap here, because even
6306       // though the value after a wrap is undefined, it's not undefined
6307       // behavior, so if wrap does occur, the loop could either terminate or
6308       // loop infinitely, but in either case, the loop is guaranteed to
6309       // iterate at least until the iteration where the wrapping occurs.
6310       const SCEV *One = getConstant(Step->getType(), 1);
6311       if (isSigned) {
6312         APInt Max = APInt::getSignedMaxValue(BitWidth);
6313         if ((Max - getSignedRange(getMinusSCEV(Step, One)).getSignedMax())
6314               .slt(getSignedRange(RHS).getSignedMax()))
6315           return getCouldNotCompute();
6316       } else {
6317         APInt Max = APInt::getMaxValue(BitWidth);
6318         if ((Max - getUnsignedRange(getMinusSCEV(Step, One)).getUnsignedMax())
6319               .ult(getUnsignedRange(RHS).getUnsignedMax()))
6320           return getCouldNotCompute();
6321       }
6322     } else
6323       // TODO: Handle negative strides here and below.
6324       return getCouldNotCompute();
6325
6326     // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant
6327     // m.  So, we count the number of iterations in which {n,+,s} < m is true.
6328     // Note that we cannot simply return max(m-n,0)/s because it's not safe to
6329     // treat m-n as signed nor unsigned due to overflow possibility.
6330
6331     // First, we get the value of the LHS in the first iteration: n
6332     const SCEV *Start = AddRec->getOperand(0);
6333
6334     // Determine the minimum constant start value.
6335     const SCEV *MinStart = getConstant(isSigned ?
6336       getSignedRange(Start).getSignedMin() :
6337       getUnsignedRange(Start).getUnsignedMin());
6338
6339     // If we know that the condition is true in order to enter the loop,
6340     // then we know that it will run exactly (m-n)/s times. Otherwise, we
6341     // only know that it will execute (max(m,n)-n)/s times. In both cases,
6342     // the division must round up.
6343     const SCEV *End = RHS;
6344     if (!isLoopEntryGuardedByCond(L,
6345                                   isSigned ? ICmpInst::ICMP_SLT :
6346                                              ICmpInst::ICMP_ULT,
6347                                   getMinusSCEV(Start, Step), RHS))
6348       End = isSigned ? getSMaxExpr(RHS, Start)
6349                      : getUMaxExpr(RHS, Start);
6350
6351     // Determine the maximum constant end value.
6352     const SCEV *MaxEnd = getConstant(isSigned ?
6353       getSignedRange(End).getSignedMax() :
6354       getUnsignedRange(End).getUnsignedMax());
6355
6356     // If MaxEnd is within a step of the maximum integer value in its type,
6357     // adjust it down to the minimum value which would produce the same effect.
6358     // This allows the subsequent ceiling division of (N+(step-1))/step to
6359     // compute the correct value.
6360     const SCEV *StepMinusOne = getMinusSCEV(Step,
6361                                             getConstant(Step->getType(), 1));
6362     MaxEnd = isSigned ?
6363       getSMinExpr(MaxEnd,
6364                   getMinusSCEV(getConstant(APInt::getSignedMaxValue(BitWidth)),
6365                                StepMinusOne)) :
6366       getUMinExpr(MaxEnd,
6367                   getMinusSCEV(getConstant(APInt::getMaxValue(BitWidth)),
6368                                StepMinusOne));
6369
6370     // Finally, we subtract these two values and divide, rounding up, to get
6371     // the number of times the backedge is executed.
6372     const SCEV *BECount = getBECount(Start, End, Step, NoWrap);
6373
6374     // The maximum backedge count is similar, except using the minimum start
6375     // value and the maximum end value.
6376     // If we already have an exact constant BECount, use it instead.
6377     const SCEV *MaxBECount = isa<SCEVConstant>(BECount) ? BECount
6378       : getBECount(MinStart, MaxEnd, Step, NoWrap);
6379
6380     // If the stride is nonconstant, and NoWrap == true, then
6381     // getBECount(MinStart, MaxEnd) may not compute. This would result in an
6382     // exact BECount and invalid MaxBECount, which should be avoided to catch
6383     // more optimization opportunities.
6384     if (isa<SCEVCouldNotCompute>(MaxBECount))
6385       MaxBECount = BECount;
6386
6387     return ExitLimit(BECount, MaxBECount);
6388   }
6389
6390   return getCouldNotCompute();
6391 }
6392
6393 /// getNumIterationsInRange - Return the number of iterations of this loop that
6394 /// produce values in the specified constant range.  Another way of looking at
6395 /// this is that it returns the first iteration number where the value is not in
6396 /// the condition, thus computing the exit count. If the iteration count can't
6397 /// be computed, an instance of SCEVCouldNotCompute is returned.
6398 const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
6399                                                     ScalarEvolution &SE) const {
6400   if (Range.isFullSet())  // Infinite loop.
6401     return SE.getCouldNotCompute();
6402
6403   // If the start is a non-zero constant, shift the range to simplify things.
6404   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
6405     if (!SC->getValue()->isZero()) {
6406       SmallVector<const SCEV *, 4> Operands(op_begin(), op_end());
6407       Operands[0] = SE.getConstant(SC->getType(), 0);
6408       const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
6409                                              getNoWrapFlags(FlagNW));
6410       if (const SCEVAddRecExpr *ShiftedAddRec =
6411             dyn_cast<SCEVAddRecExpr>(Shifted))
6412         return ShiftedAddRec->getNumIterationsInRange(
6413                            Range.subtract(SC->getValue()->getValue()), SE);
6414       // This is strange and shouldn't happen.
6415       return SE.getCouldNotCompute();
6416     }
6417
6418   // The only time we can solve this is when we have all constant indices.
6419   // Otherwise, we cannot determine the overflow conditions.
6420   for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
6421     if (!isa<SCEVConstant>(getOperand(i)))
6422       return SE.getCouldNotCompute();
6423
6424
6425   // Okay at this point we know that all elements of the chrec are constants and
6426   // that the start element is zero.
6427
6428   // First check to see if the range contains zero.  If not, the first
6429   // iteration exits.
6430   unsigned BitWidth = SE.getTypeSizeInBits(getType());
6431   if (!Range.contains(APInt(BitWidth, 0)))
6432     return SE.getConstant(getType(), 0);
6433
6434   if (isAffine()) {
6435     // If this is an affine expression then we have this situation:
6436     //   Solve {0,+,A} in Range  ===  Ax in Range
6437
6438     // We know that zero is in the range.  If A is positive then we know that
6439     // the upper value of the range must be the first possible exit value.
6440     // If A is negative then the lower of the range is the last possible loop
6441     // value.  Also note that we already checked for a full range.
6442     APInt One(BitWidth,1);
6443     APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
6444     APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
6445
6446     // The exit value should be (End+A)/A.
6447     APInt ExitVal = (End + A).udiv(A);
6448     ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
6449
6450     // Evaluate at the exit value.  If we really did fall out of the valid
6451     // range, then we computed our trip count, otherwise wrap around or other
6452     // things must have happened.
6453     ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
6454     if (Range.contains(Val->getValue()))
6455       return SE.getCouldNotCompute();  // Something strange happened
6456
6457     // Ensure that the previous value is in the range.  This is a sanity check.
6458     assert(Range.contains(
6459            EvaluateConstantChrecAtConstant(this,
6460            ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) &&
6461            "Linear scev computation is off in a bad way!");
6462     return SE.getConstant(ExitValue);
6463   } else if (isQuadratic()) {
6464     // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
6465     // quadratic equation to solve it.  To do this, we must frame our problem in
6466     // terms of figuring out when zero is crossed, instead of when
6467     // Range.getUpper() is crossed.
6468     SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end());
6469     NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
6470     const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(),
6471                                              // getNoWrapFlags(FlagNW)
6472                                              FlagAnyWrap);
6473
6474     // Next, solve the constructed addrec
6475     std::pair<const SCEV *,const SCEV *> Roots =
6476       SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
6477     const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
6478     const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
6479     if (R1) {
6480       // Pick the smallest positive root value.
6481       if (ConstantInt *CB =
6482           dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
6483                          R1->getValue(), R2->getValue()))) {
6484         if (CB->getZExtValue() == false)
6485           std::swap(R1, R2);   // R1 is the minimum root now.
6486
6487         // Make sure the root is not off by one.  The returned iteration should
6488         // not be in the range, but the previous one should be.  When solving
6489         // for "X*X < 5", for example, we should not return a root of 2.
6490         ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
6491                                                              R1->getValue(),
6492                                                              SE);
6493         if (Range.contains(R1Val->getValue())) {
6494           // The next iteration must be out of the range...
6495           ConstantInt *NextVal =
6496                 ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1);
6497
6498           R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
6499           if (!Range.contains(R1Val->getValue()))
6500             return SE.getConstant(NextVal);
6501           return SE.getCouldNotCompute();  // Something strange happened
6502         }
6503
6504         // If R1 was not in the range, then it is a good return value.  Make
6505         // sure that R1-1 WAS in the range though, just in case.
6506         ConstantInt *NextVal =
6507                ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1);
6508         R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
6509         if (Range.contains(R1Val->getValue()))
6510           return R1;
6511         return SE.getCouldNotCompute();  // Something strange happened
6512       }
6513     }
6514   }
6515
6516   return SE.getCouldNotCompute();
6517 }
6518
6519
6520
6521 //===----------------------------------------------------------------------===//
6522 //                   SCEVCallbackVH Class Implementation
6523 //===----------------------------------------------------------------------===//
6524
6525 void ScalarEvolution::SCEVCallbackVH::deleted() {
6526   assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
6527   if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
6528     SE->ConstantEvolutionLoopExitValue.erase(PN);
6529   SE->ValueExprMap.erase(getValPtr());
6530   // this now dangles!
6531 }
6532
6533 void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
6534   assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
6535
6536   // Forget all the expressions associated with users of the old value,
6537   // so that future queries will recompute the expressions using the new
6538   // value.
6539   Value *Old = getValPtr();
6540   SmallVector<User *, 16> Worklist;
6541   SmallPtrSet<User *, 8> Visited;
6542   for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end();
6543        UI != UE; ++UI)
6544     Worklist.push_back(*UI);
6545   while (!Worklist.empty()) {
6546     User *U = Worklist.pop_back_val();
6547     // Deleting the Old value will cause this to dangle. Postpone
6548     // that until everything else is done.
6549     if (U == Old)
6550       continue;
6551     if (!Visited.insert(U))
6552       continue;
6553     if (PHINode *PN = dyn_cast<PHINode>(U))
6554       SE->ConstantEvolutionLoopExitValue.erase(PN);
6555     SE->ValueExprMap.erase(U);
6556     for (Value::use_iterator UI = U->use_begin(), UE = U->use_end();
6557          UI != UE; ++UI)
6558       Worklist.push_back(*UI);
6559   }
6560   // Delete the Old value.
6561   if (PHINode *PN = dyn_cast<PHINode>(Old))
6562     SE->ConstantEvolutionLoopExitValue.erase(PN);
6563   SE->ValueExprMap.erase(Old);
6564   // this now dangles!
6565 }
6566
6567 ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
6568   : CallbackVH(V), SE(se) {}
6569
6570 //===----------------------------------------------------------------------===//
6571 //                   ScalarEvolution Class Implementation
6572 //===----------------------------------------------------------------------===//
6573
6574 ScalarEvolution::ScalarEvolution()
6575   : FunctionPass(ID), FirstUnknown(0) {
6576   initializeScalarEvolutionPass(*PassRegistry::getPassRegistry());
6577 }
6578
6579 bool ScalarEvolution::runOnFunction(Function &F) {
6580   this->F = &F;
6581   LI = &getAnalysis<LoopInfo>();
6582   TD = getAnalysisIfAvailable<TargetData>();
6583   TLI = &getAnalysis<TargetLibraryInfo>();
6584   DT = &getAnalysis<DominatorTree>();
6585   return false;
6586 }
6587
6588 void ScalarEvolution::releaseMemory() {
6589   // Iterate through all the SCEVUnknown instances and call their
6590   // destructors, so that they release their references to their values.
6591   for (SCEVUnknown *U = FirstUnknown; U; U = U->Next)
6592     U->~SCEVUnknown();
6593   FirstUnknown = 0;
6594
6595   ValueExprMap.clear();
6596
6597   // Free any extra memory created for ExitNotTakenInfo in the unlikely event
6598   // that a loop had multiple computable exits.
6599   for (DenseMap<const Loop*, BackedgeTakenInfo>::iterator I =
6600          BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end();
6601        I != E; ++I) {
6602     I->second.clear();
6603   }
6604
6605   BackedgeTakenCounts.clear();
6606   ConstantEvolutionLoopExitValue.clear();
6607   ValuesAtScopes.clear();
6608   LoopDispositions.clear();
6609   BlockDispositions.clear();
6610   UnsignedRanges.clear();
6611   SignedRanges.clear();
6612   UniqueSCEVs.clear();
6613   SCEVAllocator.Reset();
6614 }
6615
6616 void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
6617   AU.setPreservesAll();
6618   AU.addRequiredTransitive<LoopInfo>();
6619   AU.addRequiredTransitive<DominatorTree>();
6620   AU.addRequired<TargetLibraryInfo>();
6621 }
6622
6623 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
6624   return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
6625 }
6626
6627 static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
6628                           const Loop *L) {
6629   // Print all inner loops first
6630   for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
6631     PrintLoopInfo(OS, SE, *I);
6632
6633   OS << "Loop ";
6634   WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
6635   OS << ": ";
6636
6637   SmallVector<BasicBlock *, 8> ExitBlocks;
6638   L->getExitBlocks(ExitBlocks);
6639   if (ExitBlocks.size() != 1)
6640     OS << "<multiple exits> ";
6641
6642   if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
6643     OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L);
6644   } else {
6645     OS << "Unpredictable backedge-taken count. ";
6646   }
6647
6648   OS << "\n"
6649         "Loop ";
6650   WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
6651   OS << ": ";
6652
6653   if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
6654     OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
6655   } else {
6656     OS << "Unpredictable max backedge-taken count. ";
6657   }
6658
6659   OS << "\n";
6660 }
6661
6662 void ScalarEvolution::print(raw_ostream &OS, const Module *) const {
6663   // ScalarEvolution's implementation of the print method is to print
6664   // out SCEV values of all instructions that are interesting. Doing
6665   // this potentially causes it to create new SCEV objects though,
6666   // which technically conflicts with the const qualifier. This isn't
6667   // observable from outside the class though, so casting away the
6668   // const isn't dangerous.
6669   ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
6670
6671   OS << "Classifying expressions for: ";
6672   WriteAsOperand(OS, F, /*PrintType=*/false);
6673   OS << "\n";
6674   for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
6675     if (isSCEVable(I->getType()) && !isa<CmpInst>(*I)) {
6676       OS << *I << '\n';
6677       OS << "  -->  ";
6678       const SCEV *SV = SE.getSCEV(&*I);
6679       SV->print(OS);
6680
6681       const Loop *L = LI->getLoopFor((*I).getParent());
6682
6683       const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
6684       if (AtUse != SV) {
6685         OS << "  -->  ";
6686         AtUse->print(OS);
6687       }
6688
6689       if (L) {
6690         OS << "\t\t" "Exits: ";
6691         const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
6692         if (!SE.isLoopInvariant(ExitValue, L)) {
6693           OS << "<<Unknown>>";
6694         } else {
6695           OS << *ExitValue;
6696         }
6697       }
6698
6699       OS << "\n";
6700     }
6701
6702   OS << "Determining loop execution counts for: ";
6703   WriteAsOperand(OS, F, /*PrintType=*/false);
6704   OS << "\n";
6705   for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
6706     PrintLoopInfo(OS, &SE, *I);
6707 }
6708
6709 ScalarEvolution::LoopDisposition
6710 ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
6711   std::map<const Loop *, LoopDisposition> &Values = LoopDispositions[S];
6712   std::pair<std::map<const Loop *, LoopDisposition>::iterator, bool> Pair =
6713     Values.insert(std::make_pair(L, LoopVariant));
6714   if (!Pair.second)
6715     return Pair.first->second;
6716
6717   LoopDisposition D = computeLoopDisposition(S, L);
6718   return LoopDispositions[S][L] = D;
6719 }
6720
6721 ScalarEvolution::LoopDisposition
6722 ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
6723   switch (S->getSCEVType()) {
6724   case scConstant:
6725     return LoopInvariant;
6726   case scTruncate:
6727   case scZeroExtend:
6728   case scSignExtend:
6729     return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L);
6730   case scAddRecExpr: {
6731     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
6732
6733     // If L is the addrec's loop, it's computable.
6734     if (AR->getLoop() == L)
6735       return LoopComputable;
6736
6737     // Add recurrences are never invariant in the function-body (null loop).
6738     if (!L)
6739       return LoopVariant;
6740
6741     // This recurrence is variant w.r.t. L if L contains AR's loop.
6742     if (L->contains(AR->getLoop()))
6743       return LoopVariant;
6744
6745     // This recurrence is invariant w.r.t. L if AR's loop contains L.
6746     if (AR->getLoop()->contains(L))
6747       return LoopInvariant;
6748
6749     // This recurrence is variant w.r.t. L if any of its operands
6750     // are variant.
6751     for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
6752          I != E; ++I)
6753       if (!isLoopInvariant(*I, L))
6754         return LoopVariant;
6755
6756     // Otherwise it's loop-invariant.
6757     return LoopInvariant;
6758   }
6759   case scAddExpr:
6760   case scMulExpr:
6761   case scUMaxExpr:
6762   case scSMaxExpr: {
6763     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
6764     bool HasVarying = false;
6765     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
6766          I != E; ++I) {
6767       LoopDisposition D = getLoopDisposition(*I, L);
6768       if (D == LoopVariant)
6769         return LoopVariant;
6770       if (D == LoopComputable)
6771         HasVarying = true;
6772     }
6773     return HasVarying ? LoopComputable : LoopInvariant;
6774   }
6775   case scUDivExpr: {
6776     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6777     LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L);
6778     if (LD == LoopVariant)
6779       return LoopVariant;
6780     LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L);
6781     if (RD == LoopVariant)
6782       return LoopVariant;
6783     return (LD == LoopInvariant && RD == LoopInvariant) ?
6784            LoopInvariant : LoopComputable;
6785   }
6786   case scUnknown:
6787     // All non-instruction values are loop invariant.  All instructions are loop
6788     // invariant if they are not contained in the specified loop.
6789     // Instructions are never considered invariant in the function body
6790     // (null loop) because they are defined within the "loop".
6791     if (Instruction *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
6792       return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
6793     return LoopInvariant;
6794   case scCouldNotCompute:
6795     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6796     return LoopVariant;
6797   default: break;
6798   }
6799   llvm_unreachable("Unknown SCEV kind!");
6800   return LoopVariant;
6801 }
6802
6803 bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) {
6804   return getLoopDisposition(S, L) == LoopInvariant;
6805 }
6806
6807 bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
6808   return getLoopDisposition(S, L) == LoopComputable;
6809 }
6810
6811 ScalarEvolution::BlockDisposition
6812 ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
6813   std::map<const BasicBlock *, BlockDisposition> &Values = BlockDispositions[S];
6814   std::pair<std::map<const BasicBlock *, BlockDisposition>::iterator, bool>
6815     Pair = Values.insert(std::make_pair(BB, DoesNotDominateBlock));
6816   if (!Pair.second)
6817     return Pair.first->second;
6818
6819   BlockDisposition D = computeBlockDisposition(S, BB);
6820   return BlockDispositions[S][BB] = D;
6821 }
6822
6823 ScalarEvolution::BlockDisposition
6824 ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
6825   switch (S->getSCEVType()) {
6826   case scConstant:
6827     return ProperlyDominatesBlock;
6828   case scTruncate:
6829   case scZeroExtend:
6830   case scSignExtend:
6831     return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB);
6832   case scAddRecExpr: {
6833     // This uses a "dominates" query instead of "properly dominates" query
6834     // to test for proper dominance too, because the instruction which
6835     // produces the addrec's value is a PHI, and a PHI effectively properly
6836     // dominates its entire containing block.
6837     const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
6838     if (!DT->dominates(AR->getLoop()->getHeader(), BB))
6839       return DoesNotDominateBlock;
6840   }
6841   // FALL THROUGH into SCEVNAryExpr handling.
6842   case scAddExpr:
6843   case scMulExpr:
6844   case scUMaxExpr:
6845   case scSMaxExpr: {
6846     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
6847     bool Proper = true;
6848     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
6849          I != E; ++I) {
6850       BlockDisposition D = getBlockDisposition(*I, BB);
6851       if (D == DoesNotDominateBlock)
6852         return DoesNotDominateBlock;
6853       if (D == DominatesBlock)
6854         Proper = false;
6855     }
6856     return Proper ? ProperlyDominatesBlock : DominatesBlock;
6857   }
6858   case scUDivExpr: {
6859     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6860     const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
6861     BlockDisposition LD = getBlockDisposition(LHS, BB);
6862     if (LD == DoesNotDominateBlock)
6863       return DoesNotDominateBlock;
6864     BlockDisposition RD = getBlockDisposition(RHS, BB);
6865     if (RD == DoesNotDominateBlock)
6866       return DoesNotDominateBlock;
6867     return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ?
6868       ProperlyDominatesBlock : DominatesBlock;
6869   }
6870   case scUnknown:
6871     if (Instruction *I =
6872           dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
6873       if (I->getParent() == BB)
6874         return DominatesBlock;
6875       if (DT->properlyDominates(I->getParent(), BB))
6876         return ProperlyDominatesBlock;
6877       return DoesNotDominateBlock;
6878     }
6879     return ProperlyDominatesBlock;
6880   case scCouldNotCompute:
6881     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6882     return DoesNotDominateBlock;
6883   default: break;
6884   }
6885   llvm_unreachable("Unknown SCEV kind!");
6886   return DoesNotDominateBlock;
6887 }
6888
6889 bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
6890   return getBlockDisposition(S, BB) >= DominatesBlock;
6891 }
6892
6893 bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
6894   return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
6895 }
6896
6897 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
6898   switch (S->getSCEVType()) {
6899   case scConstant:
6900     return false;
6901   case scTruncate:
6902   case scZeroExtend:
6903   case scSignExtend: {
6904     const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6905     const SCEV *CastOp = Cast->getOperand();
6906     return Op == CastOp || hasOperand(CastOp, Op);
6907   }
6908   case scAddRecExpr:
6909   case scAddExpr:
6910   case scMulExpr:
6911   case scUMaxExpr:
6912   case scSMaxExpr: {
6913     const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
6914     for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
6915          I != E; ++I) {
6916       const SCEV *NAryOp = *I;
6917       if (NAryOp == Op || hasOperand(NAryOp, Op))
6918         return true;
6919     }
6920     return false;
6921   }
6922   case scUDivExpr: {
6923     const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6924     const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
6925     return LHS == Op || hasOperand(LHS, Op) ||
6926            RHS == Op || hasOperand(RHS, Op);
6927   }
6928   case scUnknown:
6929     return false;
6930   case scCouldNotCompute:
6931     llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6932     return false;
6933   default: break;
6934   }
6935   llvm_unreachable("Unknown SCEV kind!");
6936   return false;
6937 }
6938
6939 void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
6940   ValuesAtScopes.erase(S);
6941   LoopDispositions.erase(S);
6942   BlockDispositions.erase(S);
6943   UnsignedRanges.erase(S);
6944   SignedRanges.erase(S);
6945 }