faaab5c12e4cfb1807e6c552935d04f8eba63531
[oota-llvm.git] / lib / Transforms / Utils / LoopUnroll.cpp
1 //===-- UnrollLoop.cpp - Loop unrolling utilities -------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements some loop unrolling utilities. It does not define any
11 // actual pass or policy, but provides a single function to perform loop
12 // unrolling.
13 //
14 // The process of unrolling can produce extraneous basic blocks linked with
15 // unconditional branches.  This will be corrected in the future.
16 //
17 //===----------------------------------------------------------------------===//
18
19 #include "llvm/Transforms/Utils/UnrollLoop.h"
20 #include "llvm/ADT/Statistic.h"
21 #include "llvm/Analysis/InstructionSimplify.h"
22 #include "llvm/Analysis/LoopIterator.h"
23 #include "llvm/Analysis/LoopPass.h"
24 #include "llvm/Analysis/ScalarEvolution.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Dominators.h"
27 #include "llvm/IR/LLVMContext.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
31 #include "llvm/Transforms/Utils/Cloning.h"
32 #include "llvm/Transforms/Utils/Local.h"
33 #include "llvm/Transforms/Utils/LoopUtils.h"
34 #include "llvm/Transforms/Utils/SimplifyIndVar.h"
35 using namespace llvm;
36
37 #define DEBUG_TYPE "loop-unroll"
38
39 // TODO: Should these be here or in LoopUnroll?
40 STATISTIC(NumCompletelyUnrolled, "Number of loops completely unrolled");
41 STATISTIC(NumUnrolled, "Number of loops unrolled (completely or otherwise)");
42
43 /// RemapInstruction - Convert the instruction operands from referencing the
44 /// current values into those specified by VMap.
45 static inline void RemapInstruction(Instruction *I,
46                                     ValueToValueMapTy &VMap) {
47   for (unsigned op = 0, E = I->getNumOperands(); op != E; ++op) {
48     Value *Op = I->getOperand(op);
49     ValueToValueMapTy::iterator It = VMap.find(Op);
50     if (It != VMap.end())
51       I->setOperand(op, It->second);
52   }
53
54   if (PHINode *PN = dyn_cast<PHINode>(I)) {
55     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
56       ValueToValueMapTy::iterator It = VMap.find(PN->getIncomingBlock(i));
57       if (It != VMap.end())
58         PN->setIncomingBlock(i, cast<BasicBlock>(It->second));
59     }
60   }
61 }
62
63 /// FoldBlockIntoPredecessor - Folds a basic block into its predecessor if it
64 /// only has one predecessor, and that predecessor only has one successor.
65 /// The LoopInfo Analysis that is passed will be kept consistent.
66 /// Returns the new combined block.
67 static BasicBlock *FoldBlockIntoPredecessor(BasicBlock *BB, LoopInfo* LI,
68                                             LPPassManager *LPM) {
69   // Merge basic blocks into their predecessor if there is only one distinct
70   // pred, and if there is only one distinct successor of the predecessor, and
71   // if there are no PHI nodes.
72   BasicBlock *OnlyPred = BB->getSinglePredecessor();
73   if (!OnlyPred) return nullptr;
74
75   if (OnlyPred->getTerminator()->getNumSuccessors() != 1)
76     return nullptr;
77
78   DEBUG(dbgs() << "Merging: " << *BB << "into: " << *OnlyPred);
79
80   // Resolve any PHI nodes at the start of the block.  They are all
81   // guaranteed to have exactly one entry if they exist, unless there are
82   // multiple duplicate (but guaranteed to be equal) entries for the
83   // incoming edges.  This occurs when there are multiple edges from
84   // OnlyPred to OnlySucc.
85   FoldSingleEntryPHINodes(BB);
86
87   // Delete the unconditional branch from the predecessor...
88   OnlyPred->getInstList().pop_back();
89
90   // Make all PHI nodes that referred to BB now refer to Pred as their
91   // source...
92   BB->replaceAllUsesWith(OnlyPred);
93
94   // Move all definitions in the successor to the predecessor...
95   OnlyPred->getInstList().splice(OnlyPred->end(), BB->getInstList());
96
97   // OldName will be valid until erased.
98   StringRef OldName = BB->getName();
99
100   // Erase basic block from the function...
101
102   // ScalarEvolution holds references to loop exit blocks.
103   if (LPM) {
104     if (ScalarEvolution *SE = LPM->getAnalysisIfAvailable<ScalarEvolution>()) {
105       if (Loop *L = LI->getLoopFor(BB))
106         SE->forgetLoop(L);
107     }
108   }
109   LI->removeBlock(BB);
110
111   // Inherit predecessor's name if it exists...
112   if (!OldName.empty() && !OnlyPred->hasName())
113     OnlyPred->setName(OldName);
114
115   BB->eraseFromParent();
116
117   return OnlyPred;
118 }
119
120 /// Unroll the given loop by Count. The loop must be in LCSSA form. Returns true
121 /// if unrolling was successful, or false if the loop was unmodified. Unrolling
122 /// can only fail when the loop's latch block is not terminated by a conditional
123 /// branch instruction. However, if the trip count (and multiple) are not known,
124 /// loop unrolling will mostly produce more code that is no faster.
125 ///
126 /// TripCount is generally defined as the number of times the loop header
127 /// executes. UnrollLoop relaxes the definition to permit early exits: here
128 /// TripCount is the iteration on which control exits LatchBlock if no early
129 /// exits were taken. Note that UnrollLoop assumes that the loop counter test
130 /// terminates LatchBlock in order to remove unnecesssary instances of the
131 /// test. In other words, control may exit the loop prior to TripCount
132 /// iterations via an early branch, but control may not exit the loop from the
133 /// LatchBlock's terminator prior to TripCount iterations.
134 ///
135 /// Similarly, TripMultiple divides the number of times that the LatchBlock may
136 /// execute without exiting the loop.
137 ///
138 /// The LoopInfo Analysis that is passed will be kept consistent.
139 ///
140 /// If a LoopPassManager is passed in, and the loop is fully removed, it will be
141 /// removed from the LoopPassManager as well. LPM can also be NULL.
142 ///
143 /// This utility preserves LoopInfo. If DominatorTree or ScalarEvolution are
144 /// available from the Pass it must also preserve those analyses.
145 bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount,
146                       bool AllowRuntime, unsigned TripMultiple,
147                       LoopInfo *LI, Pass *PP, LPPassManager *LPM) {
148   BasicBlock *Preheader = L->getLoopPreheader();
149   if (!Preheader) {
150     DEBUG(dbgs() << "  Can't unroll; loop preheader-insertion failed.\n");
151     return false;
152   }
153
154   BasicBlock *LatchBlock = L->getLoopLatch();
155   if (!LatchBlock) {
156     DEBUG(dbgs() << "  Can't unroll; loop exit-block-insertion failed.\n");
157     return false;
158   }
159
160   // Loops with indirectbr cannot be cloned.
161   if (!L->isSafeToClone()) {
162     DEBUG(dbgs() << "  Can't unroll; Loop body cannot be cloned.\n");
163     return false;
164   }
165
166   BasicBlock *Header = L->getHeader();
167   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
168
169   if (!BI || BI->isUnconditional()) {
170     // The loop-rotate pass can be helpful to avoid this in many cases.
171     DEBUG(dbgs() <<
172              "  Can't unroll; loop not terminated by a conditional branch.\n");
173     return false;
174   }
175
176   if (Header->hasAddressTaken()) {
177     // The loop-rotate pass can be helpful to avoid this in many cases.
178     DEBUG(dbgs() <<
179           "  Won't unroll loop: address of header block is taken.\n");
180     return false;
181   }
182
183   if (TripCount != 0)
184     DEBUG(dbgs() << "  Trip Count = " << TripCount << "\n");
185   if (TripMultiple != 1)
186     DEBUG(dbgs() << "  Trip Multiple = " << TripMultiple << "\n");
187
188   // Effectively "DCE" unrolled iterations that are beyond the tripcount
189   // and will never be executed.
190   if (TripCount != 0 && Count > TripCount)
191     Count = TripCount;
192
193   // Don't enter the unroll code if there is nothing to do. This way we don't
194   // need to support "partial unrolling by 1".
195   if (TripCount == 0 && Count < 2)
196     return false;
197
198   assert(Count > 0);
199   assert(TripMultiple > 0);
200   assert(TripCount == 0 || TripCount % TripMultiple == 0);
201
202   // Are we eliminating the loop control altogether?
203   bool CompletelyUnroll = Count == TripCount;
204
205   // We assume a run-time trip count if the compiler cannot
206   // figure out the loop trip count and the unroll-runtime
207   // flag is specified.
208   bool RuntimeTripCount = (TripCount == 0 && Count > 0 && AllowRuntime);
209
210   if (RuntimeTripCount && !UnrollRuntimeLoopProlog(L, Count, LI, LPM))
211     return false;
212
213   // Notify ScalarEvolution that the loop will be substantially changed,
214   // if not outright eliminated.
215   if (PP) {
216     ScalarEvolution *SE = PP->getAnalysisIfAvailable<ScalarEvolution>();
217     if (SE)
218       SE->forgetLoop(L);
219   }
220
221   // If we know the trip count, we know the multiple...
222   unsigned BreakoutTrip = 0;
223   if (TripCount != 0) {
224     BreakoutTrip = TripCount % Count;
225     TripMultiple = 0;
226   } else {
227     // Figure out what multiple to use.
228     BreakoutTrip = TripMultiple =
229       (unsigned)GreatestCommonDivisor64(Count, TripMultiple);
230   }
231
232   // Report the unrolling decision.
233   DebugLoc LoopLoc = L->getStartLoc();
234   Function *F = Header->getParent();
235   LLVMContext &Ctx = F->getContext();
236
237   if (CompletelyUnroll) {
238     DEBUG(dbgs() << "COMPLETELY UNROLLING loop %" << Header->getName()
239           << " with trip count " << TripCount << "!\n");
240     Ctx.emitOptimizationRemark(DEBUG_TYPE, *F, LoopLoc,
241                                Twine("completely unrolled loop with ") +
242                                    Twine(TripCount) + " iterations");
243   } else {
244     DEBUG(dbgs() << "UNROLLING loop %" << Header->getName()
245           << " by " << Count);
246     Twine DiagMsg("unrolled loop by a factor of " + Twine(Count));
247     if (TripMultiple == 0 || BreakoutTrip != TripMultiple) {
248       DEBUG(dbgs() << " with a breakout at trip " << BreakoutTrip);
249       DiagMsg.concat(" with a breakout at trip " + Twine(BreakoutTrip));
250     } else if (TripMultiple != 1) {
251       DEBUG(dbgs() << " with " << TripMultiple << " trips per branch");
252       DiagMsg.concat(" with " + Twine(TripMultiple) + " trips per branch");
253     } else if (RuntimeTripCount) {
254       DEBUG(dbgs() << " with run-time trip count");
255       DiagMsg.concat(" with run-time trip count");
256     }
257     DEBUG(dbgs() << "!\n");
258     Ctx.emitOptimizationRemark(DEBUG_TYPE, *F, LoopLoc, DiagMsg);
259   }
260
261   bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
262   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
263
264   // For the first iteration of the loop, we should use the precloned values for
265   // PHI nodes.  Insert associations now.
266   ValueToValueMapTy LastValueMap;
267   std::vector<PHINode*> OrigPHINode;
268   for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) {
269     OrigPHINode.push_back(cast<PHINode>(I));
270   }
271
272   std::vector<BasicBlock*> Headers;
273   std::vector<BasicBlock*> Latches;
274   Headers.push_back(Header);
275   Latches.push_back(LatchBlock);
276
277   // The current on-the-fly SSA update requires blocks to be processed in
278   // reverse postorder so that LastValueMap contains the correct value at each
279   // exit.
280   LoopBlocksDFS DFS(L);
281   DFS.perform(LI);
282
283   // Stash the DFS iterators before adding blocks to the loop.
284   LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
285   LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
286
287   for (unsigned It = 1; It != Count; ++It) {
288     std::vector<BasicBlock*> NewBlocks;
289
290     for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
291       ValueToValueMapTy VMap;
292       BasicBlock *New = CloneBasicBlock(*BB, VMap, "." + Twine(It));
293       Header->getParent()->getBasicBlockList().push_back(New);
294
295       // Loop over all of the PHI nodes in the block, changing them to use the
296       // incoming values from the previous block.
297       if (*BB == Header)
298         for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) {
299           PHINode *NewPHI = cast<PHINode>(VMap[OrigPHINode[i]]);
300           Value *InVal = NewPHI->getIncomingValueForBlock(LatchBlock);
301           if (Instruction *InValI = dyn_cast<Instruction>(InVal))
302             if (It > 1 && L->contains(InValI))
303               InVal = LastValueMap[InValI];
304           VMap[OrigPHINode[i]] = InVal;
305           New->getInstList().erase(NewPHI);
306         }
307
308       // Update our running map of newest clones
309       LastValueMap[*BB] = New;
310       for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
311            VI != VE; ++VI)
312         LastValueMap[VI->first] = VI->second;
313
314       L->addBasicBlockToLoop(New, LI->getBase());
315
316       // Add phi entries for newly created values to all exit blocks.
317       for (succ_iterator SI = succ_begin(*BB), SE = succ_end(*BB);
318            SI != SE; ++SI) {
319         if (L->contains(*SI))
320           continue;
321         for (BasicBlock::iterator BBI = (*SI)->begin();
322              PHINode *phi = dyn_cast<PHINode>(BBI); ++BBI) {
323           Value *Incoming = phi->getIncomingValueForBlock(*BB);
324           ValueToValueMapTy::iterator It = LastValueMap.find(Incoming);
325           if (It != LastValueMap.end())
326             Incoming = It->second;
327           phi->addIncoming(Incoming, New);
328         }
329       }
330       // Keep track of new headers and latches as we create them, so that
331       // we can insert the proper branches later.
332       if (*BB == Header)
333         Headers.push_back(New);
334       if (*BB == LatchBlock)
335         Latches.push_back(New);
336
337       NewBlocks.push_back(New);
338     }
339
340     // Remap all instructions in the most recent iteration
341     for (unsigned i = 0; i < NewBlocks.size(); ++i)
342       for (BasicBlock::iterator I = NewBlocks[i]->begin(),
343            E = NewBlocks[i]->end(); I != E; ++I)
344         ::RemapInstruction(I, LastValueMap);
345   }
346
347   // Loop over the PHI nodes in the original block, setting incoming values.
348   for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) {
349     PHINode *PN = OrigPHINode[i];
350     if (CompletelyUnroll) {
351       PN->replaceAllUsesWith(PN->getIncomingValueForBlock(Preheader));
352       Header->getInstList().erase(PN);
353     }
354     else if (Count > 1) {
355       Value *InVal = PN->removeIncomingValue(LatchBlock, false);
356       // If this value was defined in the loop, take the value defined by the
357       // last iteration of the loop.
358       if (Instruction *InValI = dyn_cast<Instruction>(InVal)) {
359         if (L->contains(InValI))
360           InVal = LastValueMap[InVal];
361       }
362       assert(Latches.back() == LastValueMap[LatchBlock] && "bad last latch");
363       PN->addIncoming(InVal, Latches.back());
364     }
365   }
366
367   // Now that all the basic blocks for the unrolled iterations are in place,
368   // set up the branches to connect them.
369   for (unsigned i = 0, e = Latches.size(); i != e; ++i) {
370     // The original branch was replicated in each unrolled iteration.
371     BranchInst *Term = cast<BranchInst>(Latches[i]->getTerminator());
372
373     // The branch destination.
374     unsigned j = (i + 1) % e;
375     BasicBlock *Dest = Headers[j];
376     bool NeedConditional = true;
377
378     if (RuntimeTripCount && j != 0) {
379       NeedConditional = false;
380     }
381
382     // For a complete unroll, make the last iteration end with a branch
383     // to the exit block.
384     if (CompletelyUnroll && j == 0) {
385       Dest = LoopExit;
386       NeedConditional = false;
387     }
388
389     // If we know the trip count or a multiple of it, we can safely use an
390     // unconditional branch for some iterations.
391     if (j != BreakoutTrip && (TripMultiple == 0 || j % TripMultiple != 0)) {
392       NeedConditional = false;
393     }
394
395     if (NeedConditional) {
396       // Update the conditional branch's successor for the following
397       // iteration.
398       Term->setSuccessor(!ContinueOnTrue, Dest);
399     } else {
400       // Remove phi operands at this loop exit
401       if (Dest != LoopExit) {
402         BasicBlock *BB = Latches[i];
403         for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB);
404              SI != SE; ++SI) {
405           if (*SI == Headers[i])
406             continue;
407           for (BasicBlock::iterator BBI = (*SI)->begin();
408                PHINode *Phi = dyn_cast<PHINode>(BBI); ++BBI) {
409             Phi->removeIncomingValue(BB, false);
410           }
411         }
412       }
413       // Replace the conditional branch with an unconditional one.
414       BranchInst::Create(Dest, Term);
415       Term->eraseFromParent();
416     }
417   }
418
419   // Merge adjacent basic blocks, if possible.
420   for (unsigned i = 0, e = Latches.size(); i != e; ++i) {
421     BranchInst *Term = cast<BranchInst>(Latches[i]->getTerminator());
422     if (Term->isUnconditional()) {
423       BasicBlock *Dest = Term->getSuccessor(0);
424       if (BasicBlock *Fold = FoldBlockIntoPredecessor(Dest, LI, LPM))
425         std::replace(Latches.begin(), Latches.end(), Dest, Fold);
426     }
427   }
428
429   DominatorTree *DT = nullptr;
430   if (PP) {
431     // FIXME: Reconstruct dom info, because it is not preserved properly.
432     // Incrementally updating domtree after loop unrolling would be easy.
433     if (DominatorTreeWrapperPass *DTWP =
434             PP->getAnalysisIfAvailable<DominatorTreeWrapperPass>()) {
435       DT = &DTWP->getDomTree();
436       DT->recalculate(*L->getHeader()->getParent());
437     }
438
439     // Simplify any new induction variables in the partially unrolled loop.
440     ScalarEvolution *SE = PP->getAnalysisIfAvailable<ScalarEvolution>();
441     if (SE && !CompletelyUnroll) {
442       SmallVector<WeakVH, 16> DeadInsts;
443       simplifyLoopIVs(L, SE, LPM, DeadInsts);
444
445       // Aggressively clean up dead instructions that simplifyLoopIVs already
446       // identified. Any remaining should be cleaned up below.
447       while (!DeadInsts.empty())
448         if (Instruction *Inst =
449             dyn_cast_or_null<Instruction>(&*DeadInsts.pop_back_val()))
450           RecursivelyDeleteTriviallyDeadInstructions(Inst);
451     }
452   }
453   // At this point, the code is well formed.  We now do a quick sweep over the
454   // inserted code, doing constant propagation and dead code elimination as we
455   // go.
456   const std::vector<BasicBlock*> &NewLoopBlocks = L->getBlocks();
457   for (std::vector<BasicBlock*>::const_iterator BB = NewLoopBlocks.begin(),
458        BBE = NewLoopBlocks.end(); BB != BBE; ++BB)
459     for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end(); I != E; ) {
460       Instruction *Inst = I++;
461
462       if (isInstructionTriviallyDead(Inst))
463         (*BB)->getInstList().erase(Inst);
464       else if (Value *V = SimplifyInstruction(Inst))
465         if (LI->replacementPreservesLCSSAForm(Inst, V)) {
466           Inst->replaceAllUsesWith(V);
467           (*BB)->getInstList().erase(Inst);
468         }
469     }
470
471   NumCompletelyUnrolled += CompletelyUnroll;
472   ++NumUnrolled;
473
474   Loop *OuterL = L->getParentLoop();
475   // Remove the loop from the LoopPassManager if it's completely removed.
476   if (CompletelyUnroll && LPM != nullptr)
477     LPM->deleteLoopFromQueue(L);
478
479   // If we have a pass and a DominatorTree we should re-simplify impacted loops
480   // to ensure subsequent analyses can rely on this form. We want to simplify
481   // at least one layer outside of the loop that was unrolled so that any
482   // changes to the parent loop exposed by the unrolling are considered.
483   if (PP && DT) {
484     if (!OuterL && !CompletelyUnroll)
485       OuterL = L;
486     if (OuterL) {
487       ScalarEvolution *SE = PP->getAnalysisIfAvailable<ScalarEvolution>();
488       simplifyLoop(OuterL, DT, LI, PP, /*AliasAnalysis*/ nullptr, SE);
489       formLCSSARecursively(*OuterL, *DT, SE);
490     }
491   }
492
493   return true;
494 }