Fix a really obvious huge gaping bug, add a comment
[oota-llvm.git] / lib / Transforms / Scalar / TailRecursionElimination.cpp
1 //===- TailRecursionElimination.cpp - Eliminate Tail Calls ----------------===//
2 //
3 // This file implements tail recursion elimination.
4 //
5 // Caveats: The algorithm implemented is trivially simple.  There are several
6 // improvements that could be made:
7 //
8 //  1. If the function has any alloca instructions, these instructions will not
9 //     remain in the entry block of the function.  Doing this requires analysis
10 //     to prove that the alloca is not reachable by the recursively invoked
11 //     function call.
12 //  2. Tail recursion is only performed if the call immediately preceeds the
13 //     return instruction.  Would it be useful to generalize this somehow?
14 //  3. TRE is only performed if the function returns void or if the return
15 //     returns the result returned by the call.  It is possible, but unlikely,
16 //     that the return returns something else (like constant 0), and can still
17 //     be TRE'd.  It can be TRE'd if ALL OTHER return instructions in the
18 //     function return the exact same value.
19 //
20 //===----------------------------------------------------------------------===//
21
22 #include "llvm/Transforms/Scalar.h"
23 #include "llvm/DerivedTypes.h"
24 #include "llvm/Function.h"
25 #include "llvm/Instructions.h"
26 #include "llvm/Pass.h"
27 #include "Support/Statistic.h"
28
29 namespace {
30   Statistic<> NumEliminated("tailcallelim", "Number of tail calls removed");
31
32   struct TailCallElim : public FunctionPass {
33     virtual bool runOnFunction(Function &F);
34   };
35   RegisterOpt<TailCallElim> X("tailcallelim", "Tail Call Elimination");
36 }
37
38 FunctionPass *createTailCallEliminationPass() { return new TailCallElim(); }
39
40
41 bool TailCallElim::runOnFunction(Function &F) {
42   // If this function is a varargs function, we won't be able to PHI the args
43   // right, so don't even try to convert it...
44   if (F.getFunctionType()->isVarArg()) return false;
45
46   BasicBlock *OldEntry = 0;
47   std::vector<PHINode*> ArgumentPHIs;
48   bool MadeChange = false;
49
50   // Loop over the function, looking for any returning blocks...
51   for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB)
52     if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator()))
53       if (Ret != BB->begin())  // Make sure there is something before the ret...
54         if (CallInst *CI = dyn_cast<CallInst>(Ret->getPrev()))
55           // Make sure the tail call is to the current function, and that the
56           // return either returns void or returns the value computed by the
57           // call.
58           if (CI->getCalledFunction() == &F &&
59               (Ret->getNumOperands() == 0 || Ret->getReturnValue() == CI)) {
60             // Ohh, it looks like we found a tail call, is this the first?
61             if (!OldEntry) {
62               // Ok, so this is the first tail call we have found in this
63               // function.  Insert a new entry block into the function, allowing
64               // us to branch back to the old entry block.
65               OldEntry = &F.getEntryNode();
66               BasicBlock *NewEntry = new BasicBlock("tailrecurse", OldEntry);
67               NewEntry->getInstList().push_back(new BranchInst(OldEntry));
68               
69               // Now that we have created a new block, which jumps to the entry
70               // block, insert a PHI node for each argument of the function.
71               // For now, we initialize each PHI to only have the real arguments
72               // which are passed in.
73               Instruction *InsertPos = OldEntry->begin();
74               for (Function::aiterator I = F.abegin(), E = F.aend(); I!=E; ++I){
75                 PHINode *PN = new PHINode(I->getType(), I->getName()+".tr",
76                                           InsertPos);
77                 I->replaceAllUsesWith(PN); // Everyone use the PHI node now!
78                 PN->addIncoming(I, NewEntry);
79                 ArgumentPHIs.push_back(PN);
80               }
81             }
82             
83             // Ok, now that we know we have a pseudo-entry block WITH all of the
84             // required PHI nodes, add entries into the PHI node for the actual
85             // parameters passed into the tail-recursive call.
86             for (unsigned i = 0, e = CI->getNumOperands()-1; i != e; ++i)
87               ArgumentPHIs[i]->addIncoming(CI->getOperand(i+1), BB);
88
89             // Now that all of the PHI nodes are in place, remove the call and
90             // ret instructions, replacing them with an unconditional branch.
91             new BranchInst(OldEntry, CI);
92             BB->getInstList().pop_back();  // Remove return.
93             BB->getInstList().pop_back();  // Remove call.
94             MadeChange = true;
95             NumEliminated++;
96           }
97   
98   return MadeChange;
99 }
100