Add free-function retire
[folly.git] / folly / experimental / hazptr / test / HazptrTest.cpp
index 301aaa805d53cf53f75ac4497d674f6af0452560..89ddc023bea2a75361c0a6d2b62d7eff60bb1cbb 100644 (file)
@@ -50,12 +50,16 @@ class HazptrTest : public testing::Test {
 TEST_F(HazptrTest, Test1) {
   DEBUG_PRINT("");
   Node1* node0 = (Node1*)malloc(sizeof(Node1));
-  DEBUG_PRINT("=== new    node0 " << node0 << " " << sizeof(*node0));
+  node0 = new (node0) Node1;
+  DEBUG_PRINT("=== malloc node0 " << node0 << " " << sizeof(*node0));
   Node1* node1 = (Node1*)malloc(sizeof(Node1));
+  node1 = new (node1) Node1;
   DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1));
   Node1* node2 = (Node1*)malloc(sizeof(Node1));
+  node2 = new (node2) Node1;
   DEBUG_PRINT("=== malloc node2 " << node2 << " " << sizeof(*node2));
   Node1* node3 = (Node1*)malloc(sizeof(Node1));
+  node3 = new (node3) Node1;
   DEBUG_PRINT("=== malloc node3 " << node3 << " " << sizeof(*node3));
 
   DEBUG_PRINT("");
@@ -90,12 +94,12 @@ TEST_F(HazptrTest, Test1) {
   Node1* n2 = shared2.load();
   Node1* n3 = shared3.load();
 
-  if (hptr0.try_protect(n0, shared0)) {}
-  if (hptr1.try_protect(n1, shared1)) {}
+  CHECK(hptr0.try_protect(n0, shared0));
+  CHECK(hptr1.try_protect(n1, shared1));
   hptr1.reset();
   hptr1.reset(nullptr);
   hptr1.reset(n2);
-  if (hptr2.try_protect(n3, shared3)) {}
+  CHECK(hptr2.try_protect(n3, shared3));
   swap(hptr1, hptr2);
   hptr3.reset();
 
@@ -115,10 +119,13 @@ TEST_F(HazptrTest, Test2) {
   Node2* node0 = new Node2;
   DEBUG_PRINT("=== new    node0 " << node0 << " " << sizeof(*node0));
   Node2* node1 = (Node2*)malloc(sizeof(Node2));
+  node1 = new (node1) Node2;
   DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1));
   Node2* node2 = (Node2*)malloc(sizeof(Node2));
+  node2 = new (node2) Node2;
   DEBUG_PRINT("=== malloc node2 " << node2 << " " << sizeof(*node2));
   Node2* node3 = (Node2*)malloc(sizeof(Node2));
+  node3 = new (node3) Node2;
   DEBUG_PRINT("=== malloc node3 " << node3 << " " << sizeof(*node3));
 
   DEBUG_PRINT("");
@@ -153,11 +160,11 @@ TEST_F(HazptrTest, Test2) {
   Node2* n2 = shared2.load();
   Node2* n3 = shared3.load();
 
-  if (hptr0.try_protect(n0, shared0)) {}
-  if (hptr1.try_protect(n1, shared1)) {}
+  CHECK(hptr0.try_protect(n0, shared0));
+  CHECK(hptr1.try_protect(n1, shared1));
   hptr1.reset();
   hptr1.reset(n2);
-  if (hptr2.try_protect(n3, shared3)) {}
+  CHECK(hptr2.try_protect(n3, shared3));
   swap(hptr1, hptr2);
   hptr3.reset();
 
@@ -185,7 +192,9 @@ TEST_F(HazptrTest, LIFO) {
         for (int j = tid; j < FLAGS_num_ops; j += FLAGS_num_threads) {
           s.push(j);
           T res;
-          while (!s.pop(res)) {}
+          while (!s.pop(res)) {
+            /* keep trying */
+          }
         }
       });
     }
@@ -198,12 +207,11 @@ TEST_F(HazptrTest, LIFO) {
 
 TEST_F(HazptrTest, SWMRLIST) {
   using T = uint64_t;
-  hazptr_domain custom_domain;
 
   CHECK_GT(FLAGS_num_threads, 0);
   for (int i = 0; i < FLAGS_num_reps; ++i) {
     DEBUG_PRINT("========== start of rep scope");
-    SWMRListSet<T> s(custom_domain);
+    SWMRListSet<T> s;
     std::vector<std::thread> threads(FLAGS_num_threads);
     for (int tid = 0; tid < FLAGS_num_threads; ++tid) {
       threads[tid] = std::thread([&s, tid]() {
@@ -349,3 +357,219 @@ TEST_F(HazptrTest, Move) {
     hptr2.reset();
   }
 }
+
+TEST_F(HazptrTest, Array) {
+  struct Foo : hazptr_obj_base<Foo> {
+    int a;
+  };
+  for (int i = 0; i < 100; ++i) {
+    Foo* x = new Foo;
+    x->a = i;
+    hazptr_array<10> hptr;
+    // Protect object
+    hptr[9].reset(x);
+    // Empty array
+    hazptr_array<10> h(nullptr);
+    // Move assignment
+    h = std::move(hptr);
+    // Retire object
+    x->retire();
+    // Unprotect object - hptr2 is nonempty
+    h[9].reset();
+  }
+  {
+    // Abnormal case
+    hazptr_array<HAZPTR_TC_SIZE + 1> h;
+  }
+}
+
+TEST_F(HazptrTest, Local) {
+  struct Foo : hazptr_obj_base<Foo> {
+    int a;
+  };
+  for (int i = 0; i < 100; ++i) {
+    Foo* x = new Foo;
+    x->a = i;
+    hazptr_local<10> hptr;
+    // Protect object
+    hptr[9].reset(x);
+    // Retire object
+    x->retire();
+    // Unprotect object - hptr2 is nonempty
+    hptr[9].reset();
+  }
+  {
+    // Abnormal case
+    hazptr_local<HAZPTR_TC_SIZE + 1> h;
+  }
+}
+
+/* Test ref counting */
+
+std::atomic<int> constructed;
+std::atomic<int> destroyed;
+
+struct Foo : hazptr_obj_base_refcounted<Foo> {
+  int val_;
+  bool marked_;
+  Foo* next_;
+  Foo(int v, Foo* n) : val_(v), marked_(false), next_(n) {
+    DEBUG_PRINT("");
+    ++constructed;
+  }
+  ~Foo() {
+    DEBUG_PRINT("");
+    ++destroyed;
+    if (marked_) {
+      return;
+    }
+    auto next = next_;
+    while (next) {
+      if (!next->release_ref()) {
+        return;
+      }
+      auto p = next;
+      next = p->next_;
+      p->marked_ = true;
+      delete p;
+    }
+  }
+};
+
+struct Dummy : hazptr_obj_base<Dummy> {};
+
+TEST_F(HazptrTest, basic_refcount) {
+  constructed.store(0);
+  destroyed.store(0);
+
+  Foo* p = nullptr;
+  int num = 20;
+  for (int i = 0; i < num; ++i) {
+    p = new Foo(i, p);
+    if (i & 1) {
+      p->acquire_ref_safe();
+    } else {
+      p->acquire_ref();
+    }
+  }
+  hazptr_holder hptr;
+  hptr.reset(p);
+  for (auto q = p->next_; q; q = q->next_) {
+    q->retire();
+  }
+  int v = num;
+  for (auto q = p; q; q = q->next_) {
+    CHECK_GT(v, 0);
+    --v;
+    CHECK_EQ(q->val_, v);
+  }
+  CHECK(!p->release_ref());
+  CHECK_EQ(constructed.load(), num);
+  CHECK_EQ(destroyed.load(), 0);
+  p->retire();
+  CHECK_EQ(constructed.load(), num);
+  CHECK_EQ(destroyed.load(), 0);
+  hptr.reset();
+
+  /* retire enough objects to guarantee reclamation of Foo objects */
+  for (int i = 0; i < 100; ++i) {
+    auto a = new Dummy;
+    a->retire();
+  }
+
+  CHECK_EQ(constructed.load(), num);
+  CHECK_EQ(destroyed.load(), num);
+}
+
+TEST_F(HazptrTest, mt_refcount) {
+  constructed.store(0);
+  destroyed.store(0);
+
+  std::atomic<bool> ready(false);
+  std::atomic<int> setHazptrs(0);
+  std::atomic<Foo*> head;
+
+  int num = 20;
+  int nthr = 10;
+  std::vector<std::thread> thr(nthr);
+  for (int i = 0; i < nthr; ++i) {
+    thr[i] = std::thread([&] {
+      while (!ready.load()) {
+        /* spin */
+      }
+      hazptr_holder hptr;
+      auto p = hptr.get_protected(head);
+      ++setHazptrs;
+      /* Concurrent with removal */
+      int v = num;
+      for (auto q = p; q; q = q->next_) {
+        CHECK_GT(v, 0);
+        --v;
+        CHECK_EQ(q->val_, v);
+      }
+      CHECK_EQ(v, 0);
+    });
+  }
+
+  Foo* p = nullptr;
+  for (int i = 0; i < num; ++i) {
+    p = new Foo(i, p);
+    p->acquire_ref_safe();
+  }
+  head.store(p);
+
+  ready.store(true);
+
+  while (setHazptrs.load() < nthr) {
+    /* spin */
+  }
+
+  /* this is concurrent with traversal by reader */
+  head.store(nullptr);
+  for (auto q = p; q; q = q->next_) {
+    q->retire();
+  }
+  DEBUG_PRINT("Foo should not be destroyed");
+  CHECK_EQ(constructed.load(), num);
+  CHECK_EQ(destroyed.load(), 0);
+
+  DEBUG_PRINT("Foo may be destroyed after releasing the last reference");
+  if (p->release_ref()) {
+    delete p;
+  }
+
+  /* retire enough objects to guarantee reclamation of Foo objects */
+  for (int i = 0; i < 100; ++i) {
+    auto a = new Dummy;
+    a->retire();
+  }
+
+  for (int i = 0; i < nthr; ++i) {
+    thr[i].join();
+  }
+
+  CHECK_EQ(constructed.load(), num);
+  CHECK_EQ(destroyed.load(), num);
+}
+
+TEST_F(HazptrTest, FreeFunctionRetire) {
+  auto foo = new int;
+  hazptr_retire(foo);
+  auto foo2 = new int;
+  hazptr_retire(foo2, [](int* obj) { delete obj; });
+
+  bool retired = false;
+  {
+    hazptr_domain myDomain0;
+    struct delret {
+      bool* retired_;
+      delret(bool* retire) : retired_(retire) {}
+      ~delret() {
+        *retired_ = true;
+      }
+    };
+    auto foo3 = new delret(&retired);
+    myDomain0.retire(foo3);
+  }
+  EXPECT_TRUE(retired);
+}