[PBQP] Use DenseSet rather than std::set for PBQP's PoolCostAllocator
[oota-llvm.git] / include / llvm / CodeGen / PBQP / CostAllocator.h
index ff62c09593448f1a9f4a0c4f59b77a47c2c16d10..8c86a700cf663f6124638e34c02fbdc50aad9e15 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef LLVM_COSTALLOCATOR_H
-#define LLVM_COSTALLOCATOR_H
+#ifndef LLVM_CODEGEN_PBQP_COSTALLOCATOR_H
+#define LLVM_CODEGEN_PBQP_COSTALLOCATOR_H
 
-#include <set>
+#include "llvm/ADT/DenseSet.h"
+#include <memory>
 #include <type_traits>
 
+namespace llvm {
 namespace PBQP {
 
-template <typename CostT,
-          typename CostKeyTComparator>
+template <typename CostT>
 class CostPool {
 public:
+  typedef std::shared_ptr<CostT> PoolRef;
 
-  class PoolEntry {
+private:
+
+  class PoolEntry : public std::enable_shared_from_this<PoolEntry> {
   public:
     template <typename CostKeyT>
     PoolEntry(CostPool &pool, CostKeyT cost)
-      : pool(pool), cost(std::move(cost)), refCount(0) {}
+        : pool(pool), cost(std::move(cost)) {}
     ~PoolEntry() { pool.removeEntry(this); }
-    void incRef() { ++refCount; }
-    bool decRef() { --refCount; return (refCount == 0); }
     CostT& getCost() { return cost; }
     const CostT& getCost() const { return cost; }
   private:
     CostPool &pool;
     CostT cost;
-    std::size_t refCount;
   };
 
-  class PoolRef {
+  class PoolEntryDSInfo {
   public:
-    PoolRef(PoolEntry *entry) : entry(entry) {
-      this->entry->incRef();
+    static inline PoolEntry* getEmptyKey() { return nullptr; }
+
+    static inline PoolEntry* getTombstoneKey() {
+      return reinterpret_cast<PoolEntry*>(static_cast<uintptr_t>(1));
     }
-    PoolRef(const PoolRef &r) {
-      entry = r.entry;
-      entry->incRef();
+
+    template <typename CostKeyT>
+    static unsigned getHashValue(const CostKeyT &C) {
+      return hash_value(C);
     }
-    PoolRef& operator=(const PoolRef &r) {
-      assert(entry != nullptr && "entry should not be null.");
-      PoolEntry *temp = r.entry;
-      temp->incRef();
-      entry->decRef();
-      entry = temp;
-      return *this;
+
+    static unsigned getHashValue(PoolEntry *P) {
+      return getHashValue(P->getCost());
     }
 
-    ~PoolRef() {
-      if (entry->decRef())
-        delete entry;
+    static unsigned getHashValue(const PoolEntry *P) {
+      return getHashValue(P->getCost());
     }
-    void reset(PoolEntry *entry) {
-      entry->incRef();
-      this->entry->decRef();
-      this->entry = entry;
+
+    template <typename CostKeyT1, typename CostKeyT2>
+    static
+    bool isEqual(const CostKeyT1 &C1, const CostKeyT2 &C2) {
+      return C1 == C2;
     }
-    CostT& operator*() { return entry->getCost(); }
-    const CostT& operator*() const { return entry->getCost(); }
-    CostT* operator->() { return &entry->getCost(); }
-    const CostT* operator->() const { return &entry->getCost(); }
-  private:
-    PoolEntry *entry;
-  };
 
-private:
-  class EntryComparator {
-  public:
     template <typename CostKeyT>
-    typename std::enable_if<
-               !std::is_same<PoolEntry*,
-                             typename std::remove_const<CostKeyT>::type>::value,
-               bool>::type
-    operator()(const PoolEntry* a, const CostKeyT &b) {
-      return compare(a->getCost(), b);
+    static bool isEqual(const CostKeyT &C, PoolEntry *P) {
+      if (P == getEmptyKey() || P == getTombstoneKey())
+        return false;
+      return isEqual(C, P->getCost());
     }
-    bool operator()(const PoolEntry* a, const PoolEntry* b) {
-      return compare(a->getCost(), b->getCost());
+
+    static bool isEqual(PoolEntry *P1, PoolEntry *P2) {
+      if (P1 == getEmptyKey() || P1 == getTombstoneKey())
+        return P1 == P2;
+      return isEqual(P1->getCost(), P2);
     }
-  private:
-    CostKeyTComparator compare;
+
   };
 
-  typedef std::set<PoolEntry*, EntryComparator> EntrySet;
+  typedef DenseSet<PoolEntry*, PoolEntryDSInfo> EntrySet;
 
   EntrySet entrySet;
 
   void removeEntry(PoolEntry *p) { entrySet.erase(p); }
 
 public:
+  template <typename CostKeyT> PoolRef getCost(CostKeyT costKey) {
+    typename EntrySet::iterator itr = entrySet.find_as(costKey);
 
-  template <typename CostKeyT>
-  PoolRef getCost(CostKeyT costKey) {
-    typename EntrySet::iterator itr =
-      std::lower_bound(entrySet.begin(), entrySet.end(), costKey,
-                       EntryComparator());
-
-    if (itr != entrySet.end() && costKey == (*itr)->getCost())
-      return PoolRef(*itr);
+    if (itr != entrySet.end())
+      return PoolRef((*itr)->shared_from_this(), &(*itr)->getCost());
 
-    PoolEntry *p = new PoolEntry(*this, std::move(costKey));
-    entrySet.insert(itr, p);
-    return PoolRef(p);
+    auto p = std::make_shared<PoolEntry>(*this, std::move(costKey));
+    entrySet.insert(p.get());
+    return PoolRef(std::move(p), &p->getCost());
   }
 };
 
-template <typename VectorT, typename VectorTComparator,
-          typename MatrixT, typename MatrixTComparator>
+template <typename VectorT, typename MatrixT>
 class PoolCostAllocator {
 private:
-  typedef CostPool<VectorT, VectorTComparator> VectorCostPool;
-  typedef CostPool<MatrixT, MatrixTComparator> MatrixCostPool;
+  typedef CostPool<VectorT> VectorCostPool;
+  typedef CostPool<MatrixT> MatrixCostPool;
 public:
   typedef VectorT Vector;
   typedef MatrixT Matrix;
@@ -142,6 +127,7 @@ private:
   MatrixCostPool matrixPool;
 };
 
-}
+} // namespace PBQP
+} // namespace llvm
 
-#endif // LLVM_COSTALLOCATOR_H
+#endif