Update hazard pointers interface and implementation
[folly.git] / folly / experimental / hazptr / test / HazptrTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/experimental/hazptr/test/HazptrUse1.h>
17 #include <folly/experimental/hazptr/test/HazptrUse2.h>
18 #include <folly/experimental/hazptr/example/LockFreeLIFO.h>
19 #include <folly/experimental/hazptr/example/SWMRList.h>
20 #include <folly/experimental/hazptr/example/WideCAS.h>
21 #include <folly/experimental/hazptr/debug.h>
22 #include <folly/experimental/hazptr/hazptr.h>
23
24 #include <gflags/gflags.h>
25 #include <folly/portability/GTest.h>
26
27 #include <thread>
28
29 DEFINE_int32(num_threads, 1, "Number of threads");
30 DEFINE_int64(num_reps, 1, "Number of test reps");
31 DEFINE_int64(num_ops, 10, "Number of ops or pairs of ops per rep");
32
33 using namespace folly::hazptr;
34
35 TEST(Hazptr, Test1) {
36   DEBUG_PRINT("========== start of scope");
37   DEBUG_PRINT("");
38   Node1* node0 = (Node1*)malloc(sizeof(Node1));
39   DEBUG_PRINT("=== new    node0 " << node0 << " " << sizeof(*node0));
40   Node1* node1 = (Node1*)malloc(sizeof(Node1));
41   DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1));
42   Node1* node2 = (Node1*)malloc(sizeof(Node1));
43   DEBUG_PRINT("=== malloc node2 " << node2 << " " << sizeof(*node2));
44   Node1* node3 = (Node1*)malloc(sizeof(Node1));
45   DEBUG_PRINT("=== malloc node3 " << node3 << " " << sizeof(*node3));
46
47   DEBUG_PRINT("");
48
49   std::atomic<Node1*> shared0 = {node0};
50   std::atomic<Node1*> shared1 = {node1};
51   std::atomic<Node1*> shared2 = {node2};
52   std::atomic<Node1*> shared3 = {node3};
53
54   MyMemoryResource myMr;
55   DEBUG_PRINT("=== myMr " << &myMr);
56   hazptr_domain myDomain0;
57   DEBUG_PRINT("=== myDomain0 " << &myDomain0);
58   hazptr_domain myDomain1(&myMr);
59   DEBUG_PRINT("=== myDomain1 " << &myDomain1);
60
61   DEBUG_PRINT("");
62
63   DEBUG_PRINT("=== hptr0");
64   hazptr_owner<Node1> hptr0;
65   DEBUG_PRINT("=== hptr1");
66   hazptr_owner<Node1> hptr1(myDomain0);
67   DEBUG_PRINT("=== hptr2");
68   hazptr_owner<Node1> hptr2(myDomain1);
69   DEBUG_PRINT("=== hptr3");
70   hazptr_owner<Node1> hptr3;
71
72   DEBUG_PRINT("");
73
74   Node1* n0 = shared0.load();
75   Node1* n1 = shared1.load();
76   Node1* n2 = shared2.load();
77   Node1* n3 = shared3.load();
78
79   if (hptr0.try_protect(n0, shared0)) {}
80   if (hptr1.try_protect(n1, shared1)) {}
81   hptr1.clear();
82   hptr1.set(n2);
83   if (hptr2.try_protect(n3, shared3)) {}
84   swap(hptr1, hptr2);
85   hptr3.clear();
86
87   DEBUG_PRINT("");
88
89   DEBUG_PRINT("=== retire n0 " << n0);
90   n0->retire();
91   DEBUG_PRINT("=== retire n1 " << n1);
92   n1->retire(default_hazptr_domain());
93   DEBUG_PRINT("=== retire n2 " << n2);
94   n2->retire(myDomain0);
95   DEBUG_PRINT("=== retire n3 " << n3);
96   n3->retire(myDomain1);
97
98   DEBUG_PRINT("========== end of scope");
99 }
100
101 TEST(Hazptr, Test2) {
102   DEBUG_PRINT("========== start of scope");
103   Node2* node0 = new Node2;
104   DEBUG_PRINT("=== new    node0 " << node0 << " " << sizeof(*node0));
105   Node2* node1 = (Node2*)malloc(sizeof(Node2));
106   DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1));
107   Node2* node2 = (Node2*)malloc(sizeof(Node2));
108   DEBUG_PRINT("=== malloc node2 " << node2 << " " << sizeof(*node2));
109   Node2* node3 = (Node2*)malloc(sizeof(Node2));
110   DEBUG_PRINT("=== malloc node3 " << node3 << " " << sizeof(*node3));
111
112   DEBUG_PRINT("");
113
114   std::atomic<Node2*> shared0 = {node0};
115   std::atomic<Node2*> shared1 = {node1};
116   std::atomic<Node2*> shared2 = {node2};
117   std::atomic<Node2*> shared3 = {node3};
118
119   MineMemoryResource mineMr;
120   DEBUG_PRINT("=== mineMr " << &mineMr);
121   hazptr_domain mineDomain0;
122   DEBUG_PRINT("=== mineDomain0 " << &mineDomain0);
123   hazptr_domain mineDomain1(&mineMr);
124   DEBUG_PRINT("=== mineDomain1 " << &mineDomain1);
125
126   DEBUG_PRINT("");
127
128   DEBUG_PRINT("=== hptr0");
129   hazptr_owner<Node2> hptr0;
130   DEBUG_PRINT("=== hptr1");
131   hazptr_owner<Node2> hptr1(mineDomain0);
132   DEBUG_PRINT("=== hptr2");
133   hazptr_owner<Node2> hptr2(mineDomain1);
134   DEBUG_PRINT("=== hptr3");
135   hazptr_owner<Node2> hptr3;
136
137   DEBUG_PRINT("");
138
139   Node2* n0 = shared0.load();
140   Node2* n1 = shared1.load();
141   Node2* n2 = shared2.load();
142   Node2* n3 = shared3.load();
143
144   if (hptr0.try_protect(n0, shared0)) {}
145   if (hptr1.try_protect(n1, shared1)) {}
146   hptr1.clear();
147   hptr1.set(n2);
148   if (hptr2.try_protect(n3, shared3)) {}
149   swap(hptr1, hptr2);
150   hptr3.clear();
151
152   DEBUG_PRINT("");
153
154   DEBUG_PRINT("=== retire n0 " << n0);
155   n0->retire(default_hazptr_domain(), &mineReclaimFnDelete);
156   DEBUG_PRINT("=== retire n1 " << n1);
157   n1->retire(default_hazptr_domain(), &mineReclaimFnFree);
158   DEBUG_PRINT("=== retire n2 " << n2);
159   n2->retire(mineDomain0, &mineReclaimFnFree);
160   DEBUG_PRINT("=== retire n3 " << n3);
161   n3->retire(mineDomain1, &mineReclaimFnFree);
162
163   DEBUG_PRINT("========== end of scope");
164 }
165
166 TEST(Hazptr, LIFO) {
167   using T = uint32_t;
168   DEBUG_PRINT("========== start of test scope");
169   CHECK_GT(FLAGS_num_threads, 0);
170   for (int i = 0; i < FLAGS_num_reps; ++i) {
171     DEBUG_PRINT("========== start of rep scope");
172     LockFreeLIFO<T> s;
173     std::vector<std::thread> threads(FLAGS_num_threads);
174     for (int tid = 0; tid < FLAGS_num_threads; ++tid) {
175       threads[tid] = std::thread([&s, tid]() {
176         for (int j = tid; j < FLAGS_num_ops; j += FLAGS_num_threads) {
177           s.push(j);
178           T res;
179           while (!s.pop(res)) {}
180         }
181       });
182     }
183     for (auto& t : threads) {
184       t.join();
185     }
186     DEBUG_PRINT("========== end of rep scope");
187   }
188   DEBUG_PRINT("========== end of test scope");
189 }
190
191 TEST(Hazptr, SWMRLIST) {
192   using T = uint64_t;
193   DEBUG_PRINT("========== start of test scope");
194   hazptr_domain custom_domain;
195
196   CHECK_GT(FLAGS_num_threads, 0);
197   for (int i = 0; i < FLAGS_num_reps; ++i) {
198     DEBUG_PRINT("========== start of rep scope");
199     SWMRListSet<T> s(custom_domain);
200     std::vector<std::thread> threads(FLAGS_num_threads);
201     for (int tid = 0; tid < FLAGS_num_threads; ++tid) {
202       threads[tid] = std::thread([&s, tid]() {
203         for (int j = tid; j < FLAGS_num_ops; j += FLAGS_num_threads) {
204           s.contains(j);
205         }
206       });
207     }
208     for (int j = 0; j < 10; ++j) {
209       s.add(j);
210     }
211     for (int j = 0; j < 10; ++j) {
212       s.remove(j);
213     }
214     for (auto& t : threads) {
215       t.join();
216     }
217     DEBUG_PRINT("========== end of rep scope");
218   }
219   DEBUG_PRINT("========== end of test scope");
220 }
221
222 TEST(Hazptr, WIDECAS) {
223   DEBUG_PRINT("========== start of test scope");
224
225   WideCAS s;
226   std::string u = "";
227   std::string v = "11112222";
228   auto ret = s.cas(u, v);
229   CHECK(ret);
230   u = "";
231   v = "11112222";
232   ret = s.cas(u, v);
233   CHECK(!ret);
234   u = "11112222";
235   v = "22223333";
236   ret = s.cas(u, v);
237   CHECK(ret);
238   u = "22223333";
239   v = "333344445555";
240   ret = s.cas(u, v);
241   CHECK(ret);
242
243   DEBUG_PRINT("========== end of test scope");
244 }
245
246 int main(int argc, char** argv) {
247   DEBUG_PRINT("================================================= start main");
248   testing::InitGoogleTest(&argc, argv);
249   google::ParseCommandLineFlags(&argc, &argv, true);
250   auto ret = RUN_ALL_TESTS();
251   DEBUG_PRINT("================================================= end main");
252   return ret;
253 }