Changes For Bug 352
[oota-llvm.git] / lib / Transforms / Instrumentation / ProfilePaths / Graph.cpp
1 //===-- Graph.cpp - Implements Graph class --------------------------------===//
2 // 
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by the LLVM research group and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 // 
8 //===----------------------------------------------------------------------===//
9 //
10 // This implements Graph for helping in trace generation This graph gets used by
11 // "ProfilePaths" class.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "Graph.h"
16 #include "llvm/Instructions.h"
17 #include "llvm/Support/Debug.h"
18 #include <algorithm>
19
20 using std::vector;
21
22 namespace llvm {
23
24 const graphListElement *findNodeInList(const Graph::nodeList &NL,
25                                               Node *N) {
26   for(Graph::nodeList::const_iterator NI = NL.begin(), NE=NL.end(); NI != NE; 
27       ++NI)
28     if (*NI->element== *N)
29       return &*NI;
30   return 0;
31 }
32
33 graphListElement *findNodeInList(Graph::nodeList &NL, Node *N) {
34   for(Graph::nodeList::iterator NI = NL.begin(), NE=NL.end(); NI != NE; ++NI)
35     if (*NI->element== *N)
36       return &*NI;
37   return 0;
38 }
39
40 //graph constructor with root and exit specified
41 Graph::Graph(std::vector<Node*> n, std::vector<Edge> e, 
42              Node *rt, Node *lt){
43   strt=rt;
44   ext=lt;
45   for(vector<Node* >::iterator x=n.begin(), en=n.end(); x!=en; ++x)
46     //nodes[*x] = list<graphListElement>();
47     nodes[*x] = vector<graphListElement>();
48
49   for(vector<Edge >::iterator x=e.begin(), en=e.end(); x!=en; ++x){
50     Edge ee=*x;
51     int w=ee.getWeight();
52     //nodes[ee.getFirst()].push_front(graphListElement(ee.getSecond(),w, ee.getRandId()));   
53     nodes[ee.getFirst()].push_back(graphListElement(ee.getSecond(),w, ee.getRandId()));
54   }
55   
56 }
57
58 //sorting edgelist, called by backEdgeVist ONLY!!!
59 Graph::nodeList &Graph::sortNodeList(Node *par, nodeList &nl, vector<Edge> &be){
60   assert(par && "null node pointer");
61   BasicBlock *bbPar = par->getElement();
62   
63   if(nl.size()<=1) return nl;
64   if(getExit() == par) return nl;
65
66   for(nodeList::iterator NLI = nl.begin(), NLE = nl.end()-1; NLI != NLE; ++NLI){
67     nodeList::iterator min = NLI;
68     for(nodeList::iterator LI = NLI+1, LE = nl.end(); LI!=LE; ++LI){
69       //if LI < min, min = LI
70       if(min->element->getElement() == LI->element->getElement() &&
71          min->element == getExit()){
72
73         //same successors: so might be exit???
74         //if it is exit, then see which is backedge
75         //check if LI is a left back edge!
76
77         TerminatorInst *tti = par->getElement()->getTerminator();
78         BranchInst *ti =  cast<BranchInst>(tti);
79
80         assert(ti && "not a branch");
81         assert(ti->getNumSuccessors()==2 && "less successors!");
82         
83         BasicBlock *tB = ti->getSuccessor(0);
84         BasicBlock *fB = ti->getSuccessor(1);
85         //so one of LI or min must be back edge!
86         //Algo: if succ(0)!=LI (and so !=min) then succ(0) is backedge
87         //and then see which of min or LI is backedge
88         //THEN if LI is in be, then min=LI
89         if(LI->element->getElement() != tB){//so backedge must be made min!
90           for(vector<Edge>::iterator VBEI = be.begin(), VBEE = be.end();
91               VBEI != VBEE; ++VBEI){
92             if(VBEI->getRandId() == LI->randId){
93               min = LI;
94               break;
95             }
96             else if(VBEI->getRandId() == min->randId)
97               break;
98           }
99         }
100         else{// if(LI->element->getElement() != fB)
101           for(vector<Edge>::iterator VBEI = be.begin(), VBEE = be.end();
102               VBEI != VBEE; ++VBEI){
103             if(VBEI->getRandId() == min->randId){
104               min = LI;
105               break;
106             }
107             else if(VBEI->getRandId() == LI->randId)
108               break;
109           }
110         }
111       }
112       
113       else if (min->element->getElement() != LI->element->getElement()){
114         TerminatorInst *tti = par->getElement()->getTerminator();
115         BranchInst *ti =  cast<BranchInst>(tti);
116         assert(ti && "not a branch");
117
118         if(ti->getNumSuccessors()<=1) continue;
119         
120         assert(ti->getNumSuccessors()==2 && "less successors!");
121         
122         BasicBlock *tB = ti->getSuccessor(0);
123         BasicBlock *fB = ti->getSuccessor(1);
124         
125         if(tB == LI->element->getElement() || fB == min->element->getElement())
126           min = LI;
127       }
128     }
129     
130     graphListElement tmpElmnt = *min;
131     *min = *NLI;
132     *NLI = tmpElmnt;
133   }
134   return nl;
135 }
136
137 //check whether graph has an edge
138 //having an edge simply means that there is an edge in the graph
139 //which has same endpoints as the given edge
140 bool Graph::hasEdge(Edge ed){
141   if(ed.isNull())
142     return false;
143
144   nodeList &nli= nodes[ed.getFirst()]; //getNodeList(ed.getFirst());
145   Node *nd2=ed.getSecond();
146
147   return (findNodeInList(nli,nd2)!=NULL);
148
149 }
150
151
152 //check whether graph has an edge, with a given wt
153 //having an edge simply means that there is an edge in the graph
154 //which has same endpoints as the given edge
155 //This function checks, moreover, that the wt of edge matches too
156 bool Graph::hasEdgeAndWt(Edge ed){
157   if(ed.isNull())
158     return false;
159
160   Node *nd2=ed.getSecond();
161   nodeList &nli = nodes[ed.getFirst()];//getNodeList(ed.getFirst());
162   
163   for(nodeList::iterator NI=nli.begin(), NE=nli.end(); NI!=NE; ++NI)
164     if(*NI->element == *nd2 && ed.getWeight()==NI->weight)
165       return true;
166   
167   return false;
168 }
169
170 //add a node
171 void Graph::addNode(Node *nd){
172   vector<Node *> lt=getAllNodes();
173
174   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE;++LI){
175     if(**LI==*nd)
176       return;
177   }
178   //chng
179   nodes[nd] =vector<graphListElement>(); //list<graphListElement>();
180 }
181
182 //add an edge
183 //this adds an edge ONLY when 
184 //the edge to be added does not already exist
185 //we "equate" two edges here only with their 
186 //end points
187 void Graph::addEdge(Edge ed, int w){
188   nodeList &ndList = nodes[ed.getFirst()];
189   Node *nd2=ed.getSecond();
190
191   if(findNodeInList(nodes[ed.getFirst()], nd2))
192     return;
193  
194   //ndList.push_front(graphListElement(nd2,w, ed.getRandId()));
195   ndList.push_back(graphListElement(nd2,w, ed.getRandId()));//chng
196   //sortNodeList(ed.getFirst(), ndList);
197
198   //sort(ndList.begin(), ndList.end(), NodeListSort());
199 }
200
201 //add an edge EVEN IF such an edge already exists
202 //this may make a multi-graph
203 //which does happen when we add dummy edges
204 //to the graph, for compensating for back-edges
205 void Graph::addEdgeForce(Edge ed){
206   //nodes[ed.getFirst()].push_front(graphListElement(ed.getSecond(),
207   //ed.getWeight(), ed.getRandId()));
208   nodes[ed.getFirst()].push_back
209     (graphListElement(ed.getSecond(), ed.getWeight(), ed.getRandId()));
210
211   //sortNodeList(ed.getFirst(), nodes[ed.getFirst()]);
212   //sort(nodes[ed.getFirst()].begin(), nodes[ed.getFirst()].end(), NodeListSort());
213 }
214
215 //remove an edge
216 //Note that it removes just one edge,
217 //the first edge that is encountered
218 void Graph::removeEdge(Edge ed){
219   nodeList &ndList = nodes[ed.getFirst()];
220   Node &nd2 = *ed.getSecond();
221
222   for(nodeList::iterator NI=ndList.begin(), NE=ndList.end(); NI!=NE ;++NI) {
223     if(*NI->element == nd2) {
224       ndList.erase(NI);
225       break;
226     }
227   }
228 }
229
230 //remove an edge with a given wt
231 //Note that it removes just one edge,
232 //the first edge that is encountered
233 void Graph::removeEdgeWithWt(Edge ed){
234   nodeList &ndList = nodes[ed.getFirst()];
235   Node &nd2 = *ed.getSecond();
236
237   for(nodeList::iterator NI=ndList.begin(), NE=ndList.end(); NI!=NE ;++NI) {
238     if(*NI->element == nd2 && NI->weight==ed.getWeight()) {
239       ndList.erase(NI);
240       break;
241     }
242   }
243 }
244
245 //set the weight of an edge
246 void Graph::setWeight(Edge ed){
247   graphListElement *El = findNodeInList(nodes[ed.getFirst()], ed.getSecond());
248   if (El)
249     El->weight=ed.getWeight();
250 }
251
252
253
254 //get the list of successor nodes
255 vector<Node *> Graph::getSuccNodes(Node *nd){
256   nodeMapTy::const_iterator nli = nodes.find(nd);
257   assert(nli != nodes.end() && "Node must be in nodes map");
258   const nodeList &nl = getNodeList(nd);//getSortedNodeList(nd);
259
260   vector<Node *> lt;
261   for(nodeList::const_iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI)
262     lt.push_back(NI->element);
263
264   return lt;
265 }
266
267 //get the number of outgoing edges
268 int Graph::getNumberOfOutgoingEdges(Node *nd) const {
269   nodeMapTy::const_iterator nli = nodes.find(nd);
270   assert(nli != nodes.end() && "Node must be in nodes map");
271   const nodeList &nl = nli->second;
272
273   int count=0;
274   for(nodeList::const_iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI)
275     count++;
276
277   return count;
278 }
279
280 //get the list of predecessor nodes
281 vector<Node *> Graph::getPredNodes(Node *nd){
282   vector<Node *> lt;
283   for(nodeMapTy::const_iterator EI=nodes.begin(), EE=nodes.end(); EI!=EE ;++EI){
284     Node *lnode=EI->first;
285     const nodeList &nl = getNodeList(lnode);
286
287     const graphListElement *N = findNodeInList(nl, nd);
288     if (N) lt.push_back(lnode);
289   }
290   return lt;
291 }
292
293 //get the number of predecessor nodes
294 int Graph::getNumberOfIncomingEdges(Node *nd){
295   int count=0;
296   for(nodeMapTy::const_iterator EI=nodes.begin(), EE=nodes.end(); EI!=EE ;++EI){
297     Node *lnode=EI->first;
298     const nodeList &nl = getNodeList(lnode);
299     for(Graph::nodeList::const_iterator NI = nl.begin(), NE=nl.end(); NI != NE; 
300         ++NI)
301       if (*NI->element== *nd)
302         count++;
303   }
304   return count;
305 }
306
307 //get the list of all the vertices in graph
308 vector<Node *> Graph::getAllNodes() const{
309   vector<Node *> lt;
310   for(nodeMapTy::const_iterator x=nodes.begin(), en=nodes.end(); x != en; ++x)
311     lt.push_back(x->first);
312
313   return lt;
314 }
315
316 //get the list of all the vertices in graph
317 vector<Node *> Graph::getAllNodes(){
318   vector<Node *> lt;
319   for(nodeMapTy::const_iterator x=nodes.begin(), en=nodes.end(); x != en; ++x)
320     lt.push_back(x->first);
321
322   return lt;
323 }
324
325 //class to compare two nodes in graph
326 //based on their wt: this is used in
327 //finding the maximal spanning tree
328 struct compare_nodes {
329   bool operator()(Node *n1, Node *n2){
330     return n1->getWeight() < n2->getWeight();
331   }
332 };
333
334
335 static void printNode(Node *nd){
336   std::cerr<<"Node:"<<nd->getElement()->getName()<<"\n";
337 }
338
339 //Get the Maximal spanning tree (also a graph)
340 //of the graph
341 Graph* Graph::getMaxSpanningTree(){
342   //assume connected graph
343  
344   Graph *st=new Graph();//max spanning tree, undirected edges
345   int inf=9999999;//largest key
346   vector<Node *> lt = getAllNodes();
347   
348   //initially put all vertices in vector vt
349   //assign wt(root)=0
350   //wt(others)=infinity
351   //
352   //now:
353   //pull out u: a vertex frm vt of min wt
354   //for all vertices w in vt, 
355   //if wt(w) greater than 
356   //the wt(u->w), then assign
357   //wt(w) to be wt(u->w).
358   //
359   //make parent(u)=w in the spanning tree
360   //keep pulling out vertices from vt till it is empty
361
362   vector<Node *> vt;
363   
364   std::map<Node*, Node* > parent;
365   std::map<Node*, int > ed_weight;
366
367   //initialize: wt(root)=0, wt(others)=infinity
368   //parent(root)=NULL, parent(others) not defined (but not null)
369   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
370     Node *thisNode=*LI;
371     if(*thisNode == *getRoot()){
372       thisNode->setWeight(0);
373       parent[thisNode]=NULL;
374       ed_weight[thisNode]=0;
375     }
376     else{ 
377       thisNode->setWeight(inf);
378     }
379     st->addNode(thisNode);//add all nodes to spanning tree
380     //we later need to assign edges in the tree
381     vt.push_back(thisNode); //pushed all nodes in vt
382   }
383
384   //keep pulling out vertex of min wt from vt
385   while(!vt.empty()){
386     Node *u=*(min_element(vt.begin(), vt.end(), compare_nodes()));
387     DEBUG(std::cerr<<"popped wt"<<(u)->getWeight()<<"\n";
388           printNode(u));
389
390     if(parent[u]!=NULL){ //so not root
391       Edge edge(parent[u],u, ed_weight[u]); //assign edge in spanning tree
392       st->addEdge(edge,ed_weight[u]);
393
394       DEBUG(std::cerr<<"added:\n";
395             printEdge(edge));
396     }
397
398     //vt.erase(u);
399     
400     //remove u frm vt
401     for(vector<Node *>::iterator VI=vt.begin(), VE=vt.end(); VI!=VE; ++VI){
402       if(**VI==*u){
403         vt.erase(VI);
404         break;
405       }
406     }
407     
408     //assign wt(v) to all adjacent vertices v of u
409     //only if v is in vt
410     Graph::nodeList &nl = getNodeList(u);
411     for(nodeList::iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI){
412       Node *v=NI->element;
413       int weight=-NI->weight;
414       //check if v is in vt
415       bool contains=false;
416       for(vector<Node *>::iterator VI=vt.begin(), VE=vt.end(); VI!=VE; ++VI){
417         if(**VI==*v){
418           contains=true;
419           break;
420         }
421       }
422       DEBUG(std::cerr<<"wt:v->wt"<<weight<<":"<<v->getWeight()<<"\n";
423             printNode(v);std::cerr<<"node wt:"<<(*v).weight<<"\n");
424
425       //so if v in in vt, change wt(v) to wt(u->v)
426       //only if wt(u->v)<wt(v)
427       if(contains && weight<v->getWeight()){
428         parent[v]=u;
429         ed_weight[v]=weight;
430         v->setWeight(weight);
431
432         DEBUG(std::cerr<<v->getWeight()<<":Set weight------\n";
433               printGraph();
434               printEdge(Edge(u,v,weight)));
435       }
436     }
437   }
438   return st;
439 }
440
441 //print the graph (for debugging)   
442 void Graph::printGraph(){
443    vector<Node *> lt=getAllNodes();
444    std::cerr<<"Graph---------------------\n";
445    for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
446      std::cerr<<((*LI)->getElement())->getName()<<"->";
447      Graph::nodeList &nl = getNodeList(*LI);
448      for(Graph::nodeList::iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI){
449        std::cerr<<":"<<"("<<(NI->element->getElement())
450          ->getName()<<":"<<NI->element->getWeight()<<","<<NI->weight<<")";
451      }
452      std::cerr<<"--------\n";
453    }
454 }
455
456
457 //get a list of nodes in the graph
458 //in r-topological sorted order
459 //note that we assumed graph to be connected
460 vector<Node *> Graph::reverseTopologicalSort(){
461   vector <Node *> toReturn;
462   vector<Node *> lt=getAllNodes();
463   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
464     if((*LI)->getWeight()!=GREY && (*LI)->getWeight()!=BLACK)
465       DFS_Visit(*LI, toReturn);
466   }
467
468   return toReturn;
469 }
470
471 //a private method for doing DFS traversal of graph
472 //this is used in determining the reverse topological sort 
473 //of the graph
474 void Graph::DFS_Visit(Node *nd, vector<Node *> &toReturn){
475   nd->setWeight(GREY);
476   vector<Node *> lt=getSuccNodes(nd);
477   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
478     if((*LI)->getWeight()!=GREY && (*LI)->getWeight()!=BLACK)
479       DFS_Visit(*LI, toReturn);
480   }
481   toReturn.push_back(nd);
482 }
483
484 //Ordinarily, the graph is directional
485 //this converts the graph into an 
486 //undirectional graph
487 //This is done by adding an edge
488 //v->u for all existing edges u->v
489 void Graph::makeUnDirectional(){
490   vector<Node* > allNodes=getAllNodes();
491   for(vector<Node *>::iterator NI=allNodes.begin(), NE=allNodes.end(); NI!=NE; 
492       ++NI) {
493     nodeList &nl = getNodeList(*NI);
494     for(nodeList::iterator NLI=nl.begin(), NLE=nl.end(); NLI!=NLE; ++NLI){
495       Edge ed(NLI->element, *NI, NLI->weight);
496       if(!hasEdgeAndWt(ed)){
497         DEBUG(std::cerr<<"######doesn't hv\n";
498               printEdge(ed));
499         addEdgeForce(ed);
500       }
501     }
502   }
503 }
504
505 //reverse the sign of weights on edges
506 //this way, max-spanning tree could be obtained
507 //using min-spanning tree, and vice versa
508 void Graph::reverseWts(){
509   vector<Node *> allNodes=getAllNodes();
510   for(vector<Node *>::iterator NI=allNodes.begin(), NE=allNodes.end(); NI!=NE; 
511       ++NI) {
512     nodeList &node_list = getNodeList(*NI);
513     for(nodeList::iterator NLI=nodes[*NI].begin(), NLE=nodes[*NI].end(); 
514         NLI!=NLE; ++NLI)
515       NLI->weight=-NLI->weight;
516   }
517 }
518
519
520 //getting the backedges in a graph
521 //Its a variation of DFS to get the backedges in the graph
522 //We get back edges by associating a time
523 //and a color with each vertex.
524 //The time of a vertex is the time when it was first visited
525 //The color of a vertex is initially WHITE,
526 //Changes to GREY when it is first visited,
527 //and changes to BLACK when ALL its neighbors
528 //have been visited
529 //So we have a back edge when we meet a successor of
530 //a node with smaller time, and GREY color
531 void Graph::getBackEdges(vector<Edge > &be, std::map<Node *, int> &d){
532   std::map<Node *, Color > color;
533   int time=0;
534
535   getBackEdgesVisit(getRoot(), be, color, d, time);
536 }
537
538 //helper function to get back edges: it is called by 
539 //the "getBackEdges" function above
540 void Graph::getBackEdgesVisit(Node *u, vector<Edge > &be,
541                               std::map<Node *, Color > &color,
542                               std::map<Node *, int > &d, int &time) {
543   color[u]=GREY;
544   time++;
545   d[u]=time;
546
547   vector<graphListElement> &succ_list = getNodeList(u);
548   
549   for(vector<graphListElement>::iterator vl=succ_list.begin(), 
550         ve=succ_list.end(); vl!=ve; ++vl){
551     Node *v=vl->element;
552     if(color[v]!=GREY && color[v]!=BLACK){
553       getBackEdgesVisit(v, be, color, d, time);
554     }
555     
556     //now checking for d and f vals
557     if(color[v]==GREY){
558       //so v is ancestor of u if time of u > time of v
559       if(d[u] >= d[v]){
560         Edge *ed=new Edge(u, v,vl->weight, vl->randId);
561         if (!(*u == *getExit() && *v == *getRoot()))
562           be.push_back(*ed);      // choose the forward edges
563       }
564     }
565   }
566   color[u]=BLACK;//done with visiting the node and its neighbors
567 }
568
569 } // End llvm namespace