Add #include <iostream> since Value.h does not #include it any more.
[oota-llvm.git] / lib / Transforms / Instrumentation / ProfilePaths / Graph.cpp
1 //===-- Graph.cpp - Implements Graph class --------------------------------===//
2 // 
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by the LLVM research group and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 // 
8 //===----------------------------------------------------------------------===//
9 //
10 // This implements Graph for helping in trace generation This graph gets used by
11 // "ProfilePaths" class.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "Graph.h"
16 #include "llvm/iTerminators.h"
17 #include "Support/Debug.h"
18 #include <algorithm>
19 #include <iostream>
20
21 using std::vector;
22
23 namespace llvm {
24
25 const graphListElement *findNodeInList(const Graph::nodeList &NL,
26                                               Node *N) {
27   for(Graph::nodeList::const_iterator NI = NL.begin(), NE=NL.end(); NI != NE; 
28       ++NI)
29     if (*NI->element== *N)
30       return &*NI;
31   return 0;
32 }
33
34 graphListElement *findNodeInList(Graph::nodeList &NL, Node *N) {
35   for(Graph::nodeList::iterator NI = NL.begin(), NE=NL.end(); NI != NE; ++NI)
36     if (*NI->element== *N)
37       return &*NI;
38   return 0;
39 }
40
41 //graph constructor with root and exit specified
42 Graph::Graph(std::vector<Node*> n, std::vector<Edge> e, 
43              Node *rt, Node *lt){
44   strt=rt;
45   ext=lt;
46   for(vector<Node* >::iterator x=n.begin(), en=n.end(); x!=en; ++x)
47     //nodes[*x] = list<graphListElement>();
48     nodes[*x] = vector<graphListElement>();
49
50   for(vector<Edge >::iterator x=e.begin(), en=e.end(); x!=en; ++x){
51     Edge ee=*x;
52     int w=ee.getWeight();
53     //nodes[ee.getFirst()].push_front(graphListElement(ee.getSecond(),w, ee.getRandId()));   
54     nodes[ee.getFirst()].push_back(graphListElement(ee.getSecond(),w, ee.getRandId()));
55   }
56   
57 }
58
59 //sorting edgelist, called by backEdgeVist ONLY!!!
60 Graph::nodeList &Graph::sortNodeList(Node *par, nodeList &nl, vector<Edge> &be){
61   assert(par && "null node pointer");
62   BasicBlock *bbPar = par->getElement();
63   
64   if(nl.size()<=1) return nl;
65   if(getExit() == par) return nl;
66
67   for(nodeList::iterator NLI = nl.begin(), NLE = nl.end()-1; NLI != NLE; ++NLI){
68     nodeList::iterator min = NLI;
69     for(nodeList::iterator LI = NLI+1, LE = nl.end(); LI!=LE; ++LI){
70       //if LI < min, min = LI
71       if(min->element->getElement() == LI->element->getElement() &&
72          min->element == getExit()){
73
74         //same successors: so might be exit???
75         //if it is exit, then see which is backedge
76         //check if LI is a left back edge!
77
78         TerminatorInst *tti = par->getElement()->getTerminator();
79         BranchInst *ti =  cast<BranchInst>(tti);
80
81         assert(ti && "not a branch");
82         assert(ti->getNumSuccessors()==2 && "less successors!");
83         
84         BasicBlock *tB = ti->getSuccessor(0);
85         BasicBlock *fB = ti->getSuccessor(1);
86         //so one of LI or min must be back edge!
87         //Algo: if succ(0)!=LI (and so !=min) then succ(0) is backedge
88         //and then see which of min or LI is backedge
89         //THEN if LI is in be, then min=LI
90         if(LI->element->getElement() != tB){//so backedge must be made min!
91           for(vector<Edge>::iterator VBEI = be.begin(), VBEE = be.end();
92               VBEI != VBEE; ++VBEI){
93             if(VBEI->getRandId() == LI->randId){
94               min = LI;
95               break;
96             }
97             else if(VBEI->getRandId() == min->randId)
98               break;
99           }
100         }
101         else{// if(LI->element->getElement() != fB)
102           for(vector<Edge>::iterator VBEI = be.begin(), VBEE = be.end();
103               VBEI != VBEE; ++VBEI){
104             if(VBEI->getRandId() == min->randId){
105               min = LI;
106               break;
107             }
108             else if(VBEI->getRandId() == LI->randId)
109               break;
110           }
111         }
112       }
113       
114       else if (min->element->getElement() != LI->element->getElement()){
115         TerminatorInst *tti = par->getElement()->getTerminator();
116         BranchInst *ti =  cast<BranchInst>(tti);
117         assert(ti && "not a branch");
118
119         if(ti->getNumSuccessors()<=1) continue;
120         
121         assert(ti->getNumSuccessors()==2 && "less successors!");
122         
123         BasicBlock *tB = ti->getSuccessor(0);
124         BasicBlock *fB = ti->getSuccessor(1);
125         
126         if(tB == LI->element->getElement() || fB == min->element->getElement())
127           min = LI;
128       }
129     }
130     
131     graphListElement tmpElmnt = *min;
132     *min = *NLI;
133     *NLI = tmpElmnt;
134   }
135   return nl;
136 }
137
138 //check whether graph has an edge
139 //having an edge simply means that there is an edge in the graph
140 //which has same endpoints as the given edge
141 bool Graph::hasEdge(Edge ed){
142   if(ed.isNull())
143     return false;
144
145   nodeList &nli= nodes[ed.getFirst()]; //getNodeList(ed.getFirst());
146   Node *nd2=ed.getSecond();
147
148   return (findNodeInList(nli,nd2)!=NULL);
149
150 }
151
152
153 //check whether graph has an edge, with a given wt
154 //having an edge simply means that there is an edge in the graph
155 //which has same endpoints as the given edge
156 //This function checks, moreover, that the wt of edge matches too
157 bool Graph::hasEdgeAndWt(Edge ed){
158   if(ed.isNull())
159     return false;
160
161   Node *nd2=ed.getSecond();
162   nodeList &nli = nodes[ed.getFirst()];//getNodeList(ed.getFirst());
163   
164   for(nodeList::iterator NI=nli.begin(), NE=nli.end(); NI!=NE; ++NI)
165     if(*NI->element == *nd2 && ed.getWeight()==NI->weight)
166       return true;
167   
168   return false;
169 }
170
171 //add a node
172 void Graph::addNode(Node *nd){
173   vector<Node *> lt=getAllNodes();
174
175   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE;++LI){
176     if(**LI==*nd)
177       return;
178   }
179   //chng
180   nodes[nd] =vector<graphListElement>(); //list<graphListElement>();
181 }
182
183 //add an edge
184 //this adds an edge ONLY when 
185 //the edge to be added does not already exist
186 //we "equate" two edges here only with their 
187 //end points
188 void Graph::addEdge(Edge ed, int w){
189   nodeList &ndList = nodes[ed.getFirst()];
190   Node *nd2=ed.getSecond();
191
192   if(findNodeInList(nodes[ed.getFirst()], nd2))
193     return;
194  
195   //ndList.push_front(graphListElement(nd2,w, ed.getRandId()));
196   ndList.push_back(graphListElement(nd2,w, ed.getRandId()));//chng
197   //sortNodeList(ed.getFirst(), ndList);
198
199   //sort(ndList.begin(), ndList.end(), NodeListSort());
200 }
201
202 //add an edge EVEN IF such an edge already exists
203 //this may make a multi-graph
204 //which does happen when we add dummy edges
205 //to the graph, for compensating for back-edges
206 void Graph::addEdgeForce(Edge ed){
207   //nodes[ed.getFirst()].push_front(graphListElement(ed.getSecond(),
208   //ed.getWeight(), ed.getRandId()));
209   nodes[ed.getFirst()].push_back
210     (graphListElement(ed.getSecond(), ed.getWeight(), ed.getRandId()));
211
212   //sortNodeList(ed.getFirst(), nodes[ed.getFirst()]);
213   //sort(nodes[ed.getFirst()].begin(), nodes[ed.getFirst()].end(), NodeListSort());
214 }
215
216 //remove an edge
217 //Note that it removes just one edge,
218 //the first edge that is encountered
219 void Graph::removeEdge(Edge ed){
220   nodeList &ndList = nodes[ed.getFirst()];
221   Node &nd2 = *ed.getSecond();
222
223   for(nodeList::iterator NI=ndList.begin(), NE=ndList.end(); NI!=NE ;++NI) {
224     if(*NI->element == nd2) {
225       ndList.erase(NI);
226       break;
227     }
228   }
229 }
230
231 //remove an edge with a given wt
232 //Note that it removes just one edge,
233 //the first edge that is encountered
234 void Graph::removeEdgeWithWt(Edge ed){
235   nodeList &ndList = nodes[ed.getFirst()];
236   Node &nd2 = *ed.getSecond();
237
238   for(nodeList::iterator NI=ndList.begin(), NE=ndList.end(); NI!=NE ;++NI) {
239     if(*NI->element == nd2 && NI->weight==ed.getWeight()) {
240       ndList.erase(NI);
241       break;
242     }
243   }
244 }
245
246 //set the weight of an edge
247 void Graph::setWeight(Edge ed){
248   graphListElement *El = findNodeInList(nodes[ed.getFirst()], ed.getSecond());
249   if (El)
250     El->weight=ed.getWeight();
251 }
252
253
254
255 //get the list of successor nodes
256 vector<Node *> Graph::getSuccNodes(Node *nd){
257   nodeMapTy::const_iterator nli = nodes.find(nd);
258   assert(nli != nodes.end() && "Node must be in nodes map");
259   const nodeList &nl = getNodeList(nd);//getSortedNodeList(nd);
260
261   vector<Node *> lt;
262   for(nodeList::const_iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI)
263     lt.push_back(NI->element);
264
265   return lt;
266 }
267
268 //get the number of outgoing edges
269 int Graph::getNumberOfOutgoingEdges(Node *nd) const {
270   nodeMapTy::const_iterator nli = nodes.find(nd);
271   assert(nli != nodes.end() && "Node must be in nodes map");
272   const nodeList &nl = nli->second;
273
274   int count=0;
275   for(nodeList::const_iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI)
276     count++;
277
278   return count;
279 }
280
281 //get the list of predecessor nodes
282 vector<Node *> Graph::getPredNodes(Node *nd){
283   vector<Node *> lt;
284   for(nodeMapTy::const_iterator EI=nodes.begin(), EE=nodes.end(); EI!=EE ;++EI){
285     Node *lnode=EI->first;
286     const nodeList &nl = getNodeList(lnode);
287
288     const graphListElement *N = findNodeInList(nl, nd);
289     if (N) lt.push_back(lnode);
290   }
291   return lt;
292 }
293
294 //get the number of predecessor nodes
295 int Graph::getNumberOfIncomingEdges(Node *nd){
296   int count=0;
297   for(nodeMapTy::const_iterator EI=nodes.begin(), EE=nodes.end(); EI!=EE ;++EI){
298     Node *lnode=EI->first;
299     const nodeList &nl = getNodeList(lnode);
300     for(Graph::nodeList::const_iterator NI = nl.begin(), NE=nl.end(); NI != NE; 
301         ++NI)
302       if (*NI->element== *nd)
303         count++;
304   }
305   return count;
306 }
307
308 //get the list of all the vertices in graph
309 vector<Node *> Graph::getAllNodes() const{
310   vector<Node *> lt;
311   for(nodeMapTy::const_iterator x=nodes.begin(), en=nodes.end(); x != en; ++x)
312     lt.push_back(x->first);
313
314   return lt;
315 }
316
317 //get the list of all the vertices in graph
318 vector<Node *> Graph::getAllNodes(){
319   vector<Node *> lt;
320   for(nodeMapTy::const_iterator x=nodes.begin(), en=nodes.end(); x != en; ++x)
321     lt.push_back(x->first);
322
323   return lt;
324 }
325
326 //class to compare two nodes in graph
327 //based on their wt: this is used in
328 //finding the maximal spanning tree
329 struct compare_nodes {
330   bool operator()(Node *n1, Node *n2){
331     return n1->getWeight() < n2->getWeight();
332   }
333 };
334
335
336 static void printNode(Node *nd){
337   std::cerr<<"Node:"<<nd->getElement()->getName()<<"\n";
338 }
339
340 //Get the Maximal spanning tree (also a graph)
341 //of the graph
342 Graph* Graph::getMaxSpanningTree(){
343   //assume connected graph
344  
345   Graph *st=new Graph();//max spanning tree, undirected edges
346   int inf=9999999;//largest key
347   vector<Node *> lt = getAllNodes();
348   
349   //initially put all vertices in vector vt
350   //assign wt(root)=0
351   //wt(others)=infinity
352   //
353   //now:
354   //pull out u: a vertex frm vt of min wt
355   //for all vertices w in vt, 
356   //if wt(w) greater than 
357   //the wt(u->w), then assign
358   //wt(w) to be wt(u->w).
359   //
360   //make parent(u)=w in the spanning tree
361   //keep pulling out vertices from vt till it is empty
362
363   vector<Node *> vt;
364   
365   std::map<Node*, Node* > parent;
366   std::map<Node*, int > ed_weight;
367
368   //initialize: wt(root)=0, wt(others)=infinity
369   //parent(root)=NULL, parent(others) not defined (but not null)
370   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
371     Node *thisNode=*LI;
372     if(*thisNode == *getRoot()){
373       thisNode->setWeight(0);
374       parent[thisNode]=NULL;
375       ed_weight[thisNode]=0;
376     }
377     else{ 
378       thisNode->setWeight(inf);
379     }
380     st->addNode(thisNode);//add all nodes to spanning tree
381     //we later need to assign edges in the tree
382     vt.push_back(thisNode); //pushed all nodes in vt
383   }
384
385   //keep pulling out vertex of min wt from vt
386   while(!vt.empty()){
387     Node *u=*(min_element(vt.begin(), vt.end(), compare_nodes()));
388     DEBUG(std::cerr<<"popped wt"<<(u)->getWeight()<<"\n";
389           printNode(u));
390
391     if(parent[u]!=NULL){ //so not root
392       Edge edge(parent[u],u, ed_weight[u]); //assign edge in spanning tree
393       st->addEdge(edge,ed_weight[u]);
394
395       DEBUG(std::cerr<<"added:\n";
396             printEdge(edge));
397     }
398
399     //vt.erase(u);
400     
401     //remove u frm vt
402     for(vector<Node *>::iterator VI=vt.begin(), VE=vt.end(); VI!=VE; ++VI){
403       if(**VI==*u){
404         vt.erase(VI);
405         break;
406       }
407     }
408     
409     //assign wt(v) to all adjacent vertices v of u
410     //only if v is in vt
411     Graph::nodeList &nl = getNodeList(u);
412     for(nodeList::iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI){
413       Node *v=NI->element;
414       int weight=-NI->weight;
415       //check if v is in vt
416       bool contains=false;
417       for(vector<Node *>::iterator VI=vt.begin(), VE=vt.end(); VI!=VE; ++VI){
418         if(**VI==*v){
419           contains=true;
420           break;
421         }
422       }
423       DEBUG(std::cerr<<"wt:v->wt"<<weight<<":"<<v->getWeight()<<"\n";
424             printNode(v);std::cerr<<"node wt:"<<(*v).weight<<"\n");
425
426       //so if v in in vt, change wt(v) to wt(u->v)
427       //only if wt(u->v)<wt(v)
428       if(contains && weight<v->getWeight()){
429         parent[v]=u;
430         ed_weight[v]=weight;
431         v->setWeight(weight);
432
433         DEBUG(std::cerr<<v->getWeight()<<":Set weight------\n";
434               printGraph();
435               printEdge(Edge(u,v,weight)));
436       }
437     }
438   }
439   return st;
440 }
441
442 //print the graph (for debugging)   
443 void Graph::printGraph(){
444    vector<Node *> lt=getAllNodes();
445    std::cerr<<"Graph---------------------\n";
446    for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
447      std::cerr<<((*LI)->getElement())->getName()<<"->";
448      Graph::nodeList &nl = getNodeList(*LI);
449      for(Graph::nodeList::iterator NI=nl.begin(), NE=nl.end(); NI!=NE; ++NI){
450        std::cerr<<":"<<"("<<(NI->element->getElement())
451          ->getName()<<":"<<NI->element->getWeight()<<","<<NI->weight<<")";
452      }
453      std::cerr<<"--------\n";
454    }
455 }
456
457
458 //get a list of nodes in the graph
459 //in r-topological sorted order
460 //note that we assumed graph to be connected
461 vector<Node *> Graph::reverseTopologicalSort(){
462   vector <Node *> toReturn;
463   vector<Node *> lt=getAllNodes();
464   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
465     if((*LI)->getWeight()!=GREY && (*LI)->getWeight()!=BLACK)
466       DFS_Visit(*LI, toReturn);
467   }
468
469   return toReturn;
470 }
471
472 //a private method for doing DFS traversal of graph
473 //this is used in determining the reverse topological sort 
474 //of the graph
475 void Graph::DFS_Visit(Node *nd, vector<Node *> &toReturn){
476   nd->setWeight(GREY);
477   vector<Node *> lt=getSuccNodes(nd);
478   for(vector<Node *>::iterator LI=lt.begin(), LE=lt.end(); LI!=LE; ++LI){
479     if((*LI)->getWeight()!=GREY && (*LI)->getWeight()!=BLACK)
480       DFS_Visit(*LI, toReturn);
481   }
482   toReturn.push_back(nd);
483 }
484
485 //Ordinarily, the graph is directional
486 //this converts the graph into an 
487 //undirectional graph
488 //This is done by adding an edge
489 //v->u for all existing edges u->v
490 void Graph::makeUnDirectional(){
491   vector<Node* > allNodes=getAllNodes();
492   for(vector<Node *>::iterator NI=allNodes.begin(), NE=allNodes.end(); NI!=NE; 
493       ++NI) {
494     nodeList &nl = getNodeList(*NI);
495     for(nodeList::iterator NLI=nl.begin(), NLE=nl.end(); NLI!=NLE; ++NLI){
496       Edge ed(NLI->element, *NI, NLI->weight);
497       if(!hasEdgeAndWt(ed)){
498         DEBUG(std::cerr<<"######doesn't hv\n";
499               printEdge(ed));
500         addEdgeForce(ed);
501       }
502     }
503   }
504 }
505
506 //reverse the sign of weights on edges
507 //this way, max-spanning tree could be obtained
508 //using min-spanning tree, and vice versa
509 void Graph::reverseWts(){
510   vector<Node *> allNodes=getAllNodes();
511   for(vector<Node *>::iterator NI=allNodes.begin(), NE=allNodes.end(); NI!=NE; 
512       ++NI) {
513     nodeList &node_list = getNodeList(*NI);
514     for(nodeList::iterator NLI=nodes[*NI].begin(), NLE=nodes[*NI].end(); 
515         NLI!=NLE; ++NLI)
516       NLI->weight=-NLI->weight;
517   }
518 }
519
520
521 //getting the backedges in a graph
522 //Its a variation of DFS to get the backedges in the graph
523 //We get back edges by associating a time
524 //and a color with each vertex.
525 //The time of a vertex is the time when it was first visited
526 //The color of a vertex is initially WHITE,
527 //Changes to GREY when it is first visited,
528 //and changes to BLACK when ALL its neighbors
529 //have been visited
530 //So we have a back edge when we meet a successor of
531 //a node with smaller time, and GREY color
532 void Graph::getBackEdges(vector<Edge > &be, std::map<Node *, int> &d){
533   std::map<Node *, Color > color;
534   int time=0;
535
536   getBackEdgesVisit(getRoot(), be, color, d, time);
537 }
538
539 //helper function to get back edges: it is called by 
540 //the "getBackEdges" function above
541 void Graph::getBackEdgesVisit(Node *u, vector<Edge > &be,
542                               std::map<Node *, Color > &color,
543                               std::map<Node *, int > &d, int &time) {
544   color[u]=GREY;
545   time++;
546   d[u]=time;
547
548   vector<graphListElement> &succ_list = getNodeList(u);
549   
550   for(vector<graphListElement>::iterator vl=succ_list.begin(), 
551         ve=succ_list.end(); vl!=ve; ++vl){
552     Node *v=vl->element;
553     if(color[v]!=GREY && color[v]!=BLACK){
554       getBackEdgesVisit(v, be, color, d, time);
555     }
556     
557     //now checking for d and f vals
558     if(color[v]==GREY){
559       //so v is ancestor of u if time of u > time of v
560       if(d[u] >= d[v]){
561         Edge *ed=new Edge(u, v,vl->weight, vl->randId);
562         if (!(*u == *getExit() && *v == *getRoot()))
563           be.push_back(*ed);      // choose the forward edges
564       }
565     }
566   }
567   color[u]=BLACK;//done with visiting the node and its neighbors
568 }
569
570 } // End llvm namespace