Adds writer test case for RCU
[folly.git] / folly / experimental / hazptr / test / HazptrTest.cpp
1 /*
2  * Copyright 2016-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #define HAZPTR_DEBUG true
17 #define HAZPTR_STATS true
18 #define HAZPTR_SCAN_THRESHOLD 10
19
20 #include <folly/experimental/hazptr/debug.h>
21 #include <folly/experimental/hazptr/example/LockFreeLIFO.h>
22 #include <folly/experimental/hazptr/example/MWMRSet.h>
23 #include <folly/experimental/hazptr/example/SWMRList.h>
24 #include <folly/experimental/hazptr/example/WideCAS.h>
25 #include <folly/experimental/hazptr/hazptr.h>
26 #include <folly/experimental/hazptr/test/HazptrUse1.h>
27 #include <folly/experimental/hazptr/test/HazptrUse2.h>
28
29 #include <folly/portability/GFlags.h>
30 #include <folly/portability/GTest.h>
31
32 #include <thread>
33
34 DEFINE_int32(num_threads, 5, "Number of threads");
35 DEFINE_int64(num_reps, 1, "Number of test reps");
36 DEFINE_int64(num_ops, 10, "Number of ops or pairs of ops per rep");
37
38 using namespace folly::hazptr;
39
40 class HazptrTest : public testing::Test {
41  public:
42   HazptrTest() : Test() {
43     DEBUG_PRINT("========== start of test scope");
44   }
45   ~HazptrTest() override {
46     DEBUG_PRINT("========== end of test scope");
47   }
48 };
49
50 TEST_F(HazptrTest, Test1) {
51   DEBUG_PRINT("");
52   Node1* node0 = (Node1*)malloc(sizeof(Node1));
53   node0 = new (node0) Node1;
54   DEBUG_PRINT("=== malloc node0 " << node0 << " " << sizeof(*node0));
55   Node1* node1 = (Node1*)malloc(sizeof(Node1));
56   node1 = new (node1) Node1;
57   DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1));
58   Node1* node2 = (Node1*)malloc(sizeof(Node1));
59   node2 = new (node2) Node1;
60   DEBUG_PRINT("=== malloc node2 " << node2 << " " << sizeof(*node2));
61   Node1* node3 = (Node1*)malloc(sizeof(Node1));
62   node3 = new (node3) Node1;
63   DEBUG_PRINT("=== malloc node3 " << node3 << " " << sizeof(*node3));
64
65   DEBUG_PRINT("");
66
67   std::atomic<Node1*> shared0 = {node0};
68   std::atomic<Node1*> shared1 = {node1};
69   std::atomic<Node1*> shared2 = {node2};
70   std::atomic<Node1*> shared3 = {node3};
71
72   MyMemoryResource myMr;
73   DEBUG_PRINT("=== myMr " << &myMr);
74   hazptr_domain myDomain0;
75   DEBUG_PRINT("=== myDomain0 " << &myDomain0);
76   hazptr_domain myDomain1(&myMr);
77   DEBUG_PRINT("=== myDomain1 " << &myDomain1);
78
79   DEBUG_PRINT("");
80
81   DEBUG_PRINT("=== hptr0");
82   hazptr_holder hptr0;
83   DEBUG_PRINT("=== hptr1");
84   hazptr_holder hptr1(myDomain0);
85   DEBUG_PRINT("=== hptr2");
86   hazptr_holder hptr2(myDomain1);
87   DEBUG_PRINT("=== hptr3");
88   hazptr_holder hptr3;
89
90   DEBUG_PRINT("");
91
92   Node1* n0 = shared0.load();
93   Node1* n1 = shared1.load();
94   Node1* n2 = shared2.load();
95   Node1* n3 = shared3.load();
96
97   CHECK(hptr0.try_protect(n0, shared0));
98   CHECK(hptr1.try_protect(n1, shared1));
99   hptr1.reset();
100   hptr1.reset(nullptr);
101   hptr1.reset(n2);
102   CHECK(hptr2.try_protect(n3, shared3));
103   swap(hptr1, hptr2);
104   hptr3.reset();
105
106   DEBUG_PRINT("");
107
108   DEBUG_PRINT("=== retire n0 " << n0);
109   n0->retire();
110   DEBUG_PRINT("=== retire n1 " << n1);
111   n1->retire(default_hazptr_domain());
112   DEBUG_PRINT("=== retire n2 " << n2);
113   n2->retire(myDomain0);
114   DEBUG_PRINT("=== retire n3 " << n3);
115   n3->retire(myDomain1);
116 }
117
118 TEST_F(HazptrTest, Test2) {
119   Node2* node0 = new Node2;
120   DEBUG_PRINT("=== new    node0 " << node0 << " " << sizeof(*node0));
121   Node2* node1 = (Node2*)malloc(sizeof(Node2));
122   node1 = new (node1) Node2;
123   DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1));
124   Node2* node2 = (Node2*)malloc(sizeof(Node2));
125   node2 = new (node2) Node2;
126   DEBUG_PRINT("=== malloc node2 " << node2 << " " << sizeof(*node2));
127   Node2* node3 = (Node2*)malloc(sizeof(Node2));
128   node3 = new (node3) Node2;
129   DEBUG_PRINT("=== malloc node3 " << node3 << " " << sizeof(*node3));
130
131   DEBUG_PRINT("");
132
133   std::atomic<Node2*> shared0 = {node0};
134   std::atomic<Node2*> shared1 = {node1};
135   std::atomic<Node2*> shared2 = {node2};
136   std::atomic<Node2*> shared3 = {node3};
137
138   MineMemoryResource mineMr;
139   DEBUG_PRINT("=== mineMr " << &mineMr);
140   hazptr_domain mineDomain0;
141   DEBUG_PRINT("=== mineDomain0 " << &mineDomain0);
142   hazptr_domain mineDomain1(&mineMr);
143   DEBUG_PRINT("=== mineDomain1 " << &mineDomain1);
144
145   DEBUG_PRINT("");
146
147   DEBUG_PRINT("=== hptr0");
148   hazptr_holder hptr0;
149   DEBUG_PRINT("=== hptr1");
150   hazptr_holder hptr1(mineDomain0);
151   DEBUG_PRINT("=== hptr2");
152   hazptr_holder hptr2(mineDomain1);
153   DEBUG_PRINT("=== hptr3");
154   hazptr_holder hptr3;
155
156   DEBUG_PRINT("");
157
158   Node2* n0 = shared0.load();
159   Node2* n1 = shared1.load();
160   Node2* n2 = shared2.load();
161   Node2* n3 = shared3.load();
162
163   CHECK(hptr0.try_protect(n0, shared0));
164   CHECK(hptr1.try_protect(n1, shared1));
165   hptr1.reset();
166   hptr1.reset(n2);
167   CHECK(hptr2.try_protect(n3, shared3));
168   swap(hptr1, hptr2);
169   hptr3.reset();
170
171   DEBUG_PRINT("");
172
173   DEBUG_PRINT("=== retire n0 " << n0);
174   n0->retire(default_hazptr_domain(), &mineReclaimFnDelete);
175   DEBUG_PRINT("=== retire n1 " << n1);
176   n1->retire(default_hazptr_domain(), &mineReclaimFnFree);
177   DEBUG_PRINT("=== retire n2 " << n2);
178   n2->retire(mineDomain0, &mineReclaimFnFree);
179   DEBUG_PRINT("=== retire n3 " << n3);
180   n3->retire(mineDomain1, &mineReclaimFnFree);
181 }
182
183 TEST_F(HazptrTest, LIFO) {
184   using T = uint32_t;
185   CHECK_GT(FLAGS_num_threads, 0);
186   for (int i = 0; i < FLAGS_num_reps; ++i) {
187     DEBUG_PRINT("========== start of rep scope");
188     LockFreeLIFO<T> s;
189     std::vector<std::thread> threads(FLAGS_num_threads);
190     for (int tid = 0; tid < FLAGS_num_threads; ++tid) {
191       threads[tid] = std::thread([&s, tid]() {
192         for (int j = tid; j < FLAGS_num_ops; j += FLAGS_num_threads) {
193           s.push(j);
194           T res;
195           while (!s.pop(res)) {
196             /* keep trying */
197           }
198         }
199       });
200     }
201     for (auto& t : threads) {
202       t.join();
203     }
204     DEBUG_PRINT("========== end of rep scope");
205   }
206 }
207
208 TEST_F(HazptrTest, SWMRLIST) {
209   using T = uint64_t;
210
211   CHECK_GT(FLAGS_num_threads, 0);
212   for (int i = 0; i < FLAGS_num_reps; ++i) {
213     DEBUG_PRINT("========== start of rep scope");
214     SWMRListSet<T> s;
215     std::vector<std::thread> threads(FLAGS_num_threads);
216     for (int tid = 0; tid < FLAGS_num_threads; ++tid) {
217       threads[tid] = std::thread([&s, tid]() {
218         for (int j = tid; j < FLAGS_num_ops; j += FLAGS_num_threads) {
219           s.contains(j);
220         }
221       });
222     }
223     for (int j = 0; j < 10; ++j) {
224       s.add(j);
225     }
226     for (int j = 0; j < 10; ++j) {
227       s.remove(j);
228     }
229     for (auto& t : threads) {
230       t.join();
231     }
232     DEBUG_PRINT("========== end of rep scope");
233   }
234 }
235
236 TEST_F(HazptrTest, MWMRSet) {
237   using T = uint64_t;
238
239   CHECK_GT(FLAGS_num_threads, 0);
240   for (int i = 0; i < FLAGS_num_reps; ++i) {
241     DEBUG_PRINT("========== start of rep scope");
242     MWMRListSet<T> s;
243     std::vector<std::thread> threads(FLAGS_num_threads);
244     for (int tid = 0; tid < FLAGS_num_threads; ++tid) {
245       threads[tid] = std::thread([&s, tid]() {
246         for (int j = tid; j < FLAGS_num_ops; j += FLAGS_num_threads) {
247           s.contains(j);
248           s.add(j);
249           s.remove(j);
250         }
251       });
252     }
253     for (int j = 0; j < 10; ++j) {
254       s.add(j);
255     }
256     for (int j = 0; j < 10; ++j) {
257       s.remove(j);
258     }
259     for (auto& t : threads) {
260       t.join();
261     }
262     DEBUG_PRINT("========== end of rep scope");
263   }
264 }
265
266 TEST_F(HazptrTest, WIDECAS) {
267   WideCAS s;
268   std::string u = "";
269   std::string v = "11112222";
270   auto ret = s.cas(u, v);
271   CHECK(ret);
272   u = "";
273   v = "11112222";
274   ret = s.cas(u, v);
275   CHECK(!ret);
276   u = "11112222";
277   v = "22223333";
278   ret = s.cas(u, v);
279   CHECK(ret);
280   u = "22223333";
281   v = "333344445555";
282   ret = s.cas(u, v);
283   CHECK(ret);
284 }
285
286 TEST_F(HazptrTest, VirtualTest) {
287   struct Thing : public hazptr_obj_base<Thing> {
288     virtual ~Thing() {
289       DEBUG_PRINT("this: " << this << " &a: " << &a << " a: " << a);
290     }
291     int a;
292   };
293   for (int i = 0; i < 100; i++) {
294     auto bar = new Thing;
295     bar->a = i;
296
297     hazptr_holder hptr;
298     hptr.reset(bar);
299     bar->retire();
300     EXPECT_EQ(bar->a, i);
301   }
302 }
303
304 void destructionTest(hazptr_domain& domain) {
305   struct Thing : public hazptr_obj_base<Thing> {
306     Thing* next;
307     hazptr_domain* domain;
308     int val;
309     Thing(int v, Thing* n, hazptr_domain* d) : next(n), domain(d), val(v) {}
310     ~Thing() {
311       DEBUG_PRINT("this: " << this << " val: " << val << " next: " << next);
312       if (next) {
313         next->retire(*domain);
314       }
315     }
316   };
317   Thing* last{nullptr};
318   for (int i = 0; i < 2000; i++) {
319     last = new Thing(i, last, &domain);
320   }
321   last->retire(domain);
322 }
323
324 TEST_F(HazptrTest, DestructionTest) {
325   {
326     hazptr_domain myDomain0;
327     destructionTest(myDomain0);
328   }
329   destructionTest(default_hazptr_domain());
330 }
331
332 TEST_F(HazptrTest, Move) {
333   struct Foo : hazptr_obj_base<Foo> {
334     int a;
335   };
336   for (int i = 0; i < 100; ++i) {
337     Foo* x = new Foo;
338     x->a = i;
339     hazptr_holder hptr0;
340     // Protect object
341     hptr0.reset(x);
342     // Retire object
343     x->retire();
344     // Move constructor - still protected
345     hazptr_holder hptr1(std::move(hptr0));
346     // Self move is no-op - still protected
347     hazptr_holder* phptr1 = &hptr1;
348     CHECK_EQ(phptr1, &hptr1);
349     hptr1 = std::move(*phptr1);
350     // Empty constructor
351     hazptr_holder hptr2(nullptr);
352     // Move assignment - still protected
353     hptr2 = std::move(hptr1);
354     // Access object
355     CHECK_EQ(x->a, i);
356     // Unprotect object - hptr2 is nonempty
357     hptr2.reset();
358   }
359 }
360
361 TEST_F(HazptrTest, Array) {
362   struct Foo : hazptr_obj_base<Foo> {
363     int a;
364   };
365   for (int i = 0; i < 100; ++i) {
366     Foo* x = new Foo;
367     x->a = i;
368     hazptr_array<10> hptr;
369     // Protect object
370     hptr[9].reset(x);
371     // Empty array
372     hazptr_array<10> h(nullptr);
373     // Move assignment
374     h = std::move(hptr);
375     // Retire object
376     x->retire();
377     // Unprotect object - hptr2 is nonempty
378     h[9].reset();
379   }
380   {
381     // Abnormal case
382     hazptr_array<HAZPTR_TC_SIZE + 1> h;
383     hazptr_array<HAZPTR_TC_SIZE + 1> h2(std::move(h));
384   }
385 }
386
387 TEST_F(HazptrTest, Local) {
388   struct Foo : hazptr_obj_base<Foo> {
389     int a;
390   };
391   for (int i = 0; i < 100; ++i) {
392     Foo* x = new Foo;
393     x->a = i;
394     hazptr_local<10> hptr;
395     // Protect object
396     hptr[9].reset(x);
397     // Retire object
398     x->retire();
399     // Unprotect object - hptr2 is nonempty
400     hptr[9].reset();
401   }
402   {
403     // Abnormal case
404     hazptr_local<HAZPTR_TC_SIZE + 1> h;
405   }
406 }
407
408 /* Test ref counting */
409
410 std::atomic<int> constructed;
411 std::atomic<int> destroyed;
412
413 struct Foo : hazptr_obj_base_refcounted<Foo> {
414   int val_;
415   bool marked_;
416   Foo* next_;
417   Foo(int v, Foo* n) : val_(v), marked_(false), next_(n) {
418     DEBUG_PRINT("");
419     ++constructed;
420   }
421   ~Foo() {
422     DEBUG_PRINT("");
423     ++destroyed;
424     if (marked_) {
425       return;
426     }
427     auto next = next_;
428     while (next) {
429       if (!next->release_ref()) {
430         return;
431       }
432       auto p = next;
433       next = p->next_;
434       p->marked_ = true;
435       delete p;
436     }
437   }
438 };
439
440 struct Dummy : hazptr_obj_base<Dummy> {};
441
442 TEST_F(HazptrTest, basic_refcount) {
443   constructed.store(0);
444   destroyed.store(0);
445
446   Foo* p = nullptr;
447   int num = 20;
448   for (int i = 0; i < num; ++i) {
449     p = new Foo(i, p);
450     if (i & 1) {
451       p->acquire_ref_safe();
452     } else {
453       p->acquire_ref();
454     }
455   }
456   hazptr_holder hptr;
457   hptr.reset(p);
458   for (auto q = p->next_; q; q = q->next_) {
459     q->retire();
460   }
461   int v = num;
462   for (auto q = p; q; q = q->next_) {
463     CHECK_GT(v, 0);
464     --v;
465     CHECK_EQ(q->val_, v);
466   }
467   CHECK(!p->release_ref());
468   CHECK_EQ(constructed.load(), num);
469   CHECK_EQ(destroyed.load(), 0);
470   p->retire();
471   CHECK_EQ(constructed.load(), num);
472   CHECK_EQ(destroyed.load(), 0);
473   hptr.reset();
474
475   /* retire enough objects to guarantee reclamation of Foo objects */
476   for (int i = 0; i < 100; ++i) {
477     auto a = new Dummy;
478     a->retire();
479   }
480
481   CHECK_EQ(constructed.load(), num);
482   CHECK_EQ(destroyed.load(), num);
483 }
484
485 TEST_F(HazptrTest, mt_refcount) {
486   constructed.store(0);
487   destroyed.store(0);
488
489   std::atomic<bool> ready(false);
490   std::atomic<int> setHazptrs(0);
491   std::atomic<Foo*> head;
492
493   int num = 20;
494   int nthr = 10;
495   std::vector<std::thread> thr(nthr);
496   for (int i = 0; i < nthr; ++i) {
497     thr[i] = std::thread([&] {
498       while (!ready.load()) {
499         /* spin */
500       }
501       hazptr_holder hptr;
502       auto p = hptr.get_protected(head);
503       ++setHazptrs;
504       /* Concurrent with removal */
505       int v = num;
506       for (auto q = p; q; q = q->next_) {
507         CHECK_GT(v, 0);
508         --v;
509         CHECK_EQ(q->val_, v);
510       }
511       CHECK_EQ(v, 0);
512     });
513   }
514
515   Foo* p = nullptr;
516   for (int i = 0; i < num; ++i) {
517     p = new Foo(i, p);
518     p->acquire_ref_safe();
519   }
520   head.store(p);
521
522   ready.store(true);
523
524   while (setHazptrs.load() < nthr) {
525     /* spin */
526   }
527
528   /* this is concurrent with traversal by reader */
529   head.store(nullptr);
530   for (auto q = p; q; q = q->next_) {
531     q->retire();
532   }
533   DEBUG_PRINT("Foo should not be destroyed");
534   CHECK_EQ(constructed.load(), num);
535   CHECK_EQ(destroyed.load(), 0);
536
537   DEBUG_PRINT("Foo may be destroyed after releasing the last reference");
538   if (p->release_ref()) {
539     delete p;
540   }
541
542   /* retire enough objects to guarantee reclamation of Foo objects */
543   for (int i = 0; i < 100; ++i) {
544     auto a = new Dummy;
545     a->retire();
546   }
547
548   for (int i = 0; i < nthr; ++i) {
549     thr[i].join();
550   }
551
552   CHECK_EQ(constructed.load(), num);
553   CHECK_EQ(destroyed.load(), num);
554 }
555
556 TEST_F(HazptrTest, FreeFunctionRetire) {
557   auto foo = new int;
558   hazptr_retire(foo);
559   auto foo2 = new int;
560   hazptr_retire(foo2, [](int* obj) { delete obj; });
561
562   bool retired = false;
563   {
564     hazptr_domain myDomain0;
565     struct delret {
566       bool* retired_;
567       delret(bool* retire) : retired_(retire) {}
568       ~delret() {
569         *retired_ = true;
570       }
571     };
572     auto foo3 = new delret(&retired);
573     myDomain0.retire(foo3);
574   }
575   EXPECT_TRUE(retired);
576 }