changes to make it compatible with 64bit gcc
[oota-llvm.git] / lib / Transforms / IPO / MutateStructTypes.cpp
1 //===- MutateStructTypes.cpp - Change struct defns --------------------------=//
2 //
3 // This pass is used to change structure accesses and type definitions in some
4 // way.  It can be used to arbitrarily permute structure fields, safely, without
5 // breaking code.  A transformation may only be done on a type if that type has
6 // been found to be "safe" by the 'FindUnsafePointerTypes' pass.  This pass will
7 // assert and die if you try to do an illegal transformation.
8 //
9 // This is an interprocedural pass that requires the entire program to do a
10 // transformation.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/Transforms/IPO/MutateStructTypes.h"
15 #include "llvm/DerivedTypes.h"
16 #include "llvm/Module.h"
17 #include "llvm/Function.h"
18 #include "llvm/BasicBlock.h"
19 #include "llvm/GlobalVariable.h"
20 #include "llvm/SymbolTable.h"
21 #include "llvm/iPHINode.h"
22 #include "llvm/iMemory.h"
23 #include "llvm/iTerminators.h"
24 #include "llvm/iOther.h"
25 #include "llvm/Argument.h"
26 #include "llvm/Constants.h"
27 #include "Support/STLExtras.h"
28 #include "Support/StatisticReporter.h"
29 #include <algorithm>
30 #include <iostream>
31 using std::map;
32 using std::vector;
33
34 // ValuePlaceHolder - A stupid little marker value.  It appears as an
35 // instruction of type Instruction::UserOp1.
36 //
37 struct ValuePlaceHolder : public Instruction {
38   ValuePlaceHolder(const Type *Ty) : Instruction(Ty, UserOp1, "") {}
39
40   virtual Instruction *clone() const { abort(); return 0; }
41   virtual const char *getOpcodeName() const { return "placeholder"; }
42 };
43
44
45 // ConvertType - Convert from the old type system to the new one...
46 const Type *MutateStructTypes::ConvertType(const Type *Ty) {
47   if (Ty->isPrimitiveType() ||
48       isa<OpaqueType>(Ty)) return Ty;  // Don't convert primitives
49
50   map<const Type *, PATypeHolder>::iterator I = TypeMap.find(Ty);
51   if (I != TypeMap.end()) return I->second;
52
53   const Type *DestTy = 0;
54
55   PATypeHolder PlaceHolder = OpaqueType::get();
56   TypeMap.insert(std::make_pair(Ty, PlaceHolder.get()));
57
58   switch (Ty->getPrimitiveID()) {
59   case Type::FunctionTyID: {
60     const FunctionType *MT = cast<FunctionType>(Ty);
61     const Type *RetTy = ConvertType(MT->getReturnType());
62     vector<const Type*> ArgTypes;
63
64     for (FunctionType::ParamTypes::const_iterator I = MT->getParamTypes().begin(),
65            E = MT->getParamTypes().end(); I != E; ++I)
66       ArgTypes.push_back(ConvertType(*I));
67     
68     DestTy = FunctionType::get(RetTy, ArgTypes, MT->isVarArg());
69     break;
70   }
71   case Type::StructTyID: {
72     const StructType *ST = cast<StructType>(Ty);
73     const StructType::ElementTypes &El = ST->getElementTypes();
74     vector<const Type *> Types;
75
76     for (StructType::ElementTypes::const_iterator I = El.begin(), E = El.end();
77          I != E; ++I)
78       Types.push_back(ConvertType(*I));
79     DestTy = StructType::get(Types);
80     break;
81   }
82   case Type::ArrayTyID:
83     DestTy = ArrayType::get(ConvertType(cast<ArrayType>(Ty)->getElementType()),
84                             cast<ArrayType>(Ty)->getNumElements());
85     break;
86
87   case Type::PointerTyID:
88     DestTy = PointerType::get(
89                  ConvertType(cast<PointerType>(Ty)->getElementType()));
90     break;
91   default:
92     assert(0 && "Unknown type!");
93     return 0;
94   }
95
96   assert(DestTy && "Type didn't get created!?!?");
97
98   // Refine our little placeholder value into a real type...
99   ((DerivedType*)PlaceHolder.get())->refineAbstractTypeTo(DestTy);
100   TypeMap.insert(std::make_pair(Ty, PlaceHolder.get()));
101
102   return PlaceHolder.get();
103 }
104
105
106 // AdjustIndices - Convert the indexes specifed by Idx to the new changed form
107 // using the specified OldTy as the base type being indexed into.
108 //
109 void MutateStructTypes::AdjustIndices(const CompositeType *OldTy,
110                                       vector<Value*> &Idx,
111                                       unsigned i = 0) {
112   assert(i < Idx.size() && "i out of range!");
113   const CompositeType *NewCT = cast<CompositeType>(ConvertType(OldTy));
114   if (NewCT == OldTy) return;  // No adjustment unless type changes
115
116   if (const StructType *OldST = dyn_cast<StructType>(OldTy)) {
117     // Figure out what the current index is...
118     unsigned ElNum = cast<ConstantUInt>(Idx[i])->getValue();
119     assert(ElNum < OldST->getElementTypes().size());
120
121     map<const StructType*, TransformType>::iterator I = Transforms.find(OldST);
122     if (I != Transforms.end()) {
123       assert(ElNum < I->second.second.size());
124       // Apply the XForm specified by Transforms map...
125       unsigned NewElNum = I->second.second[ElNum];
126       Idx[i] = ConstantUInt::get(Type::UByteTy, NewElNum);
127     }
128   }
129
130   // Recursively process subtypes...
131   if (i+1 < Idx.size())
132     AdjustIndices(cast<CompositeType>(OldTy->getTypeAtIndex(Idx[i])), Idx, i+1);
133 }
134
135
136 // ConvertValue - Convert from the old value in the old type system to the new
137 // type system.
138 //
139 Value *MutateStructTypes::ConvertValue(const Value *V) {
140   // Ignore null values and simple constants..
141   if (V == 0) return 0;
142
143   if (const Constant *CPV = dyn_cast<Constant>(V)) {
144     if (V->getType()->isPrimitiveType())
145       return (Value*)CPV;
146
147     if (isa<ConstantPointerNull>(CPV))
148       return ConstantPointerNull::get(
149                       cast<PointerType>(ConvertType(V->getType())));
150     assert(0 && "Unable to convert constpool val of this type!");
151   }
152
153   // Check to see if this is an out of function reference first...
154   if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
155     // Check to see if the value is in the map...
156     map<const GlobalValue*, GlobalValue*>::iterator I = GlobalMap.find(GV);
157     if (I == GlobalMap.end())
158       return (Value*)GV;  // Not mapped, just return value itself
159     return I->second;
160   }
161   
162   map<const Value*, Value*>::iterator I = LocalValueMap.find(V);
163   if (I != LocalValueMap.end()) return I->second;
164
165   if (const BasicBlock *BB = dyn_cast<BasicBlock>(V)) {
166     // Create placeholder block to represent the basic block we haven't seen yet
167     // This will be used when the block gets created.
168     //
169     return LocalValueMap[V] = new BasicBlock(BB->getName());
170   }
171
172   DEBUG(std::cerr << "NPH: " << V << "\n");
173
174   // Otherwise make a constant to represent it
175   return LocalValueMap[V] = new ValuePlaceHolder(ConvertType(V->getType()));
176 }
177
178
179 // setTransforms - Take a map that specifies what transformation to do for each
180 // field of the specified structure types.  There is one element of the vector
181 // for each field of the structure.  The value specified indicates which slot of
182 // the destination structure the field should end up in.  A negative value 
183 // indicates that the field should be deleted entirely.
184 //
185 void MutateStructTypes::setTransforms(const TransformsType &XForm) {
186
187   // Loop over the types and insert dummy entries into the type map so that 
188   // recursive types are resolved properly...
189   for (map<const StructType*, vector<int> >::const_iterator I = XForm.begin(),
190          E = XForm.end(); I != E; ++I) {
191     const StructType *OldTy = I->first;
192     TypeMap.insert(std::make_pair(OldTy, OpaqueType::get()));
193   }
194
195   // Loop over the type specified and figure out what types they should become
196   for (map<const StructType*, vector<int> >::const_iterator I = XForm.begin(),
197          E = XForm.end(); I != E; ++I) {
198     const StructType  *OldTy = I->first;
199     const vector<int> &InVec = I->second;
200
201     assert(OldTy->getElementTypes().size() == InVec.size() &&
202            "Action not specified for every element of structure type!");
203
204     vector<const Type *> NewType;
205
206     // Convert the elements of the type over, including the new position mapping
207     int Idx = 0;
208     vector<int>::const_iterator TI = find(InVec.begin(), InVec.end(), Idx);
209     while (TI != InVec.end()) {
210       unsigned Offset = TI-InVec.begin();
211       const Type *NewEl = ConvertType(OldTy->getContainedType(Offset));
212       assert(NewEl && "Element not found!");
213       NewType.push_back(NewEl);
214
215       TI = find(InVec.begin(), InVec.end(), ++Idx);
216     }
217
218     // Create a new type that corresponds to the destination type
219     PATypeHolder NSTy = StructType::get(NewType);
220
221     // Refine the old opaque type to the new type to properly handle recursive
222     // types...
223     //
224     const Type *OldTypeStub = TypeMap.find(OldTy)->second.get();
225     ((DerivedType*)OldTypeStub)->refineAbstractTypeTo(NSTy);
226
227     // Add the transformation to the Transforms map.
228     Transforms.insert(std::make_pair(OldTy,
229                        std::make_pair(cast<StructType>(NSTy.get()), InVec)));
230
231     DEBUG(std::cerr << "Mutate " << OldTy << "\nTo " << NSTy << "\n");
232   }
233 }
234
235 void MutateStructTypes::clearTransforms() {
236   Transforms.clear();
237   TypeMap.clear();
238   GlobalMap.clear();
239   assert(LocalValueMap.empty() &&
240          "Local Value Map should always be empty between transformations!");
241 }
242
243 // processGlobals - This loops over global constants defined in the
244 // module, converting them to their new type.
245 //
246 void MutateStructTypes::processGlobals(Module &M) {
247   // Loop through the functions in the module and create a new version of the
248   // function to contained the transformed code.  Also, be careful to not
249   // process the values that we add.
250   //
251   for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
252     if (!I->isExternal()) {
253       const FunctionType *NewMTy = 
254         cast<FunctionType>(ConvertType(I->getFunctionType()));
255       
256       // Create a new function to put stuff into...
257       Function *NewMeth = new Function(NewMTy, I->hasInternalLinkage(),
258                                        I->getName());
259       if (I->hasName())
260         I->setName("OLD."+I->getName());
261
262       // Insert the new function into the function list... to be filled in later
263       M.getFunctionList().push_back(NewMeth);
264       
265       // Keep track of the association...
266       GlobalMap[I] = NewMeth;
267     }
268
269   // TODO: HANDLE GLOBAL VARIABLES
270
271   // Remap the symbol table to refer to the types in a nice way
272   //
273   if (SymbolTable *ST = M.getSymbolTable()) {    
274     SymbolTable::iterator I = ST->find(Type::TypeTy);
275     if (I != ST->end()) {    // Get the type plane for Type's
276       SymbolTable::VarMap &Plane = I->second;
277       for (SymbolTable::type_iterator TI = Plane.begin(), TE = Plane.end();
278            TI != TE; ++TI) {
279         // FIXME: This is gross, I'm reaching right into a symbol table and
280         // mucking around with it's internals... but oh well.
281         //
282         TI->second = (Value*)cast<Type>(ConvertType(cast<Type>(TI->second)));
283       }
284     }
285   }
286 }
287
288
289 // removeDeadGlobals - For this pass, all this does is remove the old versions
290 // of the functions and global variables that we no longer need.
291 void MutateStructTypes::removeDeadGlobals(Module &M) {
292   // Prepare for deletion of globals by dropping their interdependencies...
293   for(Module::iterator I = M.begin(); I != M.end(); ++I) {
294     if (GlobalMap.find(I) != GlobalMap.end())
295       I->dropAllReferences();
296   }
297
298   // Run through and delete the functions and global variables...
299 #if 0  // TODO: HANDLE GLOBAL VARIABLES
300   M->getGlobalList().delete_span(M.gbegin(), M.gbegin()+NumGVars/2);
301 #endif
302   for(Module::iterator I = M.begin(); I != M.end();) {
303     if (GlobalMap.find(I) != GlobalMap.end())
304       I = M.getFunctionList().erase(I);
305     else
306       ++I;
307   }
308 }
309
310
311
312 // transformFunction - This transforms the instructions of the function to use
313 // the new types.
314 //
315 void MutateStructTypes::transformFunction(Function *m) {
316   const Function *M = m;
317   map<const GlobalValue*, GlobalValue*>::iterator GMI = GlobalMap.find(M);
318   if (GMI == GlobalMap.end())
319     return;  // Do not affect one of our new functions that we are creating
320
321   Function *NewMeth = cast<Function>(GMI->second);
322
323   // Okay, first order of business, create the arguments...
324   for (Function::aiterator I = m->abegin(), E = m->aend(); I != E; ++I) {
325     Argument *NFA = new Argument(ConvertType(I->getType()), I->getName());
326     NewMeth->getArgumentList().push_back(NFA);
327     LocalValueMap[I] = NFA; // Keep track of value mapping
328   }
329
330
331   // Loop over all of the basic blocks copying instructions over...
332   for (Function::const_iterator BB = M->begin(), BBE = M->end(); BB != BBE;
333        ++BB) {
334     // Create a new basic block and establish a mapping between the old and new
335     BasicBlock *NewBB = cast<BasicBlock>(ConvertValue(BB));
336     NewMeth->getBasicBlockList().push_back(NewBB);  // Add block to function
337
338     // Copy over all of the instructions in the basic block...
339     for (BasicBlock::const_iterator II = BB->begin(), IE = BB->end();
340          II != IE; ++II) {
341
342       const Instruction &I = *II;   // Get the current instruction...
343       Instruction *NewI = 0;
344
345       switch (I.getOpcode()) {
346         // Terminator Instructions
347       case Instruction::Ret:
348         NewI = new ReturnInst(
349                    ConvertValue(cast<ReturnInst>(I).getReturnValue()));
350         break;
351       case Instruction::Br: {
352         const BranchInst &BI = cast<BranchInst>(I);
353         if (BI.isConditional()) {
354           NewI =
355               new BranchInst(cast<BasicBlock>(ConvertValue(BI.getSuccessor(0))),
356                              cast<BasicBlock>(ConvertValue(BI.getSuccessor(1))),
357                              ConvertValue(BI.getCondition()));
358         } else {
359           NewI = 
360             new BranchInst(cast<BasicBlock>(ConvertValue(BI.getSuccessor(0))));
361         }
362         break;
363       }
364       case Instruction::Switch:
365       case Instruction::Invoke:
366         assert(0 && "Insn not implemented!");
367
368         // Unary Instructions
369       case Instruction::Not:
370         NewI = UnaryOperator::create((Instruction::UnaryOps)I.getOpcode(),
371                                      ConvertValue(I.getOperand(0)));
372         break;
373
374         // Binary Instructions
375       case Instruction::Add:
376       case Instruction::Sub:
377       case Instruction::Mul:
378       case Instruction::Div:
379       case Instruction::Rem:
380         // Logical Operations
381       case Instruction::And:
382       case Instruction::Or:
383       case Instruction::Xor:
384
385         // Binary Comparison Instructions
386       case Instruction::SetEQ:
387       case Instruction::SetNE:
388       case Instruction::SetLE:
389       case Instruction::SetGE:
390       case Instruction::SetLT:
391       case Instruction::SetGT:
392         NewI = BinaryOperator::create((Instruction::BinaryOps)I.getOpcode(),
393                                       ConvertValue(I.getOperand(0)),
394                                       ConvertValue(I.getOperand(1)));
395         break;
396
397       case Instruction::Shr:
398       case Instruction::Shl:
399         NewI = new ShiftInst(cast<ShiftInst>(I).getOpcode(),
400                              ConvertValue(I.getOperand(0)),
401                              ConvertValue(I.getOperand(1)));
402         break;
403
404
405         // Memory Instructions
406       case Instruction::Alloca:
407         NewI = 
408           new AllocaInst(ConvertType(I.getType()),
409                          I.getNumOperands() ? ConvertValue(I.getOperand(0)) :0);
410         break;
411       case Instruction::Malloc:
412         NewI = 
413           new MallocInst(ConvertType(I.getType()),
414                          I.getNumOperands() ? ConvertValue(I.getOperand(0)) :0);
415         break;
416
417       case Instruction::Free:
418         NewI = new FreeInst(ConvertValue(I.getOperand(0)));
419         break;
420
421       case Instruction::Load:
422       case Instruction::Store:
423       case Instruction::GetElementPtr: {
424         const MemAccessInst &MAI = cast<MemAccessInst>(I);
425         vector<Value*> Indices(MAI.idx_begin(), MAI.idx_end());
426         const Value *Ptr = MAI.getPointerOperand();
427         Value *NewPtr = ConvertValue(Ptr);
428         if (!Indices.empty()) {
429           const Type *PTy = cast<PointerType>(Ptr->getType())->getElementType();
430           AdjustIndices(cast<CompositeType>(PTy), Indices);
431         }
432
433         if (isa<LoadInst>(I)) {
434           NewI = new LoadInst(NewPtr, Indices);
435         } else if (isa<StoreInst>(I)) {
436           NewI = new StoreInst(ConvertValue(I.getOperand(0)), NewPtr, Indices);
437         } else if (isa<GetElementPtrInst>(I)) {
438           NewI = new GetElementPtrInst(NewPtr, Indices);
439         } else {
440           assert(0 && "Unknown memory access inst!!!");
441         }
442         break;
443       }
444
445         // Miscellaneous Instructions
446       case Instruction::PHINode: {
447         const PHINode &OldPN = cast<PHINode>(I);
448         PHINode *PN = new PHINode(ConvertType(OldPN.getType()));
449         for (unsigned i = 0; i < OldPN.getNumIncomingValues(); ++i)
450           PN->addIncoming(ConvertValue(OldPN.getIncomingValue(i)),
451                     cast<BasicBlock>(ConvertValue(OldPN.getIncomingBlock(i))));
452         NewI = PN;
453         break;
454       }
455       case Instruction::Cast:
456         NewI = new CastInst(ConvertValue(I.getOperand(0)),
457                             ConvertType(I.getType()));
458         break;
459       case Instruction::Call: {
460         Value *Meth = ConvertValue(I.getOperand(0));
461         vector<Value*> Operands;
462         for (unsigned i = 1; i < I.getNumOperands(); ++i)
463           Operands.push_back(ConvertValue(I.getOperand(i)));
464         NewI = new CallInst(Meth, Operands);
465         break;
466       }
467         
468       default:
469         assert(0 && "UNKNOWN INSTRUCTION ENCOUNTERED!\n");
470         break;
471       }
472
473       NewI->setName(I.getName());
474       NewBB->getInstList().push_back(NewI);
475
476       // Check to see if we had to make a placeholder for this value...
477       map<const Value*,Value*>::iterator LVMI = LocalValueMap.find(&I);
478       if (LVMI != LocalValueMap.end()) {
479         // Yup, make sure it's a placeholder...
480         Instruction *I = cast<Instruction>(LVMI->second);
481         assert(I->getOpcode() == Instruction::UserOp1 && "Not a placeholder!");
482
483         // Replace all uses of the place holder with the real deal...
484         I->replaceAllUsesWith(NewI);
485         delete I;                    // And free the placeholder memory
486       }
487
488       // Keep track of the fact the the local implementation of this instruction
489       // is NewI.
490       LocalValueMap[&I] = NewI;
491     }
492   }
493
494   LocalValueMap.clear();
495 }
496
497
498 bool MutateStructTypes::run(Module &M) {
499   processGlobals(M);
500
501   for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
502     transformFunction(I);
503
504   removeDeadGlobals(M);
505   return true;
506 }
507