c67a2e006d29baaebea82c876b7825d55208834a
[oota-llvm.git] / lib / Transforms / IPO / MutateStructTypes.cpp
1 //===- MutateStructTypes.cpp - Change struct defns --------------------------=//
2 //
3 // This pass is used to change structure accesses and type definitions in some
4 // way.  It can be used to arbitrarily permute structure fields, safely, without
5 // breaking code.  A transformation may only be done on a type if that type has
6 // been found to be "safe" by the 'FindUnsafePointerTypes' pass.  This pass will
7 // assert and die if you try to do an illegal transformation.
8 //
9 // This is an interprocedural pass that requires the entire program to do a
10 // transformation.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/Transforms/IPO/MutateStructTypes.h"
15 #include "llvm/DerivedTypes.h"
16 #include "llvm/Module.h"
17 #include "llvm/Method.h"
18 #include "llvm/GlobalVariable.h"
19 #include "llvm/SymbolTable.h"
20 #include "llvm/iPHINode.h"
21 #include "llvm/iMemory.h"
22 #include "llvm/iTerminators.h"
23 #include "llvm/iOther.h"
24 #include "Support/STLExtras.h"
25 #include <algorithm>
26 using std::map;
27 using std::vector;
28
29 //FIXME: These headers are only included because the analyses are killed!!!
30 #include "llvm/Analysis/CallGraph.h"
31 #include "llvm/Analysis/FindUsedTypes.h"
32 #include "llvm/Analysis/FindUnsafePointerTypes.h"
33 //FIXME end
34
35 // To enable debugging, uncomment this...
36 //#define DEBUG_MST(x) x
37
38 #ifdef DEBUG_MST
39 #include "llvm/Assembly/Writer.h"
40 #else
41 #define DEBUG_MST(x)   // Disable debug code
42 #endif
43
44 // ValuePlaceHolder - A stupid little marker value.  It appears as an
45 // instruction of type Instruction::UserOp1.
46 //
47 struct ValuePlaceHolder : public Instruction {
48   ValuePlaceHolder(const Type *Ty) : Instruction(Ty, UserOp1, "") {}
49
50   virtual Instruction *clone() const { abort(); return 0; }
51   virtual const char *getOpcodeName() const { return "placeholder"; }
52 };
53
54
55 // ConvertType - Convert from the old type system to the new one...
56 const Type *MutateStructTypes::ConvertType(const Type *Ty) {
57   if (Ty->isPrimitiveType() ||
58       isa<OpaqueType>(Ty)) return Ty;  // Don't convert primitives
59
60   map<const Type *, PATypeHolder<Type> >::iterator I = TypeMap.find(Ty);
61   if (I != TypeMap.end()) return I->second;
62
63   const Type *DestTy = 0;
64
65   PATypeHolder<Type> PlaceHolder = OpaqueType::get();
66   TypeMap.insert(std::make_pair(Ty, PlaceHolder.get()));
67
68   switch (Ty->getPrimitiveID()) {
69   case Type::MethodTyID: {
70     const MethodType *MT = cast<MethodType>(Ty);
71     const Type *RetTy = ConvertType(MT->getReturnType());
72     vector<const Type*> ArgTypes;
73
74     for (MethodType::ParamTypes::const_iterator I = MT->getParamTypes().begin(),
75            E = MT->getParamTypes().end(); I != E; ++I)
76       ArgTypes.push_back(ConvertType(*I));
77     
78     DestTy = MethodType::get(RetTy, ArgTypes, MT->isVarArg());
79     break;
80   }
81   case Type::StructTyID: {
82     const StructType *ST = cast<StructType>(Ty);
83     const StructType::ElementTypes &El = ST->getElementTypes();
84     vector<const Type *> Types;
85
86     for (StructType::ElementTypes::const_iterator I = El.begin(), E = El.end();
87          I != E; ++I)
88       Types.push_back(ConvertType(*I));
89     DestTy = StructType::get(Types);
90     break;
91   }
92   case Type::ArrayTyID:
93     DestTy = ArrayType::get(ConvertType(cast<ArrayType>(Ty)->getElementType()),
94                             cast<ArrayType>(Ty)->getNumElements());
95     break;
96
97   case Type::PointerTyID:
98     DestTy = PointerType::get(
99                  ConvertType(cast<PointerType>(Ty)->getElementType()));
100     break;
101   default:
102     assert(0 && "Unknown type!");
103     return 0;
104   }
105
106   assert(DestTy && "Type didn't get created!?!?");
107
108   // Refine our little placeholder value into a real type...
109   cast<DerivedType>(PlaceHolder.get())->refineAbstractTypeTo(DestTy);
110   TypeMap.insert(std::make_pair(Ty, PlaceHolder.get()));
111
112   return PlaceHolder.get();
113 }
114
115
116 // AdjustIndices - Convert the indexes specifed by Idx to the new changed form
117 // using the specified OldTy as the base type being indexed into.
118 //
119 void MutateStructTypes::AdjustIndices(const CompositeType *OldTy,
120                                       vector<Value*> &Idx,
121                                       unsigned i = 0) {
122   assert(i < Idx.size() && "i out of range!");
123   const CompositeType *NewCT = cast<CompositeType>(ConvertType(OldTy));
124   if (NewCT == OldTy) return;  // No adjustment unless type changes
125
126   if (const StructType *OldST = dyn_cast<StructType>(OldTy)) {
127     // Figure out what the current index is...
128     unsigned ElNum = cast<ConstantUInt>(Idx[i])->getValue();
129     assert(ElNum < OldST->getElementTypes().size());
130
131     map<const StructType*, TransformType>::iterator I = Transforms.find(OldST);
132     if (I != Transforms.end()) {
133       assert(ElNum < I->second.second.size());
134       // Apply the XForm specified by Transforms map...
135       unsigned NewElNum = I->second.second[ElNum];
136       Idx[i] = ConstantUInt::get(Type::UByteTy, NewElNum);
137     }
138   }
139
140   // Recursively process subtypes...
141   if (i+1 < Idx.size())
142     AdjustIndices(cast<CompositeType>(OldTy->getTypeAtIndex(Idx[i])), Idx, i+1);
143 }
144
145
146 // ConvertValue - Convert from the old value in the old type system to the new
147 // type system.
148 //
149 Value *MutateStructTypes::ConvertValue(const Value *V) {
150   // Ignore null values and simple constants..
151   if (V == 0) return 0;
152
153   if (Constant *CPV = dyn_cast<Constant>(V)) {
154     if (V->getType()->isPrimitiveType())
155       return CPV;
156
157     if (isa<ConstantPointerNull>(CPV))
158       return ConstantPointerNull::get(
159                       cast<PointerType>(ConvertType(V->getType())));
160     assert(0 && "Unable to convert constpool val of this type!");
161   }
162
163   // Check to see if this is an out of method reference first...
164   if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
165     // Check to see if the value is in the map...
166     map<const GlobalValue*, GlobalValue*>::iterator I = GlobalMap.find(GV);
167     if (I == GlobalMap.end())
168       return GV;  // Not mapped, just return value itself
169     return I->second;
170   }
171   
172   map<const Value*, Value*>::iterator I = LocalValueMap.find(V);
173   if (I != LocalValueMap.end()) return I->second;
174
175   if (const BasicBlock *BB = dyn_cast<BasicBlock>(V)) {
176     // Create placeholder block to represent the basic block we haven't seen yet
177     // This will be used when the block gets created.
178     //
179     return LocalValueMap[V] = new BasicBlock(BB->getName());
180   }
181
182   DEBUG_MST(cerr << "NPH: " << V << endl);
183
184   // Otherwise make a constant to represent it
185   return LocalValueMap[V] = new ValuePlaceHolder(ConvertType(V->getType()));
186 }
187
188
189 // setTransforms - Take a map that specifies what transformation to do for each
190 // field of the specified structure types.  There is one element of the vector
191 // for each field of the structure.  The value specified indicates which slot of
192 // the destination structure the field should end up in.  A negative value 
193 // indicates that the field should be deleted entirely.
194 //
195 void MutateStructTypes::setTransforms(const TransformsType &XForm) {
196
197   // Loop over the types and insert dummy entries into the type map so that 
198   // recursive types are resolved properly...
199   for (map<const StructType*, vector<int> >::const_iterator I = XForm.begin(),
200          E = XForm.end(); I != E; ++I) {
201     const StructType *OldTy = I->first;
202     TypeMap.insert(std::make_pair(OldTy, OpaqueType::get()));
203   }
204
205   // Loop over the type specified and figure out what types they should become
206   for (map<const StructType*, vector<int> >::const_iterator I = XForm.begin(),
207          E = XForm.end(); I != E; ++I) {
208     const StructType  *OldTy = I->first;
209     const vector<int> &InVec = I->second;
210
211     assert(OldTy->getElementTypes().size() == InVec.size() &&
212            "Action not specified for every element of structure type!");
213
214     vector<const Type *> NewType;
215
216     // Convert the elements of the type over, including the new position mapping
217     int Idx = 0;
218     vector<int>::const_iterator TI = find(InVec.begin(), InVec.end(), Idx);
219     while (TI != InVec.end()) {
220       unsigned Offset = TI-InVec.begin();
221       const Type *NewEl = ConvertType(OldTy->getContainedType(Offset));
222       assert(NewEl && "Element not found!");
223       NewType.push_back(NewEl);
224
225       TI = find(InVec.begin(), InVec.end(), ++Idx);
226     }
227
228     // Create a new type that corresponds to the destination type
229     PATypeHolder<StructType> NSTy = StructType::get(NewType);
230
231     // Refine the old opaque type to the new type to properly handle recursive
232     // types...
233     //
234     const Type *OldTypeStub = TypeMap.find(OldTy)->second.get();
235     cast<DerivedType>(OldTypeStub)->refineAbstractTypeTo(NSTy);
236
237     // Add the transformation to the Transforms map.
238     Transforms.insert(std::make_pair(OldTy, std::make_pair(NSTy, InVec)));
239
240     DEBUG_MST(cerr << "Mutate " << OldTy << "\nTo " << NSTy << endl);
241   }
242 }
243
244 void MutateStructTypes::clearTransforms() {
245   Transforms.clear();
246   TypeMap.clear();
247   GlobalMap.clear();
248   assert(LocalValueMap.empty() &&
249          "Local Value Map should always be empty between transformations!");
250 }
251
252 // doInitialization - This loops over global constants defined in the
253 // module, converting them to their new type.
254 //
255 void MutateStructTypes::processGlobals(Module *M) {
256   // Loop through the methods in the module and create a new version of the
257   // method to contained the transformed code.  Don't use an iterator, because
258   // we will be adding values to the end of the vector, and it could be
259   // reallocated.  Also, we don't want to process the values that we add.
260   //
261   unsigned NumMethods = M->size();
262   for (unsigned i = 0; i < NumMethods; ++i) {
263     Method *Meth = M->begin()[i];
264
265     if (!Meth->isExternal()) {
266       const MethodType *NewMTy = 
267         cast<MethodType>(ConvertType(Meth->getMethodType()));
268       
269       // Create a new method to put stuff into...
270       Method *NewMeth = new Method(NewMTy, Meth->hasInternalLinkage(),
271                                    Meth->getName());
272       if (Meth->hasName())
273         Meth->setName("OLD."+Meth->getName());
274
275       // Insert the new method into the method list... to be filled in later...
276       M->getMethodList().push_back(NewMeth);
277       
278       // Keep track of the association...
279       GlobalMap[Meth] = NewMeth;
280     }
281   }
282
283   // TODO: HANDLE GLOBAL VARIABLES
284
285   // Remap the symbol table to refer to the types in a nice way
286   //
287   if (M->hasSymbolTable()) {
288     SymbolTable *ST = M->getSymbolTable();
289     SymbolTable::iterator I = ST->find(Type::TypeTy);
290     if (I != ST->end()) {    // Get the type plane for Type's
291       SymbolTable::VarMap &Plane = I->second;
292       for (SymbolTable::type_iterator TI = Plane.begin(), TE = Plane.end();
293            TI != TE; ++TI) {
294         // This is gross, I'm reaching right into a symbol table and mucking
295         // around with it's internals... but oh well.
296         //
297         TI->second = cast<Type>(ConvertType(cast<Type>(TI->second)));
298       }
299     }
300   }
301 }
302
303
304 // removeDeadGlobals - For this pass, all this does is remove the old versions
305 // of the methods and global variables that we no longer need.
306 void MutateStructTypes::removeDeadGlobals(Module *M) {
307   // The first half of the methods in the module have to go.
308   //unsigned NumMethods = M->size();
309   //unsigned NumGVars   = M->gsize();
310
311   // Prepare for deletion of globals by dropping their interdependencies...
312   for(Module::iterator I = M->begin(); I != M->end(); ++I) {
313     if (GlobalMap.find(*I) != GlobalMap.end())
314       (*I)->Method::dropAllReferences();
315   }
316
317   // Run through and delete the methods and global variables...
318 #if 0  // TODO: HANDLE GLOBAL VARIABLES
319   M->getGlobalList().delete_span(M->gbegin(), M->gbegin()+NumGVars/2);
320 #endif
321   for(Module::iterator I = M->begin(); I != M->end();) {
322     if (GlobalMap.find(*I) != GlobalMap.end())
323       delete M->getMethodList().remove(I);
324     else
325       ++I;
326   }
327 }
328
329
330
331 // transformMethod - This transforms the instructions of the method to use the
332 // new types.
333 //
334 void MutateStructTypes::transformMethod(Method *m) {
335   const Method *M = m;
336   map<const GlobalValue*, GlobalValue*>::iterator GMI = GlobalMap.find(M);
337   if (GMI == GlobalMap.end())
338     return;  // Do not affect one of our new methods that we are creating
339
340   Method *NewMeth = cast<Method>(GMI->second);
341
342   // Okay, first order of business, create the arguments...
343   for (unsigned i = 0; i < M->getArgumentList().size(); ++i) {
344     const MethodArgument *OMA = M->getArgumentList()[i];
345     MethodArgument *NMA = new MethodArgument(ConvertType(OMA->getType()),
346                                              OMA->getName());
347     NewMeth->getArgumentList().push_back(NMA);
348     LocalValueMap[OMA] = NMA; // Keep track of value mapping
349   }
350
351
352   // Loop over all of the basic blocks copying instructions over...
353   for (Method::const_iterator BBI = M->begin(), BBE = M->end(); BBI != BBE;
354        ++BBI) {
355
356     // Create a new basic block and establish a mapping between the old and new
357     const BasicBlock *BB = *BBI;
358     BasicBlock *NewBB = cast<BasicBlock>(ConvertValue(BB));
359     NewMeth->getBasicBlocks().push_back(NewBB);  // Add block to method
360
361     // Copy over all of the instructions in the basic block...
362     for (BasicBlock::const_iterator II = BB->begin(), IE = BB->end();
363          II != IE; ++II) {
364
365       const Instruction *I = *II;   // Get the current instruction...
366       Instruction *NewI = 0;
367
368       switch (I->getOpcode()) {
369         // Terminator Instructions
370       case Instruction::Ret:
371         NewI = new ReturnInst(
372                    ConvertValue(cast<ReturnInst>(I)->getReturnValue()));
373         break;
374       case Instruction::Br: {
375         const BranchInst *BI = cast<BranchInst>(I);
376         NewI = new BranchInst(
377                            cast<BasicBlock>(ConvertValue(BI->getSuccessor(0))),
378                     cast_or_null<BasicBlock>(ConvertValue(BI->getSuccessor(1))),
379                               ConvertValue(BI->getCondition()));
380         break;
381       }
382       case Instruction::Switch:
383       case Instruction::Invoke:
384         assert(0 && "Insn not implemented!");
385
386         // Unary Instructions
387       case Instruction::Not:
388         NewI = UnaryOperator::create((Instruction::UnaryOps)I->getOpcode(),
389                                      ConvertValue(I->getOperand(0)));
390         break;
391
392         // Binary Instructions
393       case Instruction::Add:
394       case Instruction::Sub:
395       case Instruction::Mul:
396       case Instruction::Div:
397       case Instruction::Rem:
398         // Logical Operations
399       case Instruction::And:
400       case Instruction::Or:
401       case Instruction::Xor:
402
403         // Binary Comparison Instructions
404       case Instruction::SetEQ:
405       case Instruction::SetNE:
406       case Instruction::SetLE:
407       case Instruction::SetGE:
408       case Instruction::SetLT:
409       case Instruction::SetGT:
410         NewI = BinaryOperator::create((Instruction::BinaryOps)I->getOpcode(),
411                                       ConvertValue(I->getOperand(0)),
412                                       ConvertValue(I->getOperand(1)));
413         break;
414
415       case Instruction::Shr:
416       case Instruction::Shl:
417         NewI = new ShiftInst(cast<ShiftInst>(I)->getOpcode(),
418                              ConvertValue(I->getOperand(0)),
419                              ConvertValue(I->getOperand(1)));
420         break;
421
422
423         // Memory Instructions
424       case Instruction::Alloca:
425         NewI = 
426           new AllocaInst(ConvertType(I->getType()),
427                          I->getNumOperands()?ConvertValue(I->getOperand(0)):0);
428         break;
429       case Instruction::Malloc:
430         NewI = 
431           new MallocInst(ConvertType(I->getType()),
432                          I->getNumOperands()?ConvertValue(I->getOperand(0)):0);
433         break;
434
435       case Instruction::Free:
436         NewI = new FreeInst(ConvertValue(I->getOperand(0)));
437         break;
438
439       case Instruction::Load:
440       case Instruction::Store:
441       case Instruction::GetElementPtr: {
442         const MemAccessInst *MAI = cast<MemAccessInst>(I);
443         vector<Value*> Indices(MAI->idx_begin(), MAI->idx_end());
444         const Value *Ptr = MAI->getPointerOperand();
445         Value *NewPtr = ConvertValue(Ptr);
446         if (!Indices.empty()) {
447           const Type *PTy = cast<PointerType>(Ptr->getType())->getElementType();
448           AdjustIndices(cast<CompositeType>(PTy), Indices);
449         }
450
451         if (isa<LoadInst>(I)) {
452           NewI = new LoadInst(NewPtr, Indices);
453         } else if (isa<StoreInst>(I)) {
454           NewI = new StoreInst(ConvertValue(I->getOperand(0)), NewPtr, Indices);
455         } else if (isa<GetElementPtrInst>(I)) {
456           NewI = new GetElementPtrInst(NewPtr, Indices);
457         } else {
458           assert(0 && "Unknown memory access inst!!!");
459         }
460         break;
461       }
462
463         // Miscellaneous Instructions
464       case Instruction::PHINode: {
465         const PHINode *OldPN = cast<PHINode>(I);
466         PHINode *PN = new PHINode(ConvertType(I->getType()));
467         for (unsigned i = 0; i < OldPN->getNumIncomingValues(); ++i)
468           PN->addIncoming(ConvertValue(OldPN->getIncomingValue(i)),
469                     cast<BasicBlock>(ConvertValue(OldPN->getIncomingBlock(i))));
470         NewI = PN;
471         break;
472       }
473       case Instruction::Cast:
474         NewI = new CastInst(ConvertValue(I->getOperand(0)),
475                             ConvertType(I->getType()));
476         break;
477       case Instruction::Call: {
478         Value *Meth = ConvertValue(I->getOperand(0));
479         vector<Value*> Operands;
480         for (unsigned i = 1; i < I->getNumOperands(); ++i)
481           Operands.push_back(ConvertValue(I->getOperand(i)));
482         NewI = new CallInst(Meth, Operands);
483         break;
484       }
485         
486       default:
487         assert(0 && "UNKNOWN INSTRUCTION ENCOUNTERED!\n");
488         break;
489       }
490
491       NewI->setName(I->getName());
492       NewBB->getInstList().push_back(NewI);
493
494       // Check to see if we had to make a placeholder for this value...
495       map<const Value*,Value*>::iterator LVMI = LocalValueMap.find(I);
496       if (LVMI != LocalValueMap.end()) {
497         // Yup, make sure it's a placeholder...
498         Instruction *I = cast<Instruction>(LVMI->second);
499         assert(I->getOpcode() == Instruction::UserOp1 && "Not a placeholder!");
500
501         // Replace all uses of the place holder with the real deal...
502         I->replaceAllUsesWith(NewI);
503         delete I;                    // And free the placeholder memory
504       }
505
506       // Keep track of the fact the the local implementation of this instruction
507       // is NewI.
508       LocalValueMap[I] = NewI;
509     }
510   }
511
512   LocalValueMap.clear();
513 }
514
515
516 bool MutateStructTypes::run(Module *M) {
517   processGlobals(M);
518
519   for_each(M->begin(), M->end(),
520            bind_obj(this, &MutateStructTypes::transformMethod));
521
522   removeDeadGlobals(M);
523   return true;
524 }
525
526 // getAnalysisUsageInfo - This function needs the results of the
527 // FindUsedTypes and FindUnsafePointerTypes analysis passes...
528 //
529 void MutateStructTypes::getAnalysisUsageInfo(Pass::AnalysisSet &Required,
530                                              Pass::AnalysisSet &Destroyed,
531                                              Pass::AnalysisSet &Provided) {
532   Destroyed.push_back(FindUsedTypes::ID);
533   Destroyed.push_back(FindUnsafePointerTypes::ID);
534   Destroyed.push_back(CallGraph::ID);
535 }