4f46e5b45c072cf5bb80e1e693c3bfc62bfd0708
[oota-llvm.git] / lib / Transforms / Instrumentation / ProfilePaths / Graph.cpp
1 //===--Graph.cpp--- implements Graph class ---------------- ------*- C++ -*--=//
2 //
3 // This implements Graph for helping in trace generation
4 // This graph gets used by "ProfilePaths" class
5 //
6 //===----------------------------------------------------------------------===//
7
8 #include "Graph.h"
9 #include "llvm/iTerminators.h"
10 #include "Support/Debug.h"
11 #include <algorithm>
12
13 using std::map;
14 using std::vector;
15 using std::cerr;
16
17 const graphListElement *findNodeInList(const Graph::nodeList &NL,
18                                               Node *N) {
19   for(Graph::nodeList::const_iterator NI = NL.begin(), NE=NL.end(); NI != NE; 
20       ++NI)
21     if (*NI->element== *N)
22       return &*NI;
23   return 0;
24 }
25
26 graphListElement *findNodeInList(Graph::nodeList &NL, Node *N) {
27   for(Graph::nodeList::iterator NI = NL.begin(), NE=NL.end(); NI != NE; ++NI)
28     if (*NI->element== *N)
29       return &*NI;
30   return 0;
31 }
32
33 //graph constructor with root and exit specified
34 Graph::Graph(std::vector<Node*> n, std::vector<Edge> e, 
35              Node *rt, Node *lt){
36   strt=rt;
37   ext=lt;
38   for(vector<Node* >::iterator x=n.begin(), en=n.end(); x!=en; ++x)
39     //nodes[*x] = list<graphListElement>();
40     nodes[*x] = vector<graphListElement>();
41
42   for(vector<Edge >::iterator x=e.begin(), en=e.end(); x!=en; ++x){
43     Edge ee=*x;
44     int w=ee.getWeight();
45     //nodes[ee.getFirst()].push_front(graphListElement(ee.getSecond(),w, ee.getRandId()));   
46     nodes[ee.getFirst()].push_back(graphListElement(ee.getSecond(),w, ee.getRandId()));
47   }
48   
49 }
50
51 //sorting edgelist, called by backEdgeVist ONLY!!!
52 Graph::nodeList &Graph::sortNodeList(Node *par, nodeList &nl, vector<Edge> &be){
53   assert(par && "null node pointer");
54   BasicBlock *bbPar = par->getElement();
55   
56   if(nl.size()<=1) return nl;
57   if(getExit() == par) return nl;
58
59   for(nodeList::iterator NLI = nl.begin(), NLE = nl.end()-1; NLI != NLE; ++NLI){
60     nodeList::iterator min = NLI;
61     for(nodeList::iterator LI = NLI+1, LE = nl.end(); LI!=LE; ++LI){
62       //if LI < min, min = LI
63       if(min->element->getElement() == LI->element->getElement() &&
64          min->element == getExit()){
65
66         //same successors: so might be exit???
67         //if it is exit, then see which is backedge
68         //check if LI is a left back edge!
69
70         TerminatorInst *tti = par->getElement()->getTerminator();
71         BranchInst *ti =  cast<BranchInst>(tti);
72
73         assert(ti && "not a branch");
74         assert(ti->getNumSuccessors()==2 && "less successors!");
75         
76         BasicBlock *tB = ti->getSuccessor(0);
77         BasicBlock *fB = ti->getSuccessor(1);
78         //so one of LI or min must be back edge!
79         //Algo: if succ(0)!=LI (and so !=min) then succ(0) is backedge
80         //and then see which of min or LI is backedge
81         //THEN if LI is in be, then min=LI
82         if(LI->element->getElement() != tB){//so backedge must be made min!
83           for(vector<Edge>::iterator VBEI = be.begin(), VBEE = be.end();
84               VBEI != VBEE; ++VBEI){
85             if(VBEI->getRandId() == LI->randId){
86               min = LI;
87               break;
88             }
89             else if(VBEI->getRandId() == min->randId)
90               break;
91           }
92         }
93         else{// if(LI->element->getElement() != fB)
94           for(vector<Edge>::iterator VBEI = be.begin(), VBEE = be.end();
95               VBEI != VBEE; ++VBEI){
96             if(VBEI->getRandId() == min->randId){
97               min = LI;
98               break;
99             }
100             else if(VBEI->getRandId() == LI->randId)
101               break;
102           }
103         }
104       }
105       
106       else if (min->element->getElement() != LI->element->getElement()){
107         TerminatorInst *tti = par->getElement()->getTerminator();
108         BranchInst *ti =  cast<BranchInst>(tti);
109         assert(ti && "not a branch");
110
111         if(ti->getNumSuccessors()<=1) continue;
112         
113         assert(ti->getNumSuccessors()==2 && "less successors!");
114         
115         BasicBlock *tB = ti->getSuccessor(0);
116         BasicBlock *fB = ti->getSuccessor(1);
117         
118         if(tB == LI->element->getElement() || fB == min->element->getElement())
119           min = LI;
120       }
121     }
122     
123     graphListElement tmpElmnt = *min;
124     *min = *NLI;
125     *NLI = tmpElmnt;
126   }
127   return nl;
128 }
129
130 //check whether graph has an edge
131 //having an edge simply means that there is an edge in the graph
132 //which has same endpoints as the given edge
133 bool Graph::hasEdge(Edge ed){
134   if(ed.isNull())
135     return false;
136
137   nodeList &nli= nodes[ed.getFirst()]; //getNodeList(ed.getFirst());
138   Node *nd2=ed.getSecond();
139
140   return (findNodeInList(nli,nd2)!=NULL);
141
142 }
143
144
145 //check whether graph has an edge, with a given wt
146 //having an edge simply means that there is an edge in the graph
147 //which has same endpoints as the given edge
148 //This function checks, moreover, that the wt of edge matches too
149 bool Graph::hasEdgeAndWt(Edge ed){
150   if(ed.isNull())
151     return false;
152
153   Node *nd2=ed.getSecond();
154   nodeList &nli = nodes[ed.getFirst()];//getNodeList(ed.getFirst());
155   
156   for(nodeList::iterator NI=nli.begin(), NE=nli.end(); NI!=NE; ++NI)
157     if(*NI->element == *nd2 && ed.getWeight()==NI->weight)
158       return true;
159   
160   return false;
161 }
162
163 //add a node
164 void Graph::addNode(Node *nd){
165   vector<Node *> lt=getAllNodes();
166
167   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE;++LI){
168     if(**LI==*nd)
169       return;
170   }
171   //chng
172   nodes[nd] =vector<graphListElement>(); //list<graphListElement>();
173 }
174
175 //add an edge
176 //this adds an edge ONLY when 
177 //the edge to be added doesn not already exist
178 //we "equate" two edges here only with their 
179 //end points
180 void Graph::addEdge(Edge ed, int w){
181   nodeList &ndList = nodes[ed.getFirst()];
182   Node *nd2=ed.getSecond();
183
184   if(findNodeInList(nodes[ed.getFirst()], nd2))
185     return;
186  
187   //ndList.push_front(graphListElement(nd2,w, ed.getRandId()));
188   ndList.push_back(graphListElement(nd2,w, ed.getRandId()));//chng
189   //sortNodeList(ed.getFirst(), ndList);
190
191   //sort(ndList.begin(), ndList.end(), NodeListSort());
192 }
193
194 //add an edge EVEN IF such an edge already exists
195 //this may make a multi-graph
196 //which does happen when we add dummy edges
197 //to the graph, for compensating for back-edges
198 void Graph::addEdgeForce(Edge ed){
199   //nodes[ed.getFirst()].push_front(graphListElement(ed.getSecond(),
200   //ed.getWeight(), ed.getRandId()));
201   nodes[ed.getFirst()].push_back
202     (graphListElement(ed.getSecond(), ed.getWeight(), ed.getRandId()));
203
204   //sortNodeList(ed.getFirst(), nodes[ed.getFirst()]);
205   //sort(nodes[ed.getFirst()].begin(), nodes[ed.getFirst()].end(), NodeListSort());
206 }
207
208 //remove an edge
209 //Note that it removes just one edge,
210 //the first edge that is encountered
211 void Graph::removeEdge(Edge ed){
212   nodeList &ndList = nodes[ed.getFirst()];
213   Node &nd2 = *ed.getSecond();
214
215   for(nodeList::iterator NI=ndList.begin(), NE=ndList.end(); NI!=NE ;++NI) {
216     if(*NI->element == nd2) {
217       ndList.erase(NI);
218       break;
219     }
220   }
221 }
222
223 //remove an edge with a given wt
224 //Note that it removes just one edge,
225 //the first edge that is encountered
226 void Graph::removeEdgeWithWt(Edge ed){
227   nodeList &ndList = nodes[ed.getFirst()];
228   Node &nd2 = *ed.getSecond();
229
230   for(nodeList::iterator NI=ndList.begin(), NE=ndList.end(); NI!=NE ;++NI) {
231     if(*NI->element == nd2 && NI->weight==ed.getWeight()) {
232       ndList.erase(NI);
233       break;
234     }
235   }
236 }
237
238 //set the weight of an edge
239 void Graph::setWeight(Edge ed){
240   graphListElement *El = findNodeInList(nodes[ed.getFirst()], ed.getSecond());
241   if (El)
242     El->weight=ed.getWeight();
243 }
244
245
246
247 //get the list of successor nodes
248 vector<Node *> Graph::getSuccNodes(Node *nd){
249   nodeMapTy::const_iterator nli = nodes.find(nd);
250   assert(nli != nodes.end() && "Node must be in nodes map");
251   const nodeList &nl = getNodeList(nd);//getSortedNodeList(nd);
252
253   vector<Node *> lt;
254   for(nodeList::const_iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI)
255     lt.push_back(NI->element);
256
257   return lt;
258 }
259
260 //get the number of outgoing edges
261 int Graph::getNumberOfOutgoingEdges(Node *nd) const {
262   nodeMapTy::const_iterator nli = nodes.find(nd);
263   assert(nli != nodes.end() && "Node must be in nodes map");
264   const nodeList &nl = nli->second;
265
266   int count=0;
267   for(nodeList::const_iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI)
268     count++;
269
270   return count;
271 }
272
273 //get the list of predecessor nodes
274 vector<Node *> Graph::getPredNodes(Node *nd){
275   vector<Node *> lt;
276   for(nodeMapTy::const_iterator EI=nodes.begin(), EE=nodes.end(); EI!=EE ;++EI){
277     Node *lnode=EI->first;
278     const nodeList &nl = getNodeList(lnode);
279
280     const graphListElement *N = findNodeInList(nl, nd);
281     if (N) lt.push_back(lnode);
282   }
283   return lt;
284 }
285
286 //get the number of predecessor nodes
287 int Graph::getNumberOfIncomingEdges(Node *nd){
288   int count=0;
289   for(nodeMapTy::const_iterator EI=nodes.begin(), EE=nodes.end(); EI!=EE ;++EI){
290     Node *lnode=EI->first;
291     const nodeList &nl = getNodeList(lnode);
292     for(Graph::nodeList::const_iterator NI = nl.begin(), NE=nl.end(); NI != NE; 
293         ++NI)
294       if (*NI->element== *nd)
295         count++;
296   }
297   return count;
298 }
299
300 //get the list of all the vertices in graph
301 vector<Node *> Graph::getAllNodes() const{
302   vector<Node *> lt;
303   for(nodeMapTy::const_iterator x=nodes.begin(), en=nodes.end(); x != en; ++x)
304     lt.push_back(x->first);
305
306   return lt;
307 }
308
309 //get the list of all the vertices in graph
310 vector<Node *> Graph::getAllNodes(){
311   vector<Node *> lt;
312   for(nodeMapTy::const_iterator x=nodes.begin(), en=nodes.end(); x != en; ++x)
313     lt.push_back(x->first);
314
315   return lt;
316 }
317
318 //class to compare two nodes in graph
319 //based on their wt: this is used in
320 //finding the maximal spanning tree
321 struct compare_nodes {
322   bool operator()(Node *n1, Node *n2){
323     return n1->getWeight() < n2->getWeight();
324   }
325 };
326
327
328 static void printNode(Node *nd){
329   cerr<<"Node:"<<nd->getElement()->getName()<<"\n";
330 }
331
332 //Get the Maximal spanning tree (also a graph)
333 //of the graph
334 Graph* Graph::getMaxSpanningTree(){
335   //assume connected graph
336  
337   Graph *st=new Graph();//max spanning tree, undirected edges
338   int inf=9999999;//largest key
339   vector<Node *> lt = getAllNodes();
340   
341   //initially put all vertices in vector vt
342   //assign wt(root)=0
343   //wt(others)=infinity
344   //
345   //now:
346   //pull out u: a vertex frm vt of min wt
347   //for all vertices w in vt, 
348   //if wt(w) greater than 
349   //the wt(u->w), then assign
350   //wt(w) to be wt(u->w).
351   //
352   //make parent(u)=w in the spanning tree
353   //keep pulling out vertices from vt till it is empty
354
355   vector<Node *> vt;
356   
357   map<Node*, Node* > parent;
358   map<Node*, int > ed_weight;
359
360   //initialize: wt(root)=0, wt(others)=infinity
361   //parent(root)=NULL, parent(others) not defined (but not null)
362   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
363     Node *thisNode=*LI;
364     if(*thisNode == *getRoot()){
365       thisNode->setWeight(0);
366       parent[thisNode]=NULL;
367       ed_weight[thisNode]=0;
368     }
369     else{ 
370       thisNode->setWeight(inf);
371     }
372     st->addNode(thisNode);//add all nodes to spanning tree
373     //we later need to assign edges in the tree
374     vt.push_back(thisNode); //pushed all nodes in vt
375   }
376
377   //keep pulling out vertex of min wt from vt
378   while(!vt.empty()){
379     Node *u=*(min_element(vt.begin(), vt.end(), compare_nodes()));
380     DEBUG(cerr<<"popped wt"<<(u)->getWeight()<<"\n";
381           printNode(u));
382
383     if(parent[u]!=NULL){ //so not root
384       Edge edge(parent[u],u, ed_weight[u]); //assign edge in spanning tree
385       st->addEdge(edge,ed_weight[u]);
386
387       DEBUG(cerr<<"added:\n";
388             printEdge(edge));
389     }
390
391     //vt.erase(u);
392     
393     //remove u frm vt
394     for(vector<Node *>::iterator VI=vt.begin(), VE=vt.end(); VI!=VE; ++VI){
395       if(**VI==*u){
396         vt.erase(VI);
397         break;
398       }
399     }
400     
401     //assign wt(v) to all adjacent vertices v of u
402     //only if v is in vt
403     Graph::nodeList &nl = getNodeList(u);
404     for(nodeList::iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI){
405       Node *v=NI->element;
406       int weight=-NI->weight;
407       //check if v is in vt
408       bool contains=false;
409       for(vector<Node *>::iterator VI=vt.begin(), VE=vt.end(); VI!=VE; ++VI){
410         if(**VI==*v){
411           contains=true;
412           break;
413         }
414       }
415       DEBUG(cerr<<"wt:v->wt"<<weight<<":"<<v->getWeight()<<"\n";
416             printNode(v);cerr<<"node wt:"<<(*v).weight<<"\n");
417
418       //so if v in in vt, change wt(v) to wt(u->v)
419       //only if wt(u->v)<wt(v)
420       if(contains && weight<v->getWeight()){
421         parent[v]=u;
422         ed_weight[v]=weight;
423         v->setWeight(weight);
424
425         DEBUG(cerr<<v->getWeight()<<":Set weight------\n";
426               printGraph();
427               printEdge(Edge(u,v,weight)));
428       }
429     }
430   }
431   return st;
432 }
433
434 //print the graph (for debugging)   
435 void Graph::printGraph(){
436    vector<Node *> lt=getAllNodes();
437    cerr<<"Graph---------------------\n";
438    for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
439      cerr<<((*LI)->getElement())->getName()<<"->";
440      Graph::nodeList &nl = getNodeList(*LI);
441      for(Graph::nodeList::iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI){
442        cerr<<":"<<"("<<(NI->element->getElement())
443          ->getName()<<":"<<NI->element->getWeight()<<","<<NI->weight<<")";
444      }
445      cerr<<"--------\n";
446    }
447 }
448
449
450 //get a list of nodes in the graph
451 //in r-topological sorted order
452 //note that we assumed graph to be connected
453 vector<Node *> Graph::reverseTopologicalSort(){
454   vector <Node *> toReturn;
455   vector<Node *> lt=getAllNodes();
456   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
457     if((*LI)->getWeight()!=GREY && (*LI)->getWeight()!=BLACK)
458       DFS_Visit(*LI, toReturn);
459   }
460
461   return toReturn;
462 }
463
464 //a private method for doing DFS traversal of graph
465 //this is used in determining the reverse topological sort 
466 //of the graph
467 void Graph::DFS_Visit(Node *nd, vector<Node *> &toReturn){
468   nd->setWeight(GREY);
469   vector<Node *> lt=getSuccNodes(nd);
470   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
471     if((*LI)->getWeight()!=GREY && (*LI)->getWeight()!=BLACK)
472       DFS_Visit(*LI, toReturn);
473   }
474   toReturn.push_back(nd);
475 }
476
477 //Ordinarily, the graph is directional
478 //this converts the graph into an 
479 //undirectional graph
480 //This is done by adding an edge
481 //v->u for all existing edges u->v
482 void Graph::makeUnDirectional(){
483   vector<Node* > allNodes=getAllNodes();
484   for(vector<Node *>::iterator NI=allNodes.begin(), NE=allNodes.end(); NI!=NE; 
485       ++NI) {
486     nodeList &nl = getNodeList(*NI);
487     for(nodeList::iterator NLI=nl.begin(), NLE=nl.end(); NLI!=NLE; ++NLI){
488       Edge ed(NLI->element, *NI, NLI->weight);
489       if(!hasEdgeAndWt(ed)){
490         DEBUG(cerr<<"######doesn't hv\n";
491               printEdge(ed));
492         addEdgeForce(ed);
493       }
494     }
495   }
496 }
497
498 //reverse the sign of weights on edges
499 //this way, max-spanning tree could be obtained
500 //usin min-spanning tree, and vice versa
501 void Graph::reverseWts(){
502   vector<Node *> allNodes=getAllNodes();
503   for(vector<Node *>::iterator NI=allNodes.begin(), NE=allNodes.end(); NI!=NE; 
504       ++NI) {
505     nodeList &node_list = getNodeList(*NI);
506     for(nodeList::iterator NLI=nodes[*NI].begin(), NLE=nodes[*NI].end(); 
507         NLI!=NLE; ++NLI)
508       NLI->weight=-NLI->weight;
509   }
510 }
511
512
513 //getting the backedges in a graph
514 //Its a variation of DFS to get the backedges in the graph
515 //We get back edges by associating a time
516 //and a color with each vertex.
517 //The time of a vertex is the time when it was first visited
518 //The color of a vertex is initially WHITE,
519 //Changes to GREY when it is first visited,
520 //and changes to BLACK when ALL its neighbors
521 //have been visited
522 //So we have a back edge when we meet a successor of
523 //a node with smaller time, and GREY color
524 void Graph::getBackEdges(vector<Edge > &be, map<Node *, int> &d){
525   map<Node *, Color > color;
526   int time=0;
527
528   getBackEdgesVisit(getRoot(), be, color, d, time);
529 }
530
531 //helper function to get back edges: it is called by 
532 //the "getBackEdges" function above
533 void Graph::getBackEdgesVisit(Node *u, vector<Edge > &be,
534                               map<Node *, Color > &color,
535                               map<Node *, int > &d, int &time) {
536   color[u]=GREY;
537   time++;
538   d[u]=time;
539
540   vector<graphListElement> &succ_list = getNodeList(u);
541   
542   for(vector<graphListElement>::iterator vl=succ_list.begin(), 
543         ve=succ_list.end(); vl!=ve; ++vl){
544     Node *v=vl->element;
545     if(color[v]!=GREY && color[v]!=BLACK){
546       getBackEdgesVisit(v, be, color, d, time);
547     }
548     
549     //now checking for d and f vals
550     if(color[v]==GREY){
551       //so v is ancestor of u if time of u > time of v
552       if(d[u] >= d[v]){
553         Edge *ed=new Edge(u, v,vl->weight, vl->randId);
554         if (!(*u == *getExit() && *v == *getRoot()))
555           be.push_back(*ed);      // choose the forward edges
556       }
557     }
558   }
559   color[u]=BLACK;//done with visiting the node and its neighbors
560 }
561
562