Start using the nicer terminator auto-insertion API
[oota-llvm.git] / lib / Transforms / Instrumentation / ProfilePaths / EdgeCode.cpp
1 //===-- EdgeCode.cpp - generate LLVM instrumentation code -----------------===//
2 // 
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by the LLVM research group and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 // 
8 //===----------------------------------------------------------------------===//
9 //It implements the class EdgeCode: which provides 
10 //support for inserting "appropriate" instrumentation at
11 //designated points in the graph
12 //
13 //It also has methods to insert initialization code in 
14 //top block of cfg
15 //===----------------------------------------------------------------------===//
16
17 #include "Graph.h"
18 #include "llvm/Constants.h"
19 #include "llvm/DerivedTypes.h"
20 #include "llvm/iMemory.h"
21 #include "llvm/iTerminators.h"
22 #include "llvm/iOther.h"
23 #include "llvm/iOperators.h"
24 #include "llvm/iPHINode.h"
25 #include "llvm/Module.h"
26
27 #define INSERT_LOAD_COUNT
28 #define INSERT_STORE
29
30
31 using std::vector;
32
33 namespace llvm {
34
35 static void getTriggerCode(Module *M, BasicBlock *BB, int MethNo, Value *pathNo,
36                            Value *cnt, Instruction *rInst){ 
37   
38   vector<Value *> tmpVec;
39   tmpVec.push_back(Constant::getNullValue(Type::LongTy));
40   tmpVec.push_back(Constant::getNullValue(Type::LongTy));
41   Instruction *Idx = new GetElementPtrInst(cnt, tmpVec, "");//,
42   BB->getInstList().push_back(Idx);
43
44   const Type *PIntTy = PointerType::get(Type::IntTy);
45   Function *trigMeth = M->getOrInsertFunction("trigger", Type::VoidTy, 
46                                               Type::IntTy, Type::IntTy,
47                                               PIntTy, PIntTy, 0);
48   assert(trigMeth && "trigger method could not be inserted!");
49
50   vector<Value *> trargs;
51
52   trargs.push_back(ConstantSInt::get(Type::IntTy,MethNo));
53   trargs.push_back(pathNo);
54   trargs.push_back(Idx);
55   trargs.push_back(rInst);
56
57   Instruction *callInst=new CallInst(trigMeth, trargs, "");//, BB->begin());
58   BB->getInstList().push_back(callInst);
59   //triggerInst = new CallInst(trigMeth, trargs, "");//, InsertPos);
60 }
61
62
63 //get the code to be inserted on the edge
64 //This is determined from cond (1-6)
65 void getEdgeCode::getCode(Instruction *rInst, Value *countInst, 
66                           Function *M, BasicBlock *BB, 
67                           vector<Value *> &retVec){
68   
69   //Instruction *InsertPos = BB->getInstList().begin();
70   
71   //now check for cdIn and cdOut
72   //first put cdOut
73   if(cdOut!=NULL){
74     cdOut->getCode(rInst, countInst, M, BB, retVec);
75   }
76   
77   if(cdIn!=NULL){
78     cdIn->getCode(rInst, countInst, M, BB, retVec);
79   }
80
81   //case: r=k code to be inserted
82   switch(cond){
83   case 1:{
84     Value *val=ConstantSInt::get(Type::IntTy,inc);
85 #ifdef INSERT_STORE
86     Instruction *stInst=new StoreInst(val, rInst);//, InsertPos);
87     BB->getInstList().push_back(stInst);
88 #endif
89     break;
90     }
91
92   //case: r=0 to be inserted
93   case 2:{
94 #ifdef INSERT_STORE
95     Instruction *stInst = new StoreInst(ConstantSInt::getNullValue(Type::IntTy), rInst);//, InsertPos);
96      BB->getInstList().push_back(stInst);
97 #endif
98     break;
99   }
100     
101   //r+=k
102   case 3:{
103     Instruction *ldInst = new LoadInst(rInst, "ti1");//, InsertPos);
104     BB->getInstList().push_back(ldInst);
105     Value *val = ConstantSInt::get(Type::IntTy,inc);
106     Instruction *addIn = BinaryOperator::create(Instruction::Add, ldInst, val,
107                                           "ti2");//, InsertPos);
108     BB->getInstList().push_back(addIn);
109 #ifdef INSERT_STORE
110     Instruction *stInst = new StoreInst(addIn, rInst);//, InsertPos);
111     BB->getInstList().push_back(stInst);
112 #endif
113     break;
114   }
115
116   //count[inc]++
117   case 4:{
118     vector<Value *> tmpVec;
119     tmpVec.push_back(Constant::getNullValue(Type::LongTy));
120     tmpVec.push_back(ConstantSInt::get(Type::LongTy, inc));
121     Instruction *Idx = new GetElementPtrInst(countInst, tmpVec, "");//,
122
123     //Instruction *Idx = new GetElementPtrInst(countInst, 
124     //           vector<Value*>(1,ConstantSInt::get(Type::LongTy, inc)),
125     //                                       "");//, InsertPos);
126     BB->getInstList().push_back(Idx);
127
128     Instruction *ldInst=new LoadInst(Idx, "ti1");//, InsertPos);
129     BB->getInstList().push_back(ldInst);
130  
131     Value *val = ConstantSInt::get(Type::IntTy, 1);
132     //Instruction *addIn =
133     Instruction *newCount =
134       BinaryOperator::create(Instruction::Add, ldInst, val,"ti2");
135     BB->getInstList().push_back(newCount);
136     
137
138 #ifdef INSERT_STORE
139     //Instruction *stInst=new StoreInst(addIn, Idx, InsertPos);
140     Instruction *stInst=new StoreInst(newCount, Idx);//, InsertPos);
141     BB->getInstList().push_back(stInst);
142 #endif
143     
144     Value *trAddIndex = ConstantSInt::get(Type::IntTy,inc);
145
146     retVec.push_back(newCount);
147     retVec.push_back(trAddIndex);
148     //insert trigger
149     //getTriggerCode(M->getParent(), BB, MethNo, 
150     //     ConstantSInt::get(Type::IntTy,inc), newCount, triggerInst);
151     //end trigger code
152
153     assert(inc>=0 && "IT MUST BE POSITIVE NOW");
154     break;
155   }
156
157   //case: count[r+inc]++
158   case 5:{
159    
160     //ti1=inc+r
161     Instruction *ldIndex=new LoadInst(rInst, "ti1");//, InsertPos);
162     BB->getInstList().push_back(ldIndex);
163
164     Value *val=ConstantSInt::get(Type::IntTy,inc);
165     Instruction *addIndex=BinaryOperator::
166       create(Instruction::Add, ldIndex, val,"ti2");//, InsertPos);
167     BB->getInstList().push_back(addIndex);
168     
169     //now load count[addIndex]
170     Instruction *castInst=new CastInst(addIndex, 
171                                        Type::LongTy,"ctin");//, InsertPos);
172     BB->getInstList().push_back(castInst);
173
174     vector<Value *> tmpVec;
175     tmpVec.push_back(Constant::getNullValue(Type::LongTy));
176     tmpVec.push_back(castInst);
177     Instruction *Idx = new GetElementPtrInst(countInst, tmpVec, "");//,
178     //                                             InsertPos);
179     BB->getInstList().push_back(Idx);
180
181     Instruction *ldInst=new LoadInst(Idx, "ti3");//, InsertPos);
182     BB->getInstList().push_back(ldInst);
183
184     Value *cons=ConstantSInt::get(Type::IntTy,1);
185     //count[addIndex]++
186     //std::cerr<<"Type ldInst:"<<ldInst->getType()<<"\t cons:"<<cons->getType()<<"\n";
187     Instruction *newCount = BinaryOperator::create(Instruction::Add, ldInst, 
188                                                    cons,"");
189     BB->getInstList().push_back(newCount);
190     
191 #ifdef INSERT_STORE
192     Instruction *stInst = new StoreInst(newCount, Idx);//, InsertPos);
193     BB->getInstList().push_back(stInst);
194 #endif
195
196     retVec.push_back(newCount);
197     retVec.push_back(addIndex);
198     //insert trigger
199     //getTriggerCode(M->getParent(), BB, MethNo, addIndex, newCount, triggerInst);
200     //end trigger code
201
202     break;
203   }
204
205     //case: count[r]+
206   case 6:{
207     //ti1=inc+r
208     Instruction *ldIndex=new LoadInst(rInst, "ti1");//, InsertPos);
209     BB->getInstList().push_back(ldIndex);
210
211     //now load count[addIndex]
212     Instruction *castInst2=new CastInst(ldIndex, Type::LongTy,"ctin");
213     BB->getInstList().push_back(castInst2);
214
215     vector<Value *> tmpVec;
216     tmpVec.push_back(Constant::getNullValue(Type::LongTy));
217     tmpVec.push_back(castInst2);
218     Instruction *Idx = new GetElementPtrInst(countInst, tmpVec, "");//,
219
220     //Instruction *Idx = new GetElementPtrInst(countInst, 
221     //                                       vector<Value*>(1,castInst2), "");
222     
223     BB->getInstList().push_back(Idx);
224     
225     Instruction *ldInst=new LoadInst(Idx, "ti2");//, InsertPos);
226     BB->getInstList().push_back(ldInst);
227
228     Value *cons=ConstantSInt::get(Type::IntTy,1);
229
230     //count[addIndex]++
231     Instruction *newCount = BinaryOperator::create(Instruction::Add, ldInst,
232                                                    cons,"ti3");
233     BB->getInstList().push_back(newCount);
234
235 #ifdef INSERT_STORE
236     Instruction *stInst = new StoreInst(newCount, Idx);//, InsertPos);
237     BB->getInstList().push_back(stInst);
238 #endif
239
240     retVec.push_back(newCount);
241     retVec.push_back(ldIndex);
242     break;
243   }
244     
245   }
246 }
247
248
249
250 //Insert the initialization code in the top BB
251 //this includes initializing r, and count
252 //r is like an accumulator, that 
253 //keeps on adding increments as we traverse along a path
254 //and at the end of the path, r contains the path
255 //number of that path
256 //Count is an array, where Count[k] represents
257 //the number of executions of path k
258 void insertInTopBB(BasicBlock *front, 
259                    int k, 
260                    Instruction *rVar, Value *threshold){
261   //rVar is variable r, 
262   //countVar is count[]
263
264   Value *Int0 = ConstantInt::get(Type::IntTy, 0);
265   
266   //now push all instructions in front of the BB
267   BasicBlock::iterator here=front->begin();
268   front->getInstList().insert(here, rVar);
269   //front->getInstList().insert(here,countVar);
270   
271   //Initialize Count[...] with 0
272
273   //for (int i=0;i<k; i++){
274   //Value *GEP2 = new GetElementPtrInst(countVar,
275   //                      vector<Value *>(1,ConstantSInt::get(Type::LongTy, i)),
276   //                                    "", here);
277   //new StoreInst(Int0, GEP2, here);
278   //}
279
280   //store uint 0, uint *%R
281   new StoreInst(Int0, rVar, here);
282 }
283
284
285 //insert a basic block with appropriate code
286 //along a given edge
287 void insertBB(Edge ed,
288               getEdgeCode *edgeCode, 
289               Instruction *rInst, 
290               Value *countInst, 
291               int numPaths, int Methno, Value *threshold){
292
293   BasicBlock* BB1=ed.getFirst()->getElement();
294   BasicBlock* BB2=ed.getSecond()->getElement();
295   
296 #ifdef DEBUG_PATH_PROFILES
297   //debugging info
298   cerr<<"Edges with codes ######################\n";
299   cerr<<BB1->getName()<<"->"<<BB2->getName()<<"\n";
300   cerr<<"########################\n";
301 #endif
302   
303   //We need to insert a BB between BB1 and BB2 
304   TerminatorInst *TI=BB1->getTerminator();
305   BasicBlock *newBB=new BasicBlock("counter", BB1->getParent());
306
307   //get code for the new BB
308   vector<Value *> retVec;
309
310   edgeCode->getCode(rInst, countInst, BB1->getParent(), newBB, retVec);
311
312   BranchInst *BI =  cast<BranchInst>(TI);
313
314   //Is terminator a branch instruction?
315   //then we need to change branch destinations to include new BB
316
317   if(BI->isUnconditional()){
318     BI->setUnconditionalDest(newBB);
319   }
320   else{
321       if(BI->getSuccessor(0)==BB2)
322       BI->setSuccessor(0, newBB);
323     
324     if(BI->getSuccessor(1)==BB2)
325       BI->setSuccessor(1, newBB);
326   }
327
328   BasicBlock *triggerBB = NULL;
329   if(retVec.size()>0){
330     triggerBB = new BasicBlock("trigger", BB1->getParent());
331     getTriggerCode(BB1->getParent()->getParent(), triggerBB, Methno, 
332                    retVec[1], countInst, rInst);//retVec[0]);
333
334     //Instruction *castInst = new CastInst(retVec[0], Type::IntTy, "");
335     Instruction *etr = new LoadInst(threshold, "threshold");
336     
337     //std::cerr<<"type1: "<<etr->getType()<<" type2: "<<retVec[0]->getType()<<"\n"; 
338     Instruction *cmpInst = new SetCondInst(Instruction::SetLE, etr, 
339                                            retVec[0], "");
340     Instruction *newBI2 = new BranchInst(triggerBB, BB2, cmpInst);
341     //newBB->getInstList().push_back(castInst);
342     newBB->getInstList().push_back(etr);
343     newBB->getInstList().push_back(cmpInst);
344     newBB->getInstList().push_back(newBI2);
345     
346     //triggerBB->getInstList().push_back(triggerInst);
347     new BranchInst(BB2, 0, 0, triggerBB);
348   }
349   else{
350     new BranchInst(BB2, 0, 0, newBB);
351   }
352
353   //now iterate over BB2, and set its Phi nodes right
354   for(BasicBlock::iterator BB2Inst = BB2->begin(), BBend = BB2->end(); 
355       BB2Inst != BBend; ++BB2Inst){
356    
357     if(PHINode *phiInst=dyn_cast<PHINode>(BB2Inst)){
358       int bbIndex=phiInst->getBasicBlockIndex(BB1);
359       assert(bbIndex>=0);
360       phiInst->setIncomingBlock(bbIndex, newBB);
361
362       ///check if trigger!=null, then add value corresponding to it too!
363       if(retVec.size()>0){
364         assert(triggerBB && "BasicBlock with trigger should not be null!");
365         Value *vl = phiInst->getIncomingValue((unsigned int)bbIndex);
366         phiInst->addIncoming(vl, triggerBB);
367       }
368     }
369   }
370 }
371
372 } // End llvm namespace