Add integrated reference counting
authorMaged Michael <magedmichael@fb.com>
Wed, 1 Nov 2017 14:47:41 +0000 (07:47 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Nov 2017 14:49:55 +0000 (07:49 -0700)
Summary:
Add support for reference counting integrated with the internal structures and operations of the hazard pointer library. The operations are wait-free.
The advantages of this approach over combining reference counting with hazard pointers externally are:
(1) A long list of linked objects that protected by one reference can all be reclaimed together instead of going through a potentially long series of alternating reclamation and calls to retire() for descendants.
(2) Support for iterative deletion as opposed to potential deep recursion of alternating calls to release reference count and object destructors.

Reviewed By: djwatson

Differential Revision: D6142066

fbshipit-source-id: 02bdfcbd5a2c2d5486d937bb2f9cfb6f192f5e1a

folly/experimental/hazptr/hazptr-impl.h
folly/experimental/hazptr/hazptr.h
folly/experimental/hazptr/test/HazptrTest.cpp

index a9f4c20..e287aec 100644 (file)
@@ -118,9 +118,7 @@ static_assert(
 struct hazptr_tc {
   hazptr_tc_entry entry_[HAZPTR_TC_SIZE];
   size_t count_;
-#ifndef NDEBUG
-  bool local_;
-#endif
+  bool local_; // for debug mode only
 
  public:
   hazptr_tc_entry& operator[](size_t i);
@@ -206,6 +204,63 @@ inline void hazptr_obj_base<T, D>::retire(hazptr_domain& domain, D deleter) {
   domain.objRetire(this);
 }
 
+/**
+ *  hazptr_obj_base_refcounted
+ */
+
+template <typename T, typename D>
+inline void hazptr_obj_base_refcounted<T, D>::retire(
+    hazptr_domain& domain,
+    D deleter) {
+  DEBUG_PRINT(this << " " << &domain);
+  deleter_ = std::move(deleter);
+  reclaim_ = [](hazptr_obj* p) {
+    auto hrobp = static_cast<hazptr_obj_base_refcounted*>(p);
+    if (hrobp->release_ref()) {
+      auto obj = static_cast<T*>(hrobp);
+      hrobp->deleter_(obj);
+    }
+  };
+  if (HAZPTR_PRIV &&
+      (HAZPTR_ONE_DOMAIN || (&domain == &default_hazptr_domain()))) {
+    if (hazptr_priv_try_retire(this)) {
+      return;
+    }
+  }
+  domain.objRetire(this);
+}
+
+template <typename T, typename D>
+inline void hazptr_obj_base_refcounted<T, D>::acquire_ref() {
+  DEBUG_PRINT(this);
+  auto oldval = refcount_.fetch_add(1);
+  DCHECK(oldval >= 0);
+}
+
+template <typename T, typename D>
+inline void hazptr_obj_base_refcounted<T, D>::acquire_ref_safe() {
+  DEBUG_PRINT(this);
+  auto oldval = refcount_.load(std::memory_order_acquire);
+  DCHECK(oldval >= 0);
+  refcount_.store(oldval + 1, std::memory_order_release);
+}
+
+template <typename T, typename D>
+inline bool hazptr_obj_base_refcounted<T, D>::release_ref() {
+  DEBUG_PRINT(this);
+  auto oldval = refcount_.load(std::memory_order_acquire);
+  if (oldval > 0) {
+    oldval = refcount_.fetch_sub(1);
+  } else {
+    if (kIsDebug) {
+      refcount_.store(-1);
+    }
+  }
+  DEBUG_PRINT(this << " " << oldval);
+  DCHECK(oldval >= 0);
+  return oldval == 0;
+}
+
 /**
  *  hazptr_rec
  */
@@ -481,10 +536,10 @@ FOLLY_ALWAYS_INLINE hazptr_local<M>::hazptr_local() {
       auto& tc = *ptc;
       auto count = tc.count();
       if (M <= count) {
-#ifndef NDEBUG
-        DCHECK(!tc.local_);
-        tc.local_ = true;
-#endif
+        if (kIsDebug) {
+          DCHECK(!tc.local_);
+          tc.local_ = true;
+        }
         // Fast path
         for (size_t i = 0; i < M; ++i) {
           auto hprec = tc[i].hprec_;
@@ -511,13 +566,13 @@ FOLLY_ALWAYS_INLINE hazptr_local<M>::hazptr_local() {
 template <size_t M>
 FOLLY_ALWAYS_INLINE hazptr_local<M>::~hazptr_local() {
   if (LIKELY(!need_destruct_)) {
-#ifndef NDEBUG
-    auto ptc = hazptr_tc_tls();
-    DCHECK(ptc != nullptr);
-    auto& tc = *ptc;
-    DCHECK(tc.local_);
-    tc.local_ = false;
-#endif
+    if (kIsDebug) {
+      auto ptc = hazptr_tc_tls();
+      DCHECK(ptc != nullptr);
+      auto& tc = *ptc;
+      DCHECK(tc.local_);
+      tc.local_ = false;
+    }
     return;
   }
   // Slow path
@@ -602,6 +657,7 @@ inline hazptr_domain::~hazptr_domain() {
     while (retired) {
       for (auto p = retired; p; p = next) {
         next = p->next_;
+        DEBUG_PRINT(this << " " << p << " " << p->reclaim_);
         (*(p->reclaim_))(p);
       }
       retired = retired_.exchange(nullptr);
@@ -866,9 +922,9 @@ inline void hazptr_tc_init() {
   auto& tc = tls_tc_data_;
   DEBUG_PRINT(&tc);
   tc.count_ = 0;
-#ifndef NDEBUG
-  tc.local_ = false;
-#endif
+  if (kIsDebug) {
+    tc.local_ = false;
+  }
 }
 
 inline void hazptr_tc_shutdown() {
index f1776e4..62d9651 100644 (file)
@@ -34,6 +34,12 @@ class hazptr_obj;
 template <typename T, typename Deleter>
 class hazptr_obj_base;
 
+/** hazptr_obj_base_refcounted:
+ *  Base template for reference counted objects protected by hazard pointers.
+ */
+template <typename T, typename Deleter>
+class hazptr_obj_base_refcounted;
+
 /** hazptr_local: Optimized template for bulk construction and destruction of
  *  hazard pointers */
 template <size_t M>
@@ -60,6 +66,8 @@ class hazptr_domain {
   friend class hazptr_holder;
   template <typename, typename>
   friend class hazptr_obj_base;
+  template <typename, typename>
+  friend class hazptr_obj_base_refcounted;
   friend struct hazptr_priv;
 
   memory_resource* mr_;
@@ -87,10 +95,13 @@ class hazptr_obj {
   friend class hazptr_domain;
   template <typename, typename>
   friend class hazptr_obj_base;
+  template <typename, typename>
+  friend class hazptr_obj_base_refcounted;
   friend struct hazptr_priv;
 
   void (*reclaim_)(hazptr_obj*);
   hazptr_obj* next_;
+
   const void* getObjPtr() const;
 };
 
@@ -106,6 +117,33 @@ class hazptr_obj_base : public hazptr_obj {
   D deleter_;
 };
 
+/** Definition of hazptr_recounted_obj_base */
+template <typename T, typename D = std::default_delete<T>>
+class hazptr_obj_base_refcounted : public hazptr_obj {
+ public:
+  /* Retire a removed object and pass the responsibility for
+   * reclaiming it to the hazptr library */
+  void retire(hazptr_domain& domain = default_hazptr_domain(), D reclaim = {});
+
+  /* aquire_ref() increments the reference count
+   *
+   * acquire_ref_safe() is the same as acquire_ref() except that in
+   * addition the caller guarantees that the call is made in a
+   * thread-safe context, e.g., the object is not yet shared. This is
+   * just an optimization to save an atomic operation.
+   *
+   * release_ref() decrements the reference count and returns true if
+   * the object is safe to reclaim.
+   */
+  void acquire_ref();
+  void acquire_ref_safe();
+  bool release_ref();
+
+ private:
+  std::atomic<uint32_t> refcount_{0};
+  D deleter_;
+};
+
 /** hazptr_holder: Class for automatic acquisition and release of
  *  hazard pointers, and interface for hazard pointer operations. */
 class hazptr_holder {
index 75d4669..9d71cda 100644 (file)
@@ -50,12 +50,16 @@ class HazptrTest : public testing::Test {
 TEST_F(HazptrTest, Test1) {
   DEBUG_PRINT("");
   Node1* node0 = (Node1*)malloc(sizeof(Node1));
-  DEBUG_PRINT("=== new    node0 " << node0 << " " << sizeof(*node0));
+  node0 = new (node0) Node1;
+  DEBUG_PRINT("=== malloc node0 " << node0 << " " << sizeof(*node0));
   Node1* node1 = (Node1*)malloc(sizeof(Node1));
+  node1 = new (node1) Node1;
   DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1));
   Node1* node2 = (Node1*)malloc(sizeof(Node1));
+  node2 = new (node2) Node1;
   DEBUG_PRINT("=== malloc node2 " << node2 << " " << sizeof(*node2));
   Node1* node3 = (Node1*)malloc(sizeof(Node1));
+  node3 = new (node3) Node1;
   DEBUG_PRINT("=== malloc node3 " << node3 << " " << sizeof(*node3));
 
   DEBUG_PRINT("");
@@ -90,12 +94,12 @@ TEST_F(HazptrTest, Test1) {
   Node1* n2 = shared2.load();
   Node1* n3 = shared3.load();
 
-  if (hptr0.try_protect(n0, shared0)) {}
-  if (hptr1.try_protect(n1, shared1)) {}
+  CHECK(hptr0.try_protect(n0, shared0));
+  CHECK(hptr1.try_protect(n1, shared1));
   hptr1.reset();
   hptr1.reset(nullptr);
   hptr1.reset(n2);
-  if (hptr2.try_protect(n3, shared3)) {}
+  CHECK(hptr2.try_protect(n3, shared3));
   swap(hptr1, hptr2);
   hptr3.reset();
 
@@ -115,10 +119,13 @@ TEST_F(HazptrTest, Test2) {
   Node2* node0 = new Node2;
   DEBUG_PRINT("=== new    node0 " << node0 << " " << sizeof(*node0));
   Node2* node1 = (Node2*)malloc(sizeof(Node2));
+  node1 = new (node1) Node2;
   DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1));
   Node2* node2 = (Node2*)malloc(sizeof(Node2));
+  node2 = new (node2) Node2;
   DEBUG_PRINT("=== malloc node2 " << node2 << " " << sizeof(*node2));
   Node2* node3 = (Node2*)malloc(sizeof(Node2));
+  node3 = new (node3) Node2;
   DEBUG_PRINT("=== malloc node3 " << node3 << " " << sizeof(*node3));
 
   DEBUG_PRINT("");
@@ -153,11 +160,11 @@ TEST_F(HazptrTest, Test2) {
   Node2* n2 = shared2.load();
   Node2* n3 = shared3.load();
 
-  if (hptr0.try_protect(n0, shared0)) {}
-  if (hptr1.try_protect(n1, shared1)) {}
+  CHECK(hptr0.try_protect(n0, shared0));
+  CHECK(hptr1.try_protect(n1, shared1));
   hptr1.reset();
   hptr1.reset(n2);
-  if (hptr2.try_protect(n3, shared3)) {}
+  CHECK(hptr2.try_protect(n3, shared3));
   swap(hptr1, hptr2);
   hptr3.reset();
 
@@ -185,7 +192,9 @@ TEST_F(HazptrTest, LIFO) {
         for (int j = tid; j < FLAGS_num_ops; j += FLAGS_num_threads) {
           s.push(j);
           T res;
-          while (!s.pop(res)) {}
+          while (!s.pop(res)) {
+            /* keep trying */
+          }
         }
       });
     }
@@ -394,3 +403,151 @@ TEST_F(HazptrTest, Local) {
     hazptr_local<HAZPTR_TC_SIZE + 1> h;
   }
 }
+
+/* Test ref counting */
+
+std::atomic<int> constructed;
+std::atomic<int> destroyed;
+
+struct Foo : hazptr_obj_base_refcounted<Foo> {
+  int val_;
+  bool marked_;
+  Foo* next_;
+  Foo(int v, Foo* n) : val_(v), marked_(false), next_(n) {
+    DEBUG_PRINT("");
+    ++constructed;
+  }
+  ~Foo() {
+    DEBUG_PRINT("");
+    ++destroyed;
+    if (marked_) {
+      return;
+    }
+    auto next = next_;
+    while (next) {
+      if (!next->release_ref()) {
+        return;
+      }
+      auto p = next;
+      next = p->next_;
+      p->marked_ = true;
+      delete p;
+    }
+  }
+};
+
+struct Dummy : hazptr_obj_base<Dummy> {};
+
+TEST_F(HazptrTest, basic_refcount) {
+  constructed.store(0);
+  destroyed.store(0);
+
+  Foo* p = nullptr;
+  int num = 20;
+  for (int i = 0; i < num; ++i) {
+    p = new Foo(i, p);
+    if (i & 1) {
+      p->acquire_ref_safe();
+    } else {
+      p->acquire_ref();
+    }
+  }
+  hazptr_holder hptr;
+  hptr.reset(p);
+  for (auto q = p->next_; q; q = q->next_) {
+    q->retire();
+  }
+  int v = num;
+  for (auto q = p; q; q = q->next_) {
+    CHECK_GT(v, 0);
+    --v;
+    CHECK_EQ(q->val_, v);
+  }
+  CHECK(!p->release_ref());
+  CHECK_EQ(constructed.load(), num);
+  CHECK_EQ(destroyed.load(), 0);
+  p->retire();
+  CHECK_EQ(constructed.load(), num);
+  CHECK_EQ(destroyed.load(), 0);
+  hptr.reset();
+
+  /* retire enough objects to guarantee reclamation of Foo objects */
+  for (int i = 0; i < 100; ++i) {
+    auto a = new Dummy;
+    a->retire();
+  }
+
+  CHECK_EQ(constructed.load(), num);
+  CHECK_EQ(destroyed.load(), num);
+}
+
+TEST_F(HazptrTest, mt_refcount) {
+  constructed.store(0);
+  destroyed.store(0);
+
+  std::atomic<bool> ready(false);
+  std::atomic<int> setHazptrs(0);
+  std::atomic<Foo*> head;
+
+  int num = 20;
+  int nthr = 10;
+  std::vector<std::thread> thr(nthr);
+  for (int i = 0; i < nthr; ++i) {
+    thr[i] = std::thread([&] {
+      while (!ready.load()) {
+        /* spin */
+      }
+      hazptr_holder hptr;
+      auto p = hptr.get_protected(head);
+      ++setHazptrs;
+      /* Concurrent with removal */
+      int v = num;
+      for (auto q = p; q; q = q->next_) {
+        CHECK_GT(v, 0);
+        --v;
+        CHECK_EQ(q->val_, v);
+      }
+      CHECK_EQ(v, 0);
+    });
+  }
+
+  Foo* p = nullptr;
+  for (int i = 0; i < num; ++i) {
+    p = new Foo(i, p);
+    p->acquire_ref_safe();
+  }
+  head.store(p);
+
+  ready.store(true);
+
+  while (setHazptrs.load() < nthr) {
+    /* spin */
+  }
+
+  /* this is concurrent with traversal by reader */
+  head.store(nullptr);
+  for (auto q = p; q; q = q->next_) {
+    q->retire();
+  }
+  DEBUG_PRINT("Foo should not be destroyed");
+  CHECK_EQ(constructed.load(), num);
+  CHECK_EQ(destroyed.load(), 0);
+
+  DEBUG_PRINT("Foo may be destroyed after releasing the last reference");
+  if (p->release_ref()) {
+    delete p;
+  }
+
+  /* retire enough objects to guarantee reclamation of Foo objects */
+  for (int i = 0; i < 100; ++i) {
+    auto a = new Dummy;
+    a->retire();
+  }
+
+  for (int i = 0; i < nthr; ++i) {
+    thr[i].join();
+  }
+
+  CHECK_EQ(constructed.load(), num);
+  CHECK_EQ(destroyed.load(), num);
+}