Bug fix in operator==() and in method fini().
[oota-llvm.git] / include / llvm / ADT / SCCIterator.h
1 //===-- Support/TarjanSCCIterator.h -Generic Tarjan SCC iterator -*- C++ -*--=//
2 //
3 // This builds on the Support/GraphTraits.h file to find the strongly 
4 // connected components (SCCs) of a graph in O(N+E) time using
5 // Tarjan's DFS algorithm.
6 //
7 // The SCC iterator has the important property that if a node in SCC S1
8 // has an edge to a node in SCC S2, then it visits S1 *after* S2.
9 // 
10 // To visit S1 *before* S2, use the TarjanSCCIterator on the Inverse graph.
11 // (NOTE: This requires some simple wrappers and is not supported yet.)
12 //===----------------------------------------------------------------------===//
13
14 #ifndef LLVM_SUPPORT_TARJANSCC_ITERATOR_H
15 #define LLVM_SUPPORT_TARJANSCC_ITERATOR_H
16
17 #include "Support/GraphTraits.h"
18 #include <Support/Statistic.h>
19 #include <Support/iterator>
20 #include <vector>
21 #include <stack>
22 #include <map>
23
24
25 //--------------------------------------------------------------------------
26 // class SCC : A simple representation of an SCC in a generic Graph.
27 //--------------------------------------------------------------------------
28
29 template<class GraphT, class GT = GraphTraits<GraphT> >
30 struct SCC: public std::vector<typename GT::NodeType*> {
31
32   typedef typename GT::NodeType NodeType;
33   typedef typename GT::ChildIteratorType ChildItTy;
34
35   typedef std::vector<typename GT::NodeType*> super;
36   typedef typename super::iterator               iterator;
37   typedef typename super::const_iterator         const_iterator;
38   typedef typename super::reverse_iterator       reverse_iterator;
39   typedef typename super::const_reverse_iterator const_reverse_iterator;
40
41   // HasLoop() -- Test if this SCC has a loop.  If it has more than one
42   // node, this is trivially true.  If not, it may still contain a loop
43   // if the node has an edge back to itself.
44   bool HasLoop() const {
45     if (size() > 1) return true;
46     NodeType* N = front();
47     for (ChildItTy CI=GT::child_begin(N), CE=GT::child_end(N); CI != CE; ++CI)
48       if (*CI == N)
49         return true;
50     return false;
51   }
52 };
53
54 //--------------------------------------------------------------------------
55 // class TarjanSCC_iterator: Enumerate the SCCs of a directed graph, in
56 // reverse topological order of the SCC DAG.
57 //--------------------------------------------------------------------------
58
59 namespace {
60   Statistic<> NumSCCs("NumSCCs", "Number of Strongly Connected Components");
61   Statistic<> MaxSCCSize("MaxSCCSize", "Size of largest Strongly Connected Component");
62 }
63
64 template<class GraphT, class GT = GraphTraits<GraphT> >
65 class TarjanSCC_iterator : public forward_iterator<SCC<GraphT, GT>, ptrdiff_t>
66 {
67   typedef SCC<GraphT, GT> SccTy;
68   typedef forward_iterator<SccTy, ptrdiff_t> super;
69   typedef typename super::reference reference;
70   typedef typename super::pointer pointer;
71   typedef typename GT::NodeType          NodeType;
72   typedef typename GT::ChildIteratorType ChildItTy;
73
74   // The visit counters used to detect when a complete SCC is on the stack.
75   // visitNum is the global counter.
76   // nodeVisitNumbers are per-node visit numbers, also used as DFS flags.
77   unsigned long visitNum;
78   std::map<NodeType *, unsigned long> nodeVisitNumbers;
79
80   // SCCNodeStack - Stack holding nodes of the SCC.
81   std::stack<NodeType *> SCCNodeStack;
82
83   // CurrentSCC - The current SCC, retrieved using operator*().
84   SccTy CurrentSCC;
85
86   // VisitStack - Used to maintain the ordering.  Top = current block
87   // First element is basic block pointer, second is the 'next child' to visit
88   std::stack<std::pair<NodeType *, ChildItTy> > VisitStack;
89
90   // MinVistNumStack - Stack holding the "min" values for each node in the DFS.
91   // This is used to track the minimum uplink values for all children of
92   // the corresponding node on the VisitStack.
93   std::stack<unsigned long> MinVisitNumStack;
94
95   // A single "visit" within the non-recursive DFS traversal.
96   void DFSVisitOne(NodeType* N) {
97     ++visitNum;                         // Global counter for the visit order
98     nodeVisitNumbers[N] = visitNum;
99     SCCNodeStack.push(N);
100     MinVisitNumStack.push(visitNum);
101     VisitStack.push(make_pair(N, GT::child_begin(N)));
102     DEBUG(std::cerr << "TarjanSCC: Node " << N <<
103           " : visitNum = " << visitNum << "\n");
104   }
105
106   // The stack-based DFS traversal; defined below.
107   void DFSVisitChildren() {
108     assert(!VisitStack.empty());
109     while (VisitStack.top().second != GT::child_end(VisitStack.top().first))
110       { // TOS has at least one more child so continue DFS
111         NodeType *childN = *VisitStack.top().second++;
112         if (nodeVisitNumbers.find(childN) == nodeVisitNumbers.end())
113           { // this node has never been seen
114             DFSVisitOne(childN);
115           }
116         else
117           {
118             unsigned long childNum = nodeVisitNumbers[childN];
119             if (MinVisitNumStack.top() > childNum)
120               MinVisitNumStack.top() = childNum;
121           }
122       }
123   }
124
125   // Compute the next SCC using the DFS traversal.
126   void GetNextSCC() {
127     assert(VisitStack.size() == MinVisitNumStack.size());
128     CurrentSCC.clear();                 // Prepare to compute the next SCC
129     while (! VisitStack.empty())
130       {
131         DFSVisitChildren();
132
133         assert(VisitStack.top().second==GT::child_end(VisitStack.top().first));
134         NodeType* visitingN = VisitStack.top().first;
135         unsigned long minVisitNum = MinVisitNumStack.top();
136         VisitStack.pop();
137         MinVisitNumStack.pop();
138         if (! MinVisitNumStack.empty() && MinVisitNumStack.top() > minVisitNum)
139           MinVisitNumStack.top() = minVisitNum;
140
141         DEBUG(std::cerr << "TarjanSCC: Popped node " << visitingN <<
142               " : minVisitNum = " << minVisitNum << "; Node visit num = " <<
143               nodeVisitNumbers[visitingN] << "\n");
144
145         if (minVisitNum == nodeVisitNumbers[visitingN])
146           { // A full SCC is on the SCCNodeStack!  It includes all nodes below
147             // visitingN on the stack.  Copy those nodes to CurrentSCC,
148             // reset their minVisit values, and return (this suspends
149             // the DFS traversal till the next ++).
150             do {
151               CurrentSCC.push_back(SCCNodeStack.top());
152               SCCNodeStack.pop();
153               nodeVisitNumbers[CurrentSCC.back()] = ~0UL; 
154             } while (CurrentSCC.back() != visitingN);
155
156             ++NumSCCs;
157             if (CurrentSCC.size() > MaxSCCSize) MaxSCCSize = CurrentSCC.size();
158             
159             return;
160           }
161       }
162   }
163
164   inline TarjanSCC_iterator(NodeType *entryN) : visitNum(0) {
165     DFSVisitOne(entryN);
166     GetNextSCC();
167   }
168   inline TarjanSCC_iterator() { /* End is when DFS stack is empty */ }
169
170 public:
171   typedef TarjanSCC_iterator<GraphT, GT> _Self;
172
173   // Provide static "constructors"...
174   static inline _Self begin(GraphT& G) { return _Self(GT::getEntryNode(G)); }
175   static inline _Self end  (GraphT& G) { return _Self(); }
176
177   // Direct loop termination test (I.fini() is more efficient than I == end())
178   inline bool fini() const {
179     assert(!CurrentSCC.empty() || VisitStack.empty());
180     return CurrentSCC.empty();
181   }
182
183   inline bool operator==(const _Self& x) const { 
184     return VisitStack == x.VisitStack && CurrentSCC == x.CurrentSCC;
185   }
186   inline bool operator!=(const _Self& x) const { return !operator==(x); }
187
188   // Iterator traversal: forward iteration only
189   inline _Self& operator++() {          // Preincrement
190     GetNextSCC();
191     return *this; 
192   }
193   inline _Self operator++(int) {        // Postincrement
194     _Self tmp = *this; ++*this; return tmp; 
195   }
196
197   // Retrieve a pointer to the current SCC.  Returns NULL when done.
198   inline const SccTy* operator*() const { 
199     assert(!CurrentSCC.empty() || VisitStack.empty());
200     return CurrentSCC.empty()? NULL : &CurrentSCC;
201   }
202   inline SccTy* operator*() { 
203     assert(!CurrentSCC.empty() || VisitStack.empty());
204     return CurrentSCC.empty()? NULL : &CurrentSCC;
205   }
206 };
207
208
209 // Global constructor for the Tarjan SCC iterator.  Use *I == NULL or I.fini()
210 // to test termination efficiently, instead of I == the "end" iterator.
211 template <class T>
212 TarjanSCC_iterator<T> tarj_begin(T G)
213 {
214   return TarjanSCC_iterator<T>::begin(G);
215 }
216
217 template <class T>
218 TarjanSCC_iterator<T> tarj_end(T G)
219 {
220   return TarjanSCC_iterator<T>::end(G);
221 }
222
223 //===----------------------------------------------------------------------===//
224
225 #endif