Change the ExitBlocks list from being explicitly contained in the Loop
[oota-llvm.git] / lib / Transforms / Scalar / IndVarSimplify.cpp
1 //===- IndVarSimplify.cpp - Induction Variable Elimination ----------------===//
2 // 
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by the LLVM research group and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 // 
8 //===----------------------------------------------------------------------===//
9 //
10 // This transformation analyzes and transforms the induction variables (and
11 // computations derived from them) into simpler forms suitable for subsequent
12 // analysis and transformation.
13 //
14 // This transformation make the following changes to each loop with an
15 // identifiable induction variable:
16 //   1. All loops are transformed to have a SINGLE canonical induction variable
17 //      which starts at zero and steps by one.
18 //   2. The canonical induction variable is guaranteed to be the first PHI node
19 //      in the loop header block.
20 //   3. Any pointer arithmetic recurrences are raised to use array subscripts.
21 //
22 // If the trip count of a loop is computable, this pass also makes the following
23 // changes:
24 //   1. The exit condition for the loop is canonicalized to compare the
25 //      induction value against the exit value.  This turns loops like:
26 //        'for (i = 7; i*i < 1000; ++i)' into 'for (i = 0; i != 25; ++i)'
27 //   2. Any use outside of the loop of an expression derived from the indvar
28 //      is changed to compute the derived value outside of the loop, eliminating
29 //      the dependence on the exit value of the induction variable.  If the only
30 //      purpose of the loop is to compute the exit value of some derived
31 //      expression, this transformation will make the loop dead.
32 //
33 // This transformation should be followed by strength reduction after all of the
34 // desired loop transformations have been performed.  Additionally, on targets
35 // where it is profitable, the loop could be transformed to count down to zero
36 // (the "do loop" optimization).
37 //
38 //===----------------------------------------------------------------------===//
39
40 #include "llvm/Transforms/Scalar.h"
41 #include "llvm/BasicBlock.h"
42 #include "llvm/Constants.h"
43 #include "llvm/Instructions.h"
44 #include "llvm/Type.h"
45 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
46 #include "llvm/Analysis/LoopInfo.h"
47 #include "llvm/Support/CFG.h"
48 #include "llvm/Transforms/Utils/Local.h"
49 #include "Support/CommandLine.h"
50 #include "Support/Statistic.h"
51 using namespace llvm;
52
53 namespace {
54   Statistic<> NumRemoved ("indvars", "Number of aux indvars removed");
55   Statistic<> NumPointer ("indvars", "Number of pointer indvars promoted");
56   Statistic<> NumInserted("indvars", "Number of canonical indvars added");
57   Statistic<> NumReplaced("indvars", "Number of exit values replaced");
58   Statistic<> NumLFTR    ("indvars", "Number of loop exit tests replaced");
59
60   class IndVarSimplify : public FunctionPass {
61     LoopInfo        *LI;
62     ScalarEvolution *SE;
63     bool Changed;
64   public:
65     virtual bool runOnFunction(Function &) {
66       LI = &getAnalysis<LoopInfo>();
67       SE = &getAnalysis<ScalarEvolution>();
68       Changed = false;
69
70       // Induction Variables live in the header nodes of loops
71       for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
72         runOnLoop(*I);
73       return Changed;
74     }
75
76     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
77       AU.addRequiredID(LoopSimplifyID);
78       AU.addRequired<ScalarEvolution>();
79       AU.addRequired<LoopInfo>();
80       AU.addPreservedID(LoopSimplifyID);
81       AU.setPreservesCFG();
82     }
83   private:
84     void runOnLoop(Loop *L);
85     void EliminatePointerRecurrence(PHINode *PN, BasicBlock *Preheader,
86                                     std::set<Instruction*> &DeadInsts);
87     void LinearFunctionTestReplace(Loop *L, SCEV *IterationCount,
88                                    ScalarEvolutionRewriter &RW);
89     void RewriteLoopExitValues(Loop *L);
90
91     void DeleteTriviallyDeadInstructions(std::set<Instruction*> &Insts);
92   };
93   RegisterOpt<IndVarSimplify> X("indvars", "Canonicalize Induction Variables");
94 }
95
96 Pass *llvm::createIndVarSimplifyPass() {
97   return new IndVarSimplify();
98 }
99
100
101 /// DeleteTriviallyDeadInstructions - If any of the instructions is the
102 /// specified set are trivially dead, delete them and see if this makes any of
103 /// their operands subsequently dead.
104 void IndVarSimplify::
105 DeleteTriviallyDeadInstructions(std::set<Instruction*> &Insts) {
106   while (!Insts.empty()) {
107     Instruction *I = *Insts.begin();
108     Insts.erase(Insts.begin());
109     if (isInstructionTriviallyDead(I)) {
110       for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i)
111         if (Instruction *U = dyn_cast<Instruction>(I->getOperand(i)))
112           Insts.insert(U);
113       SE->deleteInstructionFromRecords(I);
114       I->getParent()->getInstList().erase(I);
115       Changed = true;
116     }
117   }
118 }
119
120
121 /// EliminatePointerRecurrence - Check to see if this is a trivial GEP pointer
122 /// recurrence.  If so, change it into an integer recurrence, permitting
123 /// analysis by the SCEV routines.
124 void IndVarSimplify::EliminatePointerRecurrence(PHINode *PN, 
125                                                 BasicBlock *Preheader,
126                                             std::set<Instruction*> &DeadInsts) {
127   assert(PN->getNumIncomingValues() == 2 && "Noncanonicalized loop!");
128   unsigned PreheaderIdx = PN->getBasicBlockIndex(Preheader);
129   unsigned BackedgeIdx = PreheaderIdx^1;
130   if (GetElementPtrInst *GEPI =
131       dyn_cast<GetElementPtrInst>(PN->getIncomingValue(BackedgeIdx)))
132     if (GEPI->getOperand(0) == PN) {
133       assert(GEPI->getNumOperands() == 2 && "GEP types must mismatch!");
134           
135       // Okay, we found a pointer recurrence.  Transform this pointer
136       // recurrence into an integer recurrence.  Compute the value that gets
137       // added to the pointer at every iteration.
138       Value *AddedVal = GEPI->getOperand(1);
139
140       // Insert a new integer PHI node into the top of the block.
141       PHINode *NewPhi = new PHINode(AddedVal->getType(),
142                                     PN->getName()+".rec", PN);
143       NewPhi->addIncoming(Constant::getNullValue(NewPhi->getType()),
144                           Preheader);
145       // Create the new add instruction.
146       Value *NewAdd = BinaryOperator::create(Instruction::Add, NewPhi,
147                                              AddedVal,
148                                              GEPI->getName()+".rec", GEPI);
149       NewPhi->addIncoming(NewAdd, PN->getIncomingBlock(BackedgeIdx));
150           
151       // Update the existing GEP to use the recurrence.
152       GEPI->setOperand(0, PN->getIncomingValue(PreheaderIdx));
153           
154       // Update the GEP to use the new recurrence we just inserted.
155       GEPI->setOperand(1, NewAdd);
156
157       // Finally, if there are any other users of the PHI node, we must
158       // insert a new GEP instruction that uses the pre-incremented version
159       // of the induction amount.
160       if (!PN->use_empty()) {
161         BasicBlock::iterator InsertPos = PN; ++InsertPos;
162         while (isa<PHINode>(InsertPos)) ++InsertPos;
163         std::string Name = PN->getName(); PN->setName("");
164         Value *PreInc =
165           new GetElementPtrInst(PN->getIncomingValue(PreheaderIdx),
166                                 std::vector<Value*>(1, NewPhi), Name,
167                                 InsertPos);
168         PN->replaceAllUsesWith(PreInc);
169       }
170
171       // Delete the old PHI for sure, and the GEP if its otherwise unused.
172       DeadInsts.insert(PN);
173
174       ++NumPointer;
175       Changed = true;
176     }
177 }
178
179 /// LinearFunctionTestReplace - This method rewrites the exit condition of the
180 /// loop to be a canonical != comparison against the incremented loop induction
181 /// variable.  This pass is able to rewrite the exit tests of any loop where the
182 /// SCEV analysis can determine a loop-invariant trip count of the loop, which
183 /// is actually a much broader range than just linear tests.
184 void IndVarSimplify::LinearFunctionTestReplace(Loop *L, SCEV *IterationCount,
185                                                ScalarEvolutionRewriter &RW) {
186   // Find the exit block for the loop.  We can currently only handle loops with
187   // a single exit.
188   std::vector<BasicBlock*> ExitBlocks;
189   L->getExitBlocks(ExitBlocks);
190   if (ExitBlocks.size() != 1) return;
191   BasicBlock *ExitBlock = ExitBlocks[0];
192
193   // Make sure there is only one predecessor block in the loop.
194   BasicBlock *ExitingBlock = 0;
195   for (pred_iterator PI = pred_begin(ExitBlock), PE = pred_end(ExitBlock);
196        PI != PE; ++PI)
197     if (L->contains(*PI)) {
198       if (ExitingBlock == 0)
199         ExitingBlock = *PI;
200       else
201         return;  // Multiple exits from loop to this block.
202     }
203   assert(ExitingBlock && "Loop info is broken");
204
205   if (!isa<BranchInst>(ExitingBlock->getTerminator()))
206     return;  // Can't rewrite non-branch yet
207   BranchInst *BI = cast<BranchInst>(ExitingBlock->getTerminator());
208   assert(BI->isConditional() && "Must be conditional to be part of loop!");
209
210   std::set<Instruction*> InstructionsToDelete;
211   if (Instruction *Cond = dyn_cast<Instruction>(BI->getCondition()))
212     InstructionsToDelete.insert(Cond);
213
214   // If the exiting block is not the same as the backedge block, we must compare
215   // against the preincremented value, otherwise we prefer to compare against
216   // the post-incremented value.
217   BasicBlock *Header = L->getHeader();
218   pred_iterator HPI = pred_begin(Header);
219   assert(HPI != pred_end(Header) && "Loop with zero preds???");
220   if (!L->contains(*HPI)) ++HPI;
221   assert(HPI != pred_end(Header) && L->contains(*HPI) &&
222          "No backedge in loop?");
223
224   SCEVHandle TripCount = IterationCount;
225   Value *IndVar;
226   if (*HPI == ExitingBlock) {
227     // The IterationCount expression contains the number of times that the
228     // backedge actually branches to the loop header.  This is one less than the
229     // number of times the loop executes, so add one to it.
230     Constant *OneC = ConstantInt::get(IterationCount->getType(), 1);
231     TripCount = SCEVAddExpr::get(IterationCount, SCEVUnknown::get(OneC));
232     IndVar = L->getCanonicalInductionVariableIncrement();
233   } else {
234     // We have to use the preincremented value...
235     IndVar = L->getCanonicalInductionVariable();
236   }
237
238   // Expand the code for the iteration count into the preheader of the loop.
239   BasicBlock *Preheader = L->getLoopPreheader();
240   Value *ExitCnt = RW.ExpandCodeFor(TripCount, Preheader->getTerminator(),
241                                     IndVar->getType());
242
243   // Insert a new setne or seteq instruction before the branch.
244   Instruction::BinaryOps Opcode;
245   if (L->contains(BI->getSuccessor(0)))
246     Opcode = Instruction::SetNE;
247   else
248     Opcode = Instruction::SetEQ;
249
250   Value *Cond = new SetCondInst(Opcode, IndVar, ExitCnt, "exitcond", BI);
251   BI->setCondition(Cond);
252   ++NumLFTR;
253   Changed = true;
254
255   DeleteTriviallyDeadInstructions(InstructionsToDelete);
256 }
257
258
259 /// RewriteLoopExitValues - Check to see if this loop has a computable
260 /// loop-invariant execution count.  If so, this means that we can compute the
261 /// final value of any expressions that are recurrent in the loop, and
262 /// substitute the exit values from the loop into any instructions outside of
263 /// the loop that use the final values of the current expressions.
264 void IndVarSimplify::RewriteLoopExitValues(Loop *L) {
265   BasicBlock *Preheader = L->getLoopPreheader();
266
267   // Scan all of the instructions in the loop, looking at those that have
268   // extra-loop users and which are recurrences.
269   ScalarEvolutionRewriter Rewriter(*SE, *LI);
270
271   // We insert the code into the preheader of the loop if the loop contains
272   // multiple exit blocks, or in the exit block if there is exactly one.
273   BasicBlock *BlockToInsertInto;
274   std::vector<BasicBlock*> ExitBlocks;
275   L->getExitBlocks(ExitBlocks);
276   if (ExitBlocks.size() == 1)
277     BlockToInsertInto = ExitBlocks[0];
278   else
279     BlockToInsertInto = Preheader;
280   BasicBlock::iterator InsertPt = BlockToInsertInto->begin();
281   while (isa<PHINode>(InsertPt)) ++InsertPt;
282
283   bool HasConstantItCount = isa<SCEVConstant>(SE->getIterationCount(L));
284
285   std::set<Instruction*> InstructionsToDelete;
286   
287   for (unsigned i = 0, e = L->getBlocks().size(); i != e; ++i)
288     if (LI->getLoopFor(L->getBlocks()[i]) == L) {  // Not in a subloop...
289       BasicBlock *BB = L->getBlocks()[i];
290       for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
291         if (I->getType()->isInteger()) {      // Is an integer instruction
292           SCEVHandle SH = SE->getSCEV(I);
293           if (SH->hasComputableLoopEvolution(L) ||    // Varies predictably
294               HasConstantItCount) {
295             // Find out if this predictably varying value is actually used
296             // outside of the loop.  "extra" as opposed to "intra".
297             std::vector<User*> ExtraLoopUsers;
298             for (Value::use_iterator UI = I->use_begin(), E = I->use_end();
299                  UI != E; ++UI)
300               if (!L->contains(cast<Instruction>(*UI)->getParent()))
301                 ExtraLoopUsers.push_back(*UI);
302             if (!ExtraLoopUsers.empty()) {
303               // Okay, this instruction has a user outside of the current loop
304               // and varies predictably in this loop.  Evaluate the value it
305               // contains when the loop exits, and insert code for it.
306               SCEVHandle ExitValue = SE->getSCEVAtScope(I, L->getParentLoop());
307               if (!isa<SCEVCouldNotCompute>(ExitValue)) {
308                 Changed = true;
309                 ++NumReplaced;
310                 Value *NewVal = Rewriter.ExpandCodeFor(ExitValue, InsertPt,
311                                                        I->getType());
312
313                 // Rewrite any users of the computed value outside of the loop
314                 // with the newly computed value.
315                 for (unsigned i = 0, e = ExtraLoopUsers.size(); i != e; ++i)
316                   ExtraLoopUsers[i]->replaceUsesOfWith(I, NewVal);
317
318                 // If this instruction is dead now, schedule it to be removed.
319                 if (I->use_empty())
320                   InstructionsToDelete.insert(I);
321               }
322             }
323           }
324         }
325     }
326
327   DeleteTriviallyDeadInstructions(InstructionsToDelete);
328 }
329
330
331 void IndVarSimplify::runOnLoop(Loop *L) {
332   // First step.  Check to see if there are any trivial GEP pointer recurrences.
333   // If there are, change them into integer recurrences, permitting analysis by
334   // the SCEV routines.
335   //
336   BasicBlock *Header    = L->getHeader();
337   BasicBlock *Preheader = L->getLoopPreheader();
338   
339   std::set<Instruction*> DeadInsts;
340   for (BasicBlock::iterator I = Header->begin();
341        PHINode *PN = dyn_cast<PHINode>(I); ++I)
342     if (isa<PointerType>(PN->getType()))
343       EliminatePointerRecurrence(PN, Preheader, DeadInsts);
344
345   if (!DeadInsts.empty())
346     DeleteTriviallyDeadInstructions(DeadInsts);
347
348
349   // Next, transform all loops nesting inside of this loop.
350   for (LoopInfo::iterator I = L->begin(), E = L->end(); I != E; ++I)
351     runOnLoop(*I);
352
353   // Check to see if this loop has a computable loop-invariant execution count.
354   // If so, this means that we can compute the final value of any expressions
355   // that are recurrent in the loop, and substitute the exit values from the
356   // loop into any instructions outside of the loop that use the final values of
357   // the current expressions.
358   //
359   SCEVHandle IterationCount = SE->getIterationCount(L);
360   if (!isa<SCEVCouldNotCompute>(IterationCount))
361     RewriteLoopExitValues(L);
362
363   // Next, analyze all of the induction variables in the loop, canonicalizing
364   // auxillary induction variables.
365   std::vector<std::pair<PHINode*, SCEVHandle> > IndVars;
366
367   for (BasicBlock::iterator I = Header->begin();
368        PHINode *PN = dyn_cast<PHINode>(I); ++I)
369     if (PN->getType()->isInteger()) {  // FIXME: when we have fast-math, enable!
370       SCEVHandle SCEV = SE->getSCEV(PN);
371       if (SCEV->hasComputableLoopEvolution(L))
372         if (SE->shouldSubstituteIndVar(SCEV))  // HACK!
373           IndVars.push_back(std::make_pair(PN, SCEV));
374     }
375
376   // If there are no induction variables in the loop, there is nothing more to
377   // do.
378   if (IndVars.empty()) {
379     // Actually, if we know how many times the loop iterates, lets insert a
380     // canonical induction variable to help subsequent passes.
381     if (!isa<SCEVCouldNotCompute>(IterationCount)) {
382       ScalarEvolutionRewriter Rewriter(*SE, *LI);
383       Rewriter.GetOrInsertCanonicalInductionVariable(L,
384                                                      IterationCount->getType());
385       LinearFunctionTestReplace(L, IterationCount, Rewriter);
386     }
387     return;
388   }
389
390   // Compute the type of the largest recurrence expression.
391   //
392   const Type *LargestType = IndVars[0].first->getType();
393   bool DifferingSizes = false;
394   for (unsigned i = 1, e = IndVars.size(); i != e; ++i) {
395     const Type *Ty = IndVars[i].first->getType();
396     DifferingSizes |= Ty->getPrimitiveSize() != LargestType->getPrimitiveSize();
397     if (Ty->getPrimitiveSize() > LargestType->getPrimitiveSize())
398       LargestType = Ty;
399   }
400
401   // Create a rewriter object which we'll use to transform the code with.
402   ScalarEvolutionRewriter Rewriter(*SE, *LI);
403
404   // Now that we know the largest of of the induction variables in this loop,
405   // insert a canonical induction variable of the largest size.
406   LargestType = LargestType->getUnsignedVersion();
407   Value *IndVar = Rewriter.GetOrInsertCanonicalInductionVariable(L,LargestType);
408   ++NumInserted;
409   Changed = true;
410
411   if (!isa<SCEVCouldNotCompute>(IterationCount))
412     LinearFunctionTestReplace(L, IterationCount, Rewriter);
413
414 #if 0
415   // If there were induction variables of other sizes, cast the primary
416   // induction variable to the right size for them, avoiding the need for the
417   // code evaluation methods to insert induction variables of different sizes.
418   // FIXME!
419   if (DifferingSizes) {
420     std::map<unsigned, Value*> InsertedSizes;
421     for (unsigned i = 0, e = IndVars.size(); i != e; ++i) {
422     }    
423   }
424 #endif
425
426   // Now that we have a canonical induction variable, we can rewrite any
427   // recurrences in terms of the induction variable.  Start with the auxillary
428   // induction variables, and recursively rewrite any of their uses.
429   BasicBlock::iterator InsertPt = Header->begin();
430   while (isa<PHINode>(InsertPt)) ++InsertPt;
431
432   while (!IndVars.empty()) {
433     PHINode *PN = IndVars.back().first;
434     Value *NewVal = Rewriter.ExpandCodeFor(IndVars.back().second, InsertPt,
435                                            PN->getType());
436     // Replace the old PHI Node with the inserted computation.
437     PN->replaceAllUsesWith(NewVal);
438     DeadInsts.insert(PN);
439     IndVars.pop_back();
440     ++NumRemoved;
441     Changed = true;
442   }
443
444   DeleteTriviallyDeadInstructions(DeadInsts);
445
446   // TODO: In the future we could replace all instructions in the loop body with
447   // simpler expressions.  It's not clear how useful this would be though or if
448   // the code expansion cost would be worth it!  We probably shouldn't do this
449   // until we have a way to reuse expressions already in the code.
450 #if 0
451   for (unsigned i = 0, e = L->getBlocks().size(); i != e; ++i)
452     if (LI->getLoopFor(L->getBlocks()[i]) == L) {  // Not in a subloop...
453       BasicBlock *BB = L->getBlocks()[i];
454       for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
455         if (I->getType()->isInteger() &&      // Is an integer instruction
456             !Rewriter.isInsertedInstruction(I)) {
457           SCEVHandle SH = SE->getSCEV(I);
458         }
459     }
460 #endif
461 }