Converted MaximumSpanningTree algorithm to a generic template, this could go
[oota-llvm.git] / lib / Transforms / Instrumentation / MaximumSpanningTree.h
index 2343985f23d43d7d8539d804c8645acfaf3aae79..2951dbcea9a185a39caa0271ba235e341147582d 100644 (file)
@@ -7,43 +7,87 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This module privides means for calculating a maximum spanning tree for the
-// CFG of a function according to a given profile.
+// This module privides means for calculating a maximum spanning tree for a
+// given set of weighted edges. The type parameter T is the type of a node.
 //
 //===----------------------------------------------------------------------===//
 
 #ifndef LLVM_ANALYSIS_MAXIMUMSPANNINGTREE_H
 #define LLVM_ANALYSIS_MAXIMUMSPANNINGTREE_H
 
-#include "llvm/Analysis/ProfileInfo.h"
-#include "llvm/Support/raw_ostream.h"
+#include "llvm/ADT/EquivalenceClasses.h"
 #include <vector>
+#include <algorithm>
 
 namespace llvm {
-  class Function;
 
+  /// MaximumSpanningTree - A MST implementation.
+  /// The type parameter T determines the type of the nodes of the graph.
+  template <typename T>
   class MaximumSpanningTree {
-  public:
-    typedef std::vector<ProfileInfo::Edge> MaxSpanTree;
 
+    // A comparing class for comparing weighted edges.
+    template <typename CT>
+    struct EdgeWeightCompare {
+      bool operator()(typename MaximumSpanningTree<CT>::EdgeWeight X, 
+                      typename MaximumSpanningTree<CT>::EdgeWeight Y) const {
+        if (X.second > Y.second) return true;
+        if (X.second < Y.second) return false;
+        return false;
+      }
+    };
+
+  public:
+    typedef std::pair<const T*, const T*> Edge;
+    typedef std::pair<Edge, double> EdgeWeight;
+    typedef std::vector<EdgeWeight> EdgeWeights;
   protected:
+    typedef std::vector<Edge> MaxSpanTree;
+
     MaxSpanTree MST;
 
   public:
     static char ID; // Class identification, replacement for typeinfo
 
-    // MaxSpanTree() - Calculates a MST for a function according to a profile.
-    // If inverted is true, all the edges *not* in the MST are returned. As a
-    // special also all leaf edges of the MST are not included, this makes it
-    // easier for the OptimalEdgeProfileInstrumentation to use this MST to do
-    // an optimal profiling.
-    MaximumSpanningTree(std::vector<ProfileInfo::EdgeWeight>&);
-    virtual ~MaximumSpanningTree() {}
+    /// MaximumSpanningTree() - Takes a vector of weighted edges and returns a
+    /// spanning tree.
+    MaximumSpanningTree(EdgeWeights &EdgeVector) {
+
+      std::stable_sort(EdgeVector.begin(), EdgeVector.end(), EdgeWeightCompare<T>());
+
+      // Create spanning tree, Forest contains a special data structure
+      // that makes checking if two nodes are already in a common (sub-)tree
+      // fast and cheap.
+      EquivalenceClasses<const T*> Forest;
+      for (typename EdgeWeights::iterator EWi = EdgeVector.begin(),
+           EWe = EdgeVector.end(); EWi != EWe; ++EWi) {
+        Edge e = (*EWi).first;
+
+        Forest.insert(e.first);
+        Forest.insert(e.second);
+      }
+
+      // Iterate over the sorted edges, biggest first.
+      for (typename EdgeWeights::iterator EWi = EdgeVector.begin(),
+           EWe = EdgeVector.end(); EWi != EWe; ++EWi) {
+        Edge e = (*EWi).first;
+
+        if (Forest.findLeader(e.first) != Forest.findLeader(e.second)) {
+          Forest.unionSets(e.first, e.second);
+          // So we know now that the edge is not already in a subtree, so we push
+          // the edge to the MST.
+          MST.push_back(e);
+        }
+      }
+    }
 
-    virtual MaxSpanTree::iterator begin();
-    virtual MaxSpanTree::iterator end();
+    typename MaxSpanTree::iterator begin() {
+      return MST.begin();
+    }
 
-    virtual void dump();
+    typename MaxSpanTree::iterator end() {
+      return MST.end();
+    }
   };
 
 } // End llvm namespace