Revert r188119 "Kill some duplicated code for removing unreachable BBs."
[oota-llvm.git] / lib / Transforms / Scalar / SimplifyCFGPass.cpp
1 //===- SimplifyCFGPass.cpp - CFG Simplification Pass ----------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements dead code elimination and basic block merging, along
11 // with a collection of other peephole control flow optimizations.  For example:
12 //
13 //   * Removes basic blocks with no predecessors.
14 //   * Merges a basic block into its predecessor if there is only one and the
15 //     predecessor only has one successor.
16 //   * Eliminates PHI nodes for basic blocks with a single predecessor.
17 //   * Eliminates a basic block that only contains an unconditional branch.
18 //   * Changes invoke instructions to nounwind functions to be calls.
19 //   * Change things like "if (x) if (y)" into "if (x&y)".
20 //   * etc..
21 //
22 //===----------------------------------------------------------------------===//
23
24 #define DEBUG_TYPE "simplifycfg"
25 #include "llvm/Transforms/Scalar.h"
26 #include "llvm/ADT/SmallPtrSet.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/Statistic.h"
29 #include "llvm/Analysis/TargetTransformInfo.h"
30 #include "llvm/IR/Attributes.h"
31 #include "llvm/IR/Constants.h"
32 #include "llvm/IR/DataLayout.h"
33 #include "llvm/IR/Instructions.h"
34 #include "llvm/IR/IntrinsicInst.h"
35 #include "llvm/IR/Module.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Support/CFG.h"
38 #include "llvm/Transforms/Utils/Local.h"
39 using namespace llvm;
40
41 STATISTIC(NumSimpl, "Number of blocks simplified");
42
43 namespace {
44 struct CFGSimplifyPass : public FunctionPass {
45   static char ID; // Pass identification, replacement for typeid
46   CFGSimplifyPass() : FunctionPass(ID) {
47     initializeCFGSimplifyPassPass(*PassRegistry::getPassRegistry());
48   }
49   virtual bool runOnFunction(Function &F);
50
51   virtual void getAnalysisUsage(AnalysisUsage &AU) const {
52     AU.addRequired<TargetTransformInfo>();
53   }
54 };
55 }
56
57 char CFGSimplifyPass::ID = 0;
58 INITIALIZE_PASS_BEGIN(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false,
59                       false)
60 INITIALIZE_AG_DEPENDENCY(TargetTransformInfo)
61 INITIALIZE_PASS_END(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false,
62                     false)
63
64 // Public interface to the CFGSimplification pass
65 FunctionPass *llvm::createCFGSimplificationPass() {
66   return new CFGSimplifyPass();
67 }
68
69 /// changeToUnreachable - Insert an unreachable instruction before the specified
70 /// instruction, making it and the rest of the code in the block dead.
71 static void changeToUnreachable(Instruction *I, bool UseLLVMTrap) {
72   BasicBlock *BB = I->getParent();
73   // Loop over all of the successors, removing BB's entry from any PHI
74   // nodes.
75   for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI)
76     (*SI)->removePredecessor(BB);
77
78   // Insert a call to llvm.trap right before this.  This turns the undefined
79   // behavior into a hard fail instead of falling through into random code.
80   if (UseLLVMTrap) {
81     Function *TrapFn =
82       Intrinsic::getDeclaration(BB->getParent()->getParent(), Intrinsic::trap);
83     CallInst *CallTrap = CallInst::Create(TrapFn, "", I);
84     CallTrap->setDebugLoc(I->getDebugLoc());
85   }
86   new UnreachableInst(I->getContext(), I);
87
88   // All instructions after this are dead.
89   BasicBlock::iterator BBI = I, BBE = BB->end();
90   while (BBI != BBE) {
91     if (!BBI->use_empty())
92       BBI->replaceAllUsesWith(UndefValue::get(BBI->getType()));
93     BB->getInstList().erase(BBI++);
94   }
95 }
96
97 /// changeToCall - Convert the specified invoke into a normal call.
98 static void changeToCall(InvokeInst *II) {
99   SmallVector<Value*, 8> Args(II->op_begin(), II->op_end() - 3);
100   CallInst *NewCall = CallInst::Create(II->getCalledValue(), Args, "", II);
101   NewCall->takeName(II);
102   NewCall->setCallingConv(II->getCallingConv());
103   NewCall->setAttributes(II->getAttributes());
104   NewCall->setDebugLoc(II->getDebugLoc());
105   II->replaceAllUsesWith(NewCall);
106
107   // Follow the call by a branch to the normal destination.
108   BranchInst::Create(II->getNormalDest(), II);
109
110   // Update PHI nodes in the unwind destination
111   II->getUnwindDest()->removePredecessor(II->getParent());
112   II->eraseFromParent();
113 }
114
115 static bool markAliveBlocks(BasicBlock *BB,
116                             SmallPtrSet<BasicBlock*, 128> &Reachable) {
117
118   SmallVector<BasicBlock*, 128> Worklist;
119   Worklist.push_back(BB);
120   Reachable.insert(BB);
121   bool Changed = false;
122   do {
123     BB = Worklist.pop_back_val();
124
125     // Do a quick scan of the basic block, turning any obviously unreachable
126     // instructions into LLVM unreachable insts.  The instruction combining pass
127     // canonicalizes unreachable insts into stores to null or undef.
128     for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E;++BBI){
129       if (CallInst *CI = dyn_cast<CallInst>(BBI)) {
130         if (CI->doesNotReturn()) {
131           // If we found a call to a no-return function, insert an unreachable
132           // instruction after it.  Make sure there isn't *already* one there
133           // though.
134           ++BBI;
135           if (!isa<UnreachableInst>(BBI)) {
136             // Don't insert a call to llvm.trap right before the unreachable.
137             changeToUnreachable(BBI, false);
138             Changed = true;
139           }
140           break;
141         }
142       }
143
144       // Store to undef and store to null are undefined and used to signal that
145       // they should be changed to unreachable by passes that can't modify the
146       // CFG.
147       if (StoreInst *SI = dyn_cast<StoreInst>(BBI)) {
148         // Don't touch volatile stores.
149         if (SI->isVolatile()) continue;
150
151         Value *Ptr = SI->getOperand(1);
152
153         if (isa<UndefValue>(Ptr) ||
154             (isa<ConstantPointerNull>(Ptr) &&
155              SI->getPointerAddressSpace() == 0)) {
156           changeToUnreachable(SI, true);
157           Changed = true;
158           break;
159         }
160       }
161     }
162
163     // Turn invokes that call 'nounwind' functions into ordinary calls.
164     if (InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator())) {
165       Value *Callee = II->getCalledValue();
166       if (isa<ConstantPointerNull>(Callee) || isa<UndefValue>(Callee)) {
167         changeToUnreachable(II, true);
168         Changed = true;
169       } else if (II->doesNotThrow()) {
170         if (II->use_empty() && II->onlyReadsMemory()) {
171           // jump to the normal destination branch.
172           BranchInst::Create(II->getNormalDest(), II);
173           II->getUnwindDest()->removePredecessor(II->getParent());
174           II->eraseFromParent();
175         } else
176           changeToCall(II);
177         Changed = true;
178       }
179     }
180
181     Changed |= ConstantFoldTerminator(BB, true);
182     for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI)
183       if (Reachable.insert(*SI))
184         Worklist.push_back(*SI);
185   } while (!Worklist.empty());
186   return Changed;
187 }
188
189 /// removeUnreachableBlocksFromFn - Remove blocks that are not reachable, even
190 /// if they are in a dead cycle.  Return true if a change was made, false
191 /// otherwise.
192 static bool removeUnreachableBlocksFromFn(Function &F) {
193   SmallPtrSet<BasicBlock*, 128> Reachable;
194   bool Changed = markAliveBlocks(F.begin(), Reachable);
195
196   // If there are unreachable blocks in the CFG...
197   if (Reachable.size() == F.size())
198     return Changed;
199
200   assert(Reachable.size() < F.size());
201   NumSimpl += F.size()-Reachable.size();
202
203   // Loop over all of the basic blocks that are not reachable, dropping all of
204   // their internal references...
205   for (Function::iterator BB = ++F.begin(), E = F.end(); BB != E; ++BB) {
206     if (Reachable.count(BB))
207       continue;
208
209     for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; ++SI)
210       if (Reachable.count(*SI))
211         (*SI)->removePredecessor(BB);
212     BB->dropAllReferences();
213   }
214
215   for (Function::iterator I = ++F.begin(); I != F.end();)
216     if (!Reachable.count(I))
217       I = F.getBasicBlockList().erase(I);
218     else
219       ++I;
220
221   return true;
222 }
223
224 /// mergeEmptyReturnBlocks - If we have more than one empty (other than phi
225 /// node) return blocks, merge them together to promote recursive block merging.
226 static bool mergeEmptyReturnBlocks(Function &F) {
227   bool Changed = false;
228
229   BasicBlock *RetBlock = 0;
230
231   // Scan all the blocks in the function, looking for empty return blocks.
232   for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; ) {
233     BasicBlock &BB = *BBI++;
234
235     // Only look at return blocks.
236     ReturnInst *Ret = dyn_cast<ReturnInst>(BB.getTerminator());
237     if (Ret == 0) continue;
238
239     // Only look at the block if it is empty or the only other thing in it is a
240     // single PHI node that is the operand to the return.
241     if (Ret != &BB.front()) {
242       // Check for something else in the block.
243       BasicBlock::iterator I = Ret;
244       --I;
245       // Skip over debug info.
246       while (isa<DbgInfoIntrinsic>(I) && I != BB.begin())
247         --I;
248       if (!isa<DbgInfoIntrinsic>(I) &&
249           (!isa<PHINode>(I) || I != BB.begin() ||
250            Ret->getNumOperands() == 0 ||
251            Ret->getOperand(0) != I))
252         continue;
253     }
254
255     // If this is the first returning block, remember it and keep going.
256     if (RetBlock == 0) {
257       RetBlock = &BB;
258       continue;
259     }
260
261     // Otherwise, we found a duplicate return block.  Merge the two.
262     Changed = true;
263
264     // Case when there is no input to the return or when the returned values
265     // agree is trivial.  Note that they can't agree if there are phis in the
266     // blocks.
267     if (Ret->getNumOperands() == 0 ||
268         Ret->getOperand(0) ==
269           cast<ReturnInst>(RetBlock->getTerminator())->getOperand(0)) {
270       BB.replaceAllUsesWith(RetBlock);
271       BB.eraseFromParent();
272       continue;
273     }
274
275     // If the canonical return block has no PHI node, create one now.
276     PHINode *RetBlockPHI = dyn_cast<PHINode>(RetBlock->begin());
277     if (RetBlockPHI == 0) {
278       Value *InVal = cast<ReturnInst>(RetBlock->getTerminator())->getOperand(0);
279       pred_iterator PB = pred_begin(RetBlock), PE = pred_end(RetBlock);
280       RetBlockPHI = PHINode::Create(Ret->getOperand(0)->getType(),
281                                     std::distance(PB, PE), "merge",
282                                     &RetBlock->front());
283
284       for (pred_iterator PI = PB; PI != PE; ++PI)
285         RetBlockPHI->addIncoming(InVal, *PI);
286       RetBlock->getTerminator()->setOperand(0, RetBlockPHI);
287     }
288
289     // Turn BB into a block that just unconditionally branches to the return
290     // block.  This handles the case when the two return blocks have a common
291     // predecessor but that return different things.
292     RetBlockPHI->addIncoming(Ret->getOperand(0), &BB);
293     BB.getTerminator()->eraseFromParent();
294     BranchInst::Create(RetBlock, &BB);
295   }
296
297   return Changed;
298 }
299
300 /// iterativelySimplifyCFG - Call SimplifyCFG on all the blocks in the function,
301 /// iterating until no more changes are made.
302 static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI,
303                                    const DataLayout *TD) {
304   bool Changed = false;
305   bool LocalChange = true;
306   while (LocalChange) {
307     LocalChange = false;
308
309     // Loop over all of the basic blocks and remove them if they are unneeded...
310     //
311     for (Function::iterator BBIt = F.begin(); BBIt != F.end(); ) {
312       if (SimplifyCFG(BBIt++, TTI, TD)) {
313         LocalChange = true;
314         ++NumSimpl;
315       }
316     }
317     Changed |= LocalChange;
318   }
319   return Changed;
320 }
321
322 // It is possible that we may require multiple passes over the code to fully
323 // simplify the CFG.
324 //
325 bool CFGSimplifyPass::runOnFunction(Function &F) {
326   const TargetTransformInfo &TTI = getAnalysis<TargetTransformInfo>();
327   const DataLayout *TD = getAnalysisIfAvailable<DataLayout>();
328   bool EverChanged = removeUnreachableBlocksFromFn(F);
329   EverChanged |= mergeEmptyReturnBlocks(F);
330   EverChanged |= iterativelySimplifyCFG(F, TTI, TD);
331
332   // If neither pass changed anything, we're done.
333   if (!EverChanged) return false;
334
335   // iterativelySimplifyCFG can (rarely) make some loops dead.  If this happens,
336   // removeUnreachableBlocksFromFn is needed to nuke them, which means we should
337   // iterate between the two optimizations.  We structure the code like this to
338   // avoid reruning iterativelySimplifyCFG if the second pass of
339   // removeUnreachableBlocksFromFn doesn't do anything.
340   if (!removeUnreachableBlocksFromFn(F))
341     return true;
342
343   do {
344     EverChanged = iterativelySimplifyCFG(F, TTI, TD);
345     EverChanged |= removeUnreachableBlocksFromFn(F);
346   } while (EverChanged);
347
348   return true;
349 }