7de1b3c52e632fef3c1490137eefe6f350dddeed
[oota-llvm.git] / lib / Transforms / Instrumentation / ProfilePaths / CombineBranch.cpp
1 //===-- CombineBranch.cpp -------------------------------------------------===//
2 // 
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by the LLVM research group and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 // 
8 //===----------------------------------------------------------------------===//
9 //
10 // Combine branches
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/Support/CFG.h"
15 #include "llvm/iTerminators.h"
16 #include "llvm/iPHINode.h"
17 #include "llvm/Function.h"
18 #include "llvm/Pass.h"
19
20 namespace llvm {
21
22 namespace {
23   struct CombineBranches : public FunctionPass {
24   private:
25     /// Possible colors that a vertex can have during depth-first search for
26     /// back-edges.
27     ///
28     enum Color { WHITE, GREY, BLACK };
29
30     void getBackEdgesVisit(BasicBlock *u,
31                            std::map<BasicBlock *, Color > &color,
32                            std::map<BasicBlock *, int > &d, 
33                            int &time,
34                            std::map<BasicBlock *, BasicBlock *> &be);
35     void removeRedundant(std::map<BasicBlock *, BasicBlock *> &be);
36   public:
37     bool runOnFunction(Function &F);
38   };
39   
40   RegisterOpt<CombineBranches>
41   X("branch-combine", "Multiple backedges going to same target are merged");
42 }
43
44 /// getBackEdgesVisit - Get the back-edges of the control-flow graph for this
45 /// function.  We proceed recursively using depth-first search.  We get
46 /// back-edges by associating a time and a color with each vertex.  The time of a
47 /// vertex is the time when it was first visited.  The color of a vertex is
48 /// initially WHITE, changes to GREY when it is first visited, and changes to
49 /// BLACK when ALL its neighbors have been visited.  So we have a back edge when
50 /// we meet a successor of a node with smaller time, and GREY color.
51 ///
52 void CombineBranches::getBackEdgesVisit(BasicBlock *u,
53                        std::map<BasicBlock *, Color > &color,
54                        std::map<BasicBlock *, int > &d, 
55                        int &time,
56                        std::map<BasicBlock *, BasicBlock *> &be) {
57   
58   color[u]=GREY;
59   time++;
60   d[u]=time;
61
62   for (succ_iterator vl = succ_begin(u), ve = succ_end(u); vl != ve; ++vl){
63     BasicBlock *BB = *vl;
64
65     if(color[BB]!=GREY && color[BB]!=BLACK)
66       getBackEdgesVisit(BB, color, d, time, be);
67     
68     //now checking for d and f vals
69     else if(color[BB]==GREY){
70       //so v is ancestor of u if time of u > time of v
71       if(d[u] >= d[BB]) // u->BB is a backedge
72         be[u] = BB;
73     }
74   }
75   color[u]=BLACK;//done with visiting the node and its neighbors
76 }
77
78 /// removeRedundant - Remove all back-edges that are dominated by other
79 /// back-edges in the set.
80 ///
81 void CombineBranches::removeRedundant(std::map<BasicBlock *, BasicBlock *> &be){
82   std::vector<BasicBlock *> toDelete;
83   std::map<BasicBlock *, int> seenBB;
84   
85   for(std::map<BasicBlock *, BasicBlock *>::iterator MI = be.begin(), 
86         ME = be.end(); MI != ME; ++MI){
87     
88     if(seenBB[MI->second])
89       continue;
90     
91     seenBB[MI->second] = 1;
92
93     std::vector<BasicBlock *> sameTarget;
94     sameTarget.clear();
95     
96     for(std::map<BasicBlock *, BasicBlock *>::iterator MMI = be.begin(), 
97           MME = be.end(); MMI != MME; ++MMI){
98       
99       if(MMI->first == MI->first)
100         continue;
101       
102       if(MMI->second == MI->second)
103         sameTarget.push_back(MMI->first);
104       
105     }
106     
107     //so more than one branch to same target
108     if(sameTarget.size()){
109
110       sameTarget.push_back(MI->first);
111
112       BasicBlock *newBB = new BasicBlock("newCommon", MI->first->getParent());
113       BranchInst *newBranch = new BranchInst(MI->second, 0, 0, newBB);
114
115       std::map<PHINode *, std::vector<unsigned int> > phiMap;
116
117       for(std::vector<BasicBlock *>::iterator VBI = sameTarget.begin(),
118             VBE = sameTarget.end(); VBI != VBE; ++VBI){
119
120         BranchInst *ti = cast<BranchInst>((*VBI)->getTerminator());
121         unsigned char index = 1;
122         if(ti->getSuccessor(0) == MI->second)
123           index = 0;
124
125         ti->setSuccessor(index, newBB);
126
127         for(BasicBlock::iterator BB2Inst = MI->second->begin(), 
128               BBend = MI->second->end(); BB2Inst != BBend; ++BB2Inst){
129           
130           if (PHINode *phiInst = dyn_cast<PHINode>(BB2Inst)){
131             int bbIndex;
132             bbIndex = phiInst->getBasicBlockIndex(*VBI);
133             if(bbIndex>=0)
134               phiMap[phiInst].push_back(bbIndex);
135           }
136         }
137       }
138
139       for(std::map<PHINode *, std::vector<unsigned int> >::iterator
140             PI = phiMap.begin(), PE = phiMap.end(); PI != PE; ++PI){
141         
142         PHINode *phiNode = new PHINode(PI->first->getType(), "phi", newBranch);
143         for(std::vector<unsigned int>::iterator II = PI->second.begin(),
144               IE = PI->second.end(); II != IE; ++II){
145           phiNode->addIncoming(PI->first->getIncomingValue(*II),
146                                PI->first->getIncomingBlock(*II));
147         }
148
149         std::vector<BasicBlock *> tempBB;
150         for(std::vector<unsigned int>::iterator II = PI->second.begin(),
151               IE = PI->second.end(); II != IE; ++II){
152           tempBB.push_back(PI->first->getIncomingBlock(*II));
153         }
154
155         for(std::vector<BasicBlock *>::iterator II = tempBB.begin(),
156               IE = tempBB.end(); II != IE; ++II){
157           PI->first->removeIncomingValue(*II);
158         }
159
160         PI->first->addIncoming(phiNode, newBB);
161       }
162     }
163   }
164 }
165
166 /// runOnFunction - Per function pass for combining branches.
167 ///
168 bool CombineBranches::runOnFunction(Function &F){
169   if (F.isExternal ())
170     return false;
171
172   // Find and remove "redundant" back-edges.
173   std::map<BasicBlock *, Color> color;
174   std::map<BasicBlock *, int> d;
175   std::map<BasicBlock *, BasicBlock *> be;
176   int time = 0;
177   getBackEdgesVisit (F.begin (), color, d, time, be);
178   removeRedundant (be);
179   
180   return true; // FIXME: assumes a modification was always made.
181 }
182
183 } // End llvm namespace