Revert of r212265.
[oota-llvm.git] / lib / Transforms / Instrumentation / MaximumSpanningTree.h
index fcfb3f5b9c58d35e5792ac1678d168993f6052e4..363539b2886f3998bbfa56b8c06baae79c85ffca 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 //
-// This module privides means for calculating a maximum spanning tree for the
-// CFG of a function according to a given profile.
+// This module provides means for calculating a maximum spanning tree for a
+// given set of weighted edges. The type parameter T is the type of a node.
 //
 //===----------------------------------------------------------------------===//
 
 #ifndef LLVM_ANALYSIS_MAXIMUMSPANNINGTREE_H
 #define LLVM_ANALYSIS_MAXIMUMSPANNINGTREE_H
 
-#include "llvm/Analysis/ProfileInfo.h"
-#include "llvm/Support/raw_ostream.h"
+#include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/IR/BasicBlock.h"
+#include <algorithm>
 #include <vector>
 
 namespace llvm {
-  class Function;
 
+  /// MaximumSpanningTree - A MST implementation.
+  /// The type parameter T determines the type of the nodes of the graph.
+  template <typename T>
   class MaximumSpanningTree {
   public:
-    typedef std::vector<ProfileInfo::Edge> MaxSpanTree;
-
+    typedef std::pair<const T*, const T*> Edge;
+    typedef std::pair<Edge, double> EdgeWeight;
+    typedef std::vector<EdgeWeight> EdgeWeights;
   protected:
+    typedef std::vector<Edge> MaxSpanTree;
+
     MaxSpanTree MST;
 
+  private:
+    // A comparing class for comparing weighted edges.
+    struct EdgeWeightCompare {
+      static bool getBlockSize(const T *X) {
+        const BasicBlock *BB = dyn_cast_or_null<BasicBlock>(X);
+        return BB ? BB->size() : 0;
+      }
+
+      bool operator()(EdgeWeight X, EdgeWeight Y) const {
+        if (X.second > Y.second) return true;
+        if (X.second < Y.second) return false;
+
+        // Equal edge weights: break ties by comparing block sizes.
+        size_t XSizeA = getBlockSize(X.first.first);
+        size_t YSizeA = getBlockSize(Y.first.first);
+        if (XSizeA > YSizeA) return true;
+        if (XSizeA < YSizeA) return false;
+
+        size_t XSizeB = getBlockSize(X.first.second);
+        size_t YSizeB = getBlockSize(Y.first.second);
+        if (XSizeB > YSizeB) return true;
+        if (XSizeB < YSizeB) return false;
+
+        return false;
+      }
+    };
+
   public:
     static char ID; // Class identification, replacement for typeinfo
 
-    // MaxSpanTree() - Calculates a MST for a function according to a profile.
-    // If inverted is true, all the edges *not* in the MST are returned. As a
-    // special also all leaf edges of the MST are not included, this makes it
-    // easier for the OptimalEdgeProfileInstrumentation to use this MST to do
-    // an optimal profiling.
-    MaximumSpanningTree(Function *F, ProfileInfo *PI, bool invert);
+    /// MaximumSpanningTree() - Takes a vector of weighted edges and returns a
+    /// spanning tree.
+    MaximumSpanningTree(EdgeWeights &EdgeVector) {
+
+      std::stable_sort(EdgeVector.begin(), EdgeVector.end(), EdgeWeightCompare());
+
+      // Create spanning tree, Forest contains a special data structure
+      // that makes checking if two nodes are already in a common (sub-)tree
+      // fast and cheap.
+      EquivalenceClasses<const T*> Forest;
+      for (typename EdgeWeights::iterator EWi = EdgeVector.begin(),
+           EWe = EdgeVector.end(); EWi != EWe; ++EWi) {
+        Edge e = (*EWi).first;
+
+        Forest.insert(e.first);
+        Forest.insert(e.second);
+      }
+
+      // Iterate over the sorted edges, biggest first.
+      for (typename EdgeWeights::iterator EWi = EdgeVector.begin(),
+           EWe = EdgeVector.end(); EWi != EWe; ++EWi) {
+        Edge e = (*EWi).first;
+
+        if (Forest.findLeader(e.first) != Forest.findLeader(e.second)) {
+          Forest.unionSets(e.first, e.second);
+          // So we know now that the edge is not already in a subtree, so we push
+          // the edge to the MST.
+          MST.push_back(e);
+        }
+      }
+    }
 
-    virtual MaxSpanTree::iterator begin();
-    virtual MaxSpanTree::iterator end();
+    typename MaxSpanTree::iterator begin() {
+      return MST.begin();
+    }
 
-    virtual void dump();
+    typename MaxSpanTree::iterator end() {
+      return MST.end();
+    }
   };
 
 } // End llvm namespace