* Rename MethodPass class to FunctionPass
[oota-llvm.git] / lib / Transforms / IPO / DeadTypeElimination.cpp
1 //===- CleanupGCCOutput.cpp - Cleanup GCC Output --------------------------===//
2 //
3 // This pass is used to cleanup the output of GCC.  GCC's output is
4 // unneccessarily gross for a couple of reasons. This pass does the following
5 // things to try to clean it up:
6 //
7 // * Eliminate names for GCC types that we know can't be needed by the user.
8 // * Eliminate names for types that are unused in the entire translation unit
9 // * Fix various problems that we might have in PHI nodes and casts
10 // * Link uses of 'void %foo(...)' to 'void %foo(sometypes)'
11 //
12 // Note:  This code produces dead declarations, it is a good idea to run DCE
13 //        after this pass.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "llvm/Transforms/CleanupGCCOutput.h"
18 #include "llvm/Analysis/FindUsedTypes.h"
19 #include "TransformInternals.h"
20 #include "llvm/Module.h"
21 #include "llvm/SymbolTable.h"
22 #include "llvm/DerivedTypes.h"
23 #include "llvm/iPHINode.h"
24 #include "llvm/iMemory.h"
25 #include "llvm/iTerminators.h"
26 #include "llvm/iOther.h"
27 #include "llvm/Support/CFG.h"
28 #include "llvm/Pass.h"
29 #include <algorithm>
30 #include <iostream>
31 using std::vector;
32 using std::string;
33 using std::cerr;
34
35 static const Type *PtrSByte = 0;    // 'sbyte*' type
36
37 namespace {
38   struct CleanupGCCOutput : public FunctionPass {
39     // doPassInitialization - For this pass, it removes global symbol table
40     // entries for primitive types.  These are never used for linking in GCC and
41     // they make the output uglier to look at, so we nuke them.
42     //
43     // Also, initialize instance variables.
44     //
45     bool doInitialization(Module *M);
46     
47     // runOnFunction - This method simplifies the specified function hopefully.
48     //
49     bool runOnFunction(Function *F);
50     
51     // doPassFinalization - Strip out type names that are unused by the program
52     bool doFinalization(Module *M);
53     
54     // getAnalysisUsage - This function needs FindUsedTypes to do its job...
55     //
56     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
57       AU.addRequired(FindUsedTypes::ID);
58     }
59   };
60 }
61
62 Pass *createCleanupGCCOutputPass() {
63   return new CleanupGCCOutput();
64 }
65
66
67
68 // ShouldNukSymtabEntry - Return true if this module level symbol table entry
69 // should be eliminated.
70 //
71 static inline bool ShouldNukeSymtabEntry(const std::pair<string, Value*> &E) {
72   // Nuke all names for primitive types!
73   if (cast<Type>(E.second)->isPrimitiveType()) return true;
74
75   // Nuke all pointers to primitive types as well...
76   if (const PointerType *PT = dyn_cast<PointerType>(E.second))
77     if (PT->getElementType()->isPrimitiveType()) return true;
78
79   // The only types that could contain .'s in the program are things generated
80   // by GCC itself, including "complex.float" and friends.  Nuke them too.
81   if (E.first.find('.') != string::npos) return true;
82
83   return false;
84 }
85
86 // doInitialization - For this pass, it removes global symbol table
87 // entries for primitive types.  These are never used for linking in GCC and
88 // they make the output uglier to look at, so we nuke them.
89 //
90 bool CleanupGCCOutput::doInitialization(Module *M) {
91   bool Changed = false;
92
93   if (PtrSByte == 0)
94     PtrSByte = PointerType::get(Type::SByteTy);
95
96   if (M->hasSymbolTable()) {
97     SymbolTable *ST = M->getSymbolTable();
98
99     // Check the symbol table for superfluous type entries...
100     //
101     // Grab the 'type' plane of the module symbol...
102     SymbolTable::iterator STI = ST->find(Type::TypeTy);
103     if (STI != ST->end()) {
104       // Loop over all entries in the type plane...
105       SymbolTable::VarMap &Plane = STI->second;
106       for (SymbolTable::VarMap::iterator PI = Plane.begin(); PI != Plane.end();)
107         if (ShouldNukeSymtabEntry(*PI)) {    // Should we remove this entry?
108 #if MAP_IS_NOT_BRAINDEAD
109           PI = Plane.erase(PI);     // STD C++ Map should support this!
110 #else
111           Plane.erase(PI);          // Alas, GCC 2.95.3 doesn't  *SIGH*
112           PI = Plane.begin();
113 #endif
114           Changed = true;
115         } else {
116           ++PI;
117         }
118     }
119   }
120
121   return Changed;
122 }
123
124
125 // FixCastsAndPHIs - The LLVM GCC has a tendancy to intermix Cast instructions
126 // in with the PHI nodes.  These cast instructions are potentially there for two
127 // different reasons:
128 //
129 //   1. The cast could be for an early PHI, and be accidentally inserted before
130 //      another PHI node.  In this case, the PHI node should be moved to the end
131 //      of the PHI nodes in the basic block.  We know that it is this case if
132 //      the source for the cast is a PHI node in this basic block.
133 //
134 //   2. If not #1, the cast must be a source argument for one of the PHI nodes
135 //      in the current basic block.  If this is the case, the cast should be
136 //      lifted into the basic block for the appropriate predecessor. 
137 //
138 static inline bool FixCastsAndPHIs(BasicBlock *BB) {
139   bool Changed = false;
140
141   BasicBlock::iterator InsertPos = BB->begin();
142
143   // Find the end of the interesting instructions...
144   while (isa<PHINode>(*InsertPos) || isa<CastInst>(*InsertPos)) ++InsertPos;
145
146   // Back the InsertPos up to right after the last PHI node.
147   while (InsertPos != BB->begin() && isa<CastInst>(*(InsertPos-1))) --InsertPos;
148
149   // No PHI nodes, quick exit.
150   if (InsertPos == BB->begin()) return false;
151
152   // Loop over all casts trapped between the PHI's...
153   BasicBlock::iterator I = BB->begin();
154   while (I != InsertPos) {
155     if (CastInst *CI = dyn_cast<CastInst>(*I)) { // Fix all cast instructions
156       Value *Src = CI->getOperand(0);
157
158       // Move the cast instruction to the current insert position...
159       --InsertPos;                 // New position for cast to go...
160       std::swap(*InsertPos, *I);   // Cast goes down, PHI goes up
161
162       if (isa<PHINode>(Src) &&                                // Handle case #1
163           cast<PHINode>(Src)->getParent() == BB) {
164         // We're done for case #1
165       } else {                                                // Handle case #2
166         // In case #2, we have to do a few things:
167         //   1. Remove the cast from the current basic block.
168         //   2. Identify the PHI node that the cast is for.
169         //   3. Find out which predecessor the value is for.
170         //   4. Move the cast to the end of the basic block that it SHOULD be
171         //
172
173         // Remove the cast instruction from the basic block.  The remove only
174         // invalidates iterators in the basic block that are AFTER the removed
175         // element.  Because we just moved the CastInst to the InsertPos, no
176         // iterators get invalidated.
177         //
178         BB->getInstList().remove(InsertPos);
179
180         // Find the PHI node.  Since this cast was generated specifically for a
181         // PHI node, there can only be a single PHI node using it.
182         //
183         assert(CI->use_size() == 1 && "Exactly one PHI node should use cast!");
184         PHINode *PN = cast<PHINode>(*CI->use_begin());
185
186         // Find out which operand of the PHI it is...
187         unsigned i;
188         for (i = 0; i < PN->getNumIncomingValues(); ++i)
189           if (PN->getIncomingValue(i) == CI)
190             break;
191         assert(i != PN->getNumIncomingValues() && "PHI doesn't use cast!");
192
193         // Get the predecessor the value is for...
194         BasicBlock *Pred = PN->getIncomingBlock(i);
195
196         // Reinsert the cast right before the terminator in Pred.
197         Pred->getInstList().insert(Pred->end()-1, CI);
198       }
199     } else {
200       ++I;
201     }
202   }
203
204   return Changed;
205 }
206
207 // RefactorPredecessor - When we find out that a basic block is a repeated
208 // predecessor in a PHI node, we have to refactor the function until there is at
209 // most a single instance of a basic block in any predecessor list.
210 //
211 static inline void RefactorPredecessor(BasicBlock *BB, BasicBlock *Pred) {
212   Function *M = BB->getParent();
213   assert(find(pred_begin(BB), pred_end(BB), Pred) != pred_end(BB) &&
214          "Pred is not a predecessor of BB!");
215
216   // Create a new basic block, adding it to the end of the function.
217   BasicBlock *NewBB = new BasicBlock("", M);
218
219   // Add an unconditional branch to BB to the new block.
220   NewBB->getInstList().push_back(new BranchInst(BB));
221
222   // Get the terminator that causes a branch to BB from Pred.
223   TerminatorInst *TI = Pred->getTerminator();
224
225   // Find the first use of BB in the terminator...
226   User::op_iterator OI = find(TI->op_begin(), TI->op_end(), BB);
227   assert(OI != TI->op_end() && "Pred does not branch to BB!!!");
228
229   // Change the use of BB to point to the new stub basic block
230   *OI = NewBB;
231
232   // Now we need to loop through all of the PHI nodes in BB and convert their
233   // first incoming value for Pred to reference the new basic block instead.
234   //
235   for (BasicBlock::iterator I = BB->begin(); 
236        PHINode *PN = dyn_cast<PHINode>(*I); ++I) {
237     int BBIdx = PN->getBasicBlockIndex(Pred);
238     assert(BBIdx != -1 && "PHI node doesn't have an entry for Pred!");
239
240     // The value that used to look like it came from Pred now comes from NewBB
241     PN->setIncomingBlock((unsigned)BBIdx, NewBB);
242   }
243 }
244
245
246 // runOnFunction - Loop through the function and fix problems with the PHI nodes
247 // in the current function.  The problem is that PHI nodes might exist with
248 // multiple entries for the same predecessor.  GCC sometimes generates code that
249 // looks like this:
250 //
251 //  bb7:  br bool %cond1004, label %bb8, label %bb8
252 //  bb8: %reg119 = phi uint [ 0, %bb7 ], [ 1, %bb7 ]
253 //     
254 //     which is completely illegal LLVM code.  To compensate for this, we insert
255 //     an extra basic block, and convert the code to look like this:
256 //
257 //  bb7: br bool %cond1004, label %bbX, label %bb8
258 //  bbX: br label bb8
259 //  bb8: %reg119 = phi uint [ 0, %bbX ], [ 1, %bb7 ]
260 //
261 //
262 bool CleanupGCCOutput::runOnFunction(Function *M) {
263   bool Changed = false;
264   // Don't use iterators because invalidation gets messy...
265   for (unsigned MI = 0; MI < M->size(); ++MI) {
266     BasicBlock *BB = M->getBasicBlocks()[MI];
267
268     Changed |= FixCastsAndPHIs(BB);
269
270     if (isa<PHINode>(BB->front())) {
271       const vector<BasicBlock*> Preds(pred_begin(BB), pred_end(BB));
272
273       // Handle the problem.  Sort the list of predecessors so that it is easy
274       // to decide whether or not duplicate predecessors exist.
275       vector<BasicBlock*> SortedPreds(Preds);
276       sort(SortedPreds.begin(), SortedPreds.end());
277
278       // Loop over the predecessors, looking for adjacent BB's that are equal.
279       BasicBlock *LastOne = 0;
280       for (unsigned i = 0; i < Preds.size(); ++i) {
281         if (SortedPreds[i] == LastOne) {   // Found a duplicate.
282           RefactorPredecessor(BB, SortedPreds[i]);
283           Changed = true;
284         }
285         LastOne = SortedPreds[i];
286       }
287     }
288   }
289   return Changed;
290 }
291
292 bool CleanupGCCOutput::doFinalization(Module *M) {
293   bool Changed = false;
294
295   if (M->hasSymbolTable()) {
296     SymbolTable *ST = M->getSymbolTable();
297     const std::set<const Type *> &UsedTypes =
298       getAnalysis<FindUsedTypes>().getTypes();
299
300     // Check the symbol table for superfluous type entries that aren't used in
301     // the program
302     //
303     // Grab the 'type' plane of the module symbol...
304     SymbolTable::iterator STI = ST->find(Type::TypeTy);
305     if (STI != ST->end()) {
306       // Loop over all entries in the type plane...
307       SymbolTable::VarMap &Plane = STI->second;
308       for (SymbolTable::VarMap::iterator PI = Plane.begin(); PI != Plane.end();)
309         if (!UsedTypes.count(cast<Type>(PI->second))) {
310 #if MAP_IS_NOT_BRAINDEAD
311           PI = Plane.erase(PI);     // STD C++ Map should support this!
312 #else
313           Plane.erase(PI);          // Alas, GCC 2.95.3 doesn't  *SIGH*
314           PI = Plane.begin();       // N^2 algorithms are fun.  :(
315 #endif
316           Changed = true;
317         } else {
318           ++PI;
319         }
320     }
321   }
322   return Changed;
323 }
324
325
326 //===----------------------------------------------------------------------===//
327 //
328 // FunctionResolvingPass - Go over the functions that are in the module and
329 // look for functions that have the same name.  More often than not, there will
330 // be things like:
331 //    void "foo"(...)
332 //    void "foo"(int, int)
333 // because of the way things are declared in C.  If this is the case, patch
334 // things up.
335 //
336 //===----------------------------------------------------------------------===//
337
338 namespace {
339   struct FunctionResolvingPass : public Pass {
340     bool run(Module *M);
341   };
342 }
343
344 // ConvertCallTo - Convert a call to a varargs function with no arg types
345 // specified to a concrete nonvarargs function.
346 //
347 static void ConvertCallTo(CallInst *CI, Function *Dest) {
348   const FunctionType::ParamTypes &ParamTys =
349     Dest->getFunctionType()->getParamTypes();
350   BasicBlock *BB = CI->getParent();
351
352   // Get an iterator to where we want to insert cast instructions if the
353   // argument types don't agree.
354   //
355   BasicBlock::iterator BBI = find(BB->begin(), BB->end(), CI);
356   assert(BBI != BB->end() && "CallInst not in parent block?");
357
358   assert(CI->getNumOperands()-1 == ParamTys.size()&&
359          "Function calls resolved funny somehow, incompatible number of args");
360
361   vector<Value*> Params;
362
363   // Convert all of the call arguments over... inserting cast instructions if
364   // the types are not compatible.
365   for (unsigned i = 1; i < CI->getNumOperands(); ++i) {
366     Value *V = CI->getOperand(i);
367
368     if (V->getType() != ParamTys[i-1]) { // Must insert a cast...
369       Instruction *Cast = new CastInst(V, ParamTys[i-1]);
370       BBI = BB->getInstList().insert(BBI, Cast)+1;
371       V = Cast;
372     }
373
374     Params.push_back(V);
375   }
376
377   // Replace the old call instruction with a new call instruction that calls
378   // the real function.
379   //
380   ReplaceInstWithInst(BB->getInstList(), BBI, new CallInst(Dest, Params));
381 }
382
383
384 bool FunctionResolvingPass::run(Module *M) {
385   SymbolTable *ST = M->getSymbolTable();
386   if (!ST) return false;
387
388   std::map<string, vector<Function*> > Functions;
389
390   // Loop over the entries in the symbol table. If an entry is a func pointer,
391   // then add it to the Functions map.  We do a two pass algorithm here to avoid
392   // problems with iterators getting invalidated if we did a one pass scheme.
393   //
394   for (SymbolTable::iterator I = ST->begin(), E = ST->end(); I != E; ++I)
395     if (const PointerType *PT = dyn_cast<PointerType>(I->first))
396       if (isa<FunctionType>(PT->getElementType())) {
397         SymbolTable::VarMap &Plane = I->second;
398         for (SymbolTable::type_iterator PI = Plane.begin(), PE = Plane.end();
399              PI != PE; ++PI) {
400           const string &Name = PI->first;
401           Functions[Name].push_back(cast<Function>(PI->second));          
402         }
403       }
404
405   bool Changed = false;
406
407   // Now we have a list of all functions with a particular name.  If there is
408   // more than one entry in a list, merge the functions together.
409   //
410   for (std::map<string, vector<Function*> >::iterator I = Functions.begin(), 
411          E = Functions.end(); I != E; ++I) {
412     vector<Function*> &Functions = I->second;
413     Function *Implementation = 0;     // Find the implementation
414     Function *Concrete = 0;
415     for (unsigned i = 0; i < Functions.size(); ) {
416       if (!Functions[i]->isExternal()) {  // Found an implementation
417         assert(Implementation == 0 && "Multiple definitions of the same"
418                " function. Case not handled yet!");
419         Implementation = Functions[i];
420       } else {
421         // Ignore functions that are never used so they don't cause spurious
422         // warnings... here we will actually DCE the function so that it isn't
423         // used later.
424         //
425         if (Functions[i]->use_size() == 0) {
426           M->getFunctionList().remove(Functions[i]);
427           delete Functions[i];
428           Functions.erase(Functions.begin()+i);
429           Changed = true;
430           continue;
431         }
432       }
433       
434       if (Functions[i] && (!Functions[i]->getFunctionType()->isVarArg())) {
435         if (Concrete) {  // Found two different functions types.  Can't choose
436           Concrete = 0;
437           break;
438         }
439         Concrete = Functions[i];
440       }
441       ++i;
442     }
443
444     if (Functions.size() > 1) {         // Found a multiply defined function...
445       // We should find exactly one non-vararg function definition, which is
446       // probably the implementation.  Change all of the function definitions
447       // and uses to use it instead.
448       //
449       if (!Concrete) {
450         cerr << "Warning: Found functions types that are not compatible:\n";
451         for (unsigned i = 0; i < Functions.size(); ++i) {
452           cerr << "\t" << Functions[i]->getType()->getDescription() << " %"
453                << Functions[i]->getName() << "\n";
454         }
455         cerr << "  No linkage of functions named '" << Functions[0]->getName()
456              << "' performed!\n";
457       } else {
458         for (unsigned i = 0; i < Functions.size(); ++i)
459           if (Functions[i] != Concrete) {
460             Function *Old = Functions[i];
461             const FunctionType *OldMT = Old->getFunctionType();
462             const FunctionType *ConcreteMT = Concrete->getFunctionType();
463             bool Broken = false;
464
465             assert(Old->getReturnType() == Concrete->getReturnType() &&
466                    "Differing return types not handled yet!");
467             assert(OldMT->getParamTypes().size() <=
468                    ConcreteMT->getParamTypes().size() &&
469                    "Concrete type must have more specified parameters!");
470
471             // Check to make sure that if there are specified types, that they
472             // match...
473             //
474             for (unsigned i = 0; i < OldMT->getParamTypes().size(); ++i)
475               if (OldMT->getParamTypes()[i] != ConcreteMT->getParamTypes()[i]) {
476                 cerr << "Parameter types conflict for" << OldMT
477                      << " and " << ConcreteMT;
478                 Broken = true;
479               }
480             if (Broken) break;  // Can't process this one!
481
482
483             // Attempt to convert all of the uses of the old function to the
484             // concrete form of the function.  If there is a use of the fn
485             // that we don't understand here we punt to avoid making a bad
486             // transformation.
487             //
488             // At this point, we know that the return values are the same for
489             // our two functions and that the Old function has no varargs fns
490             // specified.  In otherwords it's just <retty> (...)
491             //
492             for (unsigned i = 0; i < Old->use_size(); ) {
493               User *U = *(Old->use_begin()+i);
494               if (CastInst *CI = dyn_cast<CastInst>(U)) {
495                 // Convert casts directly
496                 assert(CI->getOperand(0) == Old);
497                 CI->setOperand(0, Concrete);
498                 Changed = true;
499               } else if (CallInst *CI = dyn_cast<CallInst>(U)) {
500                 // Can only fix up calls TO the argument, not args passed in.
501                 if (CI->getCalledValue() == Old) {
502                   ConvertCallTo(CI, Concrete);
503                   Changed = true;
504                 } else {
505                   cerr << "Couldn't cleanup this function call, must be an"
506                        << " argument or something!" << CI;
507                   ++i;
508                 }
509               } else {
510                 cerr << "Cannot convert use of function: " << U << "\n";
511                 ++i;
512               }
513             }
514           }
515         }
516     }
517   }
518
519   return Changed;
520 }
521
522 Pass *createFunctionResolvingPass() {
523   return new FunctionResolvingPass();
524 }