Fix intermittent failure in concurrent_skip_list_test
[folly.git] / folly / test / ConcurrentSkipListTest.cpp
1 /*
2  * Copyright 2012 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 // @author: Xin Liu <xliux@fb.com>
18
19 #include <set>
20 #include <vector>
21 #include <thread>
22 #include <system_error>
23
24 #include <glog/logging.h>
25 #include <gflags/gflags.h>
26 #include "folly/ConcurrentSkipList.h"
27 #include "folly/Foreach.h"
28 #include "folly/String.h"
29 #include "gtest/gtest.h"
30
31 DEFINE_int32(num_threads, 12, "num concurrent threads to test");
32
33 namespace {
34
35 using namespace folly;
36 using std::vector;
37
38 typedef int ValueType;
39 typedef detail::SkipListNode<ValueType> SkipListNodeType;
40 typedef ConcurrentSkipList<ValueType> SkipListType;
41 typedef SkipListType::Accessor SkipListAccessor;
42 typedef vector<ValueType> VectorType;
43 typedef std::set<ValueType> SetType;
44
45 static const int kHeadHeight = 2;
46 static const int kMaxValue = 5000;
47
48 static void randomAdding(int size,
49     SkipListAccessor skipList,
50     SetType *verifier,
51     int maxValue = kMaxValue) {
52   for (int i = 0; i < size; ++i) {
53     int32_t r = rand() % maxValue;
54     verifier->insert(r);
55     skipList.add(r);
56   }
57 }
58
59 static void randomRemoval(int size,
60     SkipListAccessor skipList,
61     SetType *verifier,
62     int maxValue=kMaxValue) {
63   for (int i = 0; i < size; ++i) {
64     int32_t r = rand() % maxValue;
65     verifier->insert(r);
66     skipList.remove(r);
67   }
68 }
69
70 static void sumAllValues(SkipListAccessor skipList, int64_t *sum) {
71   *sum = 0;
72   FOR_EACH(it, skipList) {
73     *sum += *it;
74   }
75   VLOG(20) << "sum = " << sum;
76 }
77
78 static void concurrentSkip(const vector<ValueType> *values,
79     SkipListAccessor skipList) {
80   int64_t sum = 0;
81   SkipListAccessor::Skipper skipper(skipList);
82   FOR_EACH(it, *values) {
83     if (skipper.to(*it)) sum += *it;
84   }
85   VLOG(20) << "sum = " << sum;
86 }
87
88 bool verifyEqual(SkipListAccessor skipList,
89     const SetType &verifier) {
90   EXPECT_EQ(verifier.size(), skipList.size());
91   FOR_EACH(it, verifier) {
92     CHECK(skipList.contains(*it)) << *it;
93     SkipListType::const_iterator iter = skipList.find(*it);
94     CHECK(iter != skipList.end());
95     EXPECT_EQ(*iter, *it);
96   }
97   EXPECT_TRUE(std::equal(verifier.begin(), verifier.end(), skipList.begin()));
98   return true;
99 }
100
101 TEST(ConcurrentSkipList, SequentialAccess) {
102   {
103     LOG(INFO) << "nodetype size=" << sizeof(SkipListNodeType);
104
105     auto skipList(SkipListType::create(kHeadHeight));
106     EXPECT_TRUE(skipList.first() == NULL);
107     EXPECT_TRUE(skipList.last() == NULL);
108
109     skipList.add(3);
110     EXPECT_TRUE(skipList.contains(3));
111     EXPECT_FALSE(skipList.contains(2));
112     EXPECT_EQ(3, *skipList.first());
113     EXPECT_EQ(3, *skipList.last());
114
115     EXPECT_EQ(3, *skipList.find(3));
116     EXPECT_FALSE(skipList.find(3) == skipList.end());
117     EXPECT_TRUE(skipList.find(2) == skipList.end());
118
119     {
120       SkipListAccessor::Skipper skipper(skipList);
121       skipper.to(3);
122       CHECK_EQ(3, *skipper);
123     }
124
125     skipList.add(2);
126     EXPECT_EQ(2, *skipList.first());
127     EXPECT_EQ(3, *skipList.last());
128     skipList.add(5);
129     EXPECT_EQ(5, *skipList.last());
130     skipList.add(3);
131     EXPECT_EQ(5, *skipList.last());
132     auto ret = skipList.insert(9);
133     EXPECT_EQ(9, *ret.first);
134     EXPECT_TRUE(ret.second);
135
136     ret = skipList.insert(5);
137     EXPECT_EQ(5, *ret.first);
138     EXPECT_FALSE(ret.second);
139
140     EXPECT_EQ(2, *skipList.first());
141     EXPECT_EQ(9, *skipList.last());
142     EXPECT_TRUE(skipList.pop_back());
143     EXPECT_EQ(5, *skipList.last());
144     EXPECT_TRUE(skipList.pop_back());
145     EXPECT_EQ(3, *skipList.last());
146
147     skipList.add(9);
148     skipList.add(5);
149
150     CHECK(skipList.contains(2));
151     CHECK(skipList.contains(3));
152     CHECK(skipList.contains(5));
153     CHECK(skipList.contains(9));
154     CHECK(!skipList.contains(4));
155
156     // lower_bound
157     auto it = skipList.lower_bound(5);
158     EXPECT_EQ(5, *it);
159     it = skipList.lower_bound(4);
160     EXPECT_EQ(5, *it);
161     it = skipList.lower_bound(9);
162     EXPECT_EQ(9, *it);
163     it = skipList.lower_bound(12);
164     EXPECT_FALSE(it.good());
165
166     it = skipList.begin();
167     EXPECT_EQ(2, *it);
168
169     // skipper test
170     SkipListAccessor::Skipper skipper(skipList);
171     skipper.to(3);
172     EXPECT_EQ(3, skipper.data());
173     skipper.to(5);
174     EXPECT_EQ(5, skipper.data());
175     CHECK(!skipper.to(7));
176
177     skipList.remove(5);
178     skipList.remove(3);
179     CHECK(skipper.to(9));
180     EXPECT_EQ(9, skipper.data());
181
182     CHECK(!skipList.contains(3));
183     skipList.add(3);
184     CHECK(skipList.contains(3));
185     int pos = 0;
186     FOR_EACH(it, skipList) {
187       LOG(INFO) << "pos= " << pos++ << " value= " << *it;
188     }
189   }
190
191   {
192     auto skipList(SkipListType::create(kHeadHeight));
193
194     SetType verifier;
195     randomAdding(10000, skipList, &verifier);
196     verifyEqual(skipList, verifier);
197
198     // test skipper
199     SkipListAccessor::Skipper skipper(skipList);
200     int num_skips = 1000;
201     for (int i = 0; i < num_skips; ++i) {
202       int n = i * kMaxValue / num_skips;
203       bool found = skipper.to(n);
204       EXPECT_EQ(found, (verifier.find(n) != verifier.end()));
205     }
206   }
207
208 }
209
210 void testConcurrentAdd(int numThreads) {
211   auto skipList(SkipListType::create(kHeadHeight));
212
213   vector<std::thread> threads;
214   vector<SetType> verifiers(numThreads);
215   try {
216     for (int i = 0; i < numThreads; ++i) {
217       threads.push_back(std::thread(
218             &randomAdding, 100, skipList, &verifiers[i], kMaxValue));
219     }
220   } catch (const std::system_error& e) {
221     LOG(WARNING)
222       << "Caught " << exceptionStr(e)
223       << ": could only create " << threads.size() << " threads out of "
224       << numThreads;
225   }
226   for (int i = 0; i < threads.size(); ++i) {
227     threads[i].join();
228   }
229
230   SetType all;
231   FOR_EACH(s, verifiers) {
232     all.insert(s->begin(), s->end());
233   }
234   verifyEqual(skipList, all);
235 }
236
237 TEST(ConcurrentSkipList, ConcurrentAdd) {
238   // test it many times
239   for (int numThreads = 10; numThreads < 10000; numThreads += 1000) {
240     testConcurrentAdd(numThreads);
241   }
242 }
243
244 void testConcurrentRemoval(int numThreads, int maxValue) {
245   auto skipList = SkipListType::create(kHeadHeight);
246   for (int i = 0; i < maxValue; ++i) {
247     skipList.add(i);
248   }
249
250   vector<std::thread> threads;
251   vector<SetType > verifiers(numThreads);
252   try {
253     for (int i = 0; i < numThreads; ++i) {
254       threads.push_back(std::thread(
255             &randomRemoval, 100, skipList, &verifiers[i], maxValue));
256     }
257   } catch (const std::system_error& e) {
258     LOG(WARNING)
259       << "Caught " << exceptionStr(e)
260       << ": could only create " << threads.size() << " threads out of "
261       << numThreads;
262   }
263   FOR_EACH(t, threads) {
264     (*t).join();
265   }
266
267   SetType all;
268   FOR_EACH(s, verifiers) {
269     all.insert(s->begin(), s->end());
270   }
271
272   CHECK_EQ(maxValue, all.size() + skipList.size());
273   for (int i = 0; i < maxValue; ++i) {
274     if (all.find(i) != all.end()) {
275       CHECK(!skipList.contains(i)) << i;
276     } else {
277       CHECK(skipList.contains(i)) << i;
278     }
279   }
280 }
281
282 TEST(ConcurrentSkipList, ConcurrentRemove) {
283   for (int numThreads = 10; numThreads < 1000; numThreads += 100) {
284     testConcurrentRemoval(numThreads, 100 * numThreads);
285   }
286 }
287
288 static void testConcurrentAccess(
289     int numInsertions, int numDeletions, int maxValue) {
290   auto skipList = SkipListType::create(kHeadHeight);
291
292   vector<SetType> verifiers(FLAGS_num_threads);
293   vector<int64_t> sums(FLAGS_num_threads);
294   vector<vector<ValueType> > skipValues(FLAGS_num_threads);
295
296   for (int i = 0; i < FLAGS_num_threads; ++i) {
297     for (int j = 0; j < numInsertions; ++j) {
298       skipValues[i].push_back(rand() % (maxValue + 1));
299     }
300     std::sort(skipValues[i].begin(), skipValues[i].end());
301   }
302
303   vector<std::thread> threads;
304   for (int i = 0; i < FLAGS_num_threads; ++i) {
305     switch (i % 8) {
306       case 0:
307       case 1:
308         threads.push_back(std::thread(
309               randomAdding, numInsertions, skipList, &verifiers[i], maxValue));
310         break;
311       case 2:
312         threads.push_back(std::thread(
313               randomRemoval, numDeletions, skipList, &verifiers[i], maxValue));
314         break;
315       case 3:
316         threads.push_back(std::thread(
317               concurrentSkip, &skipValues[i], skipList));
318         break;
319       default:
320         threads.push_back(std::thread(sumAllValues, skipList, &sums[i]));
321         break;
322     }
323   }
324
325   FOR_EACH(t, threads) {
326     (*t).join();
327   }
328   // just run through it, no need to verify the correctness.
329 }
330
331 TEST(ConcurrentSkipList, ConcurrentAccess) {
332   testConcurrentAccess(10000, 100, kMaxValue);
333   testConcurrentAccess(100000, 10000, kMaxValue * 10);
334   testConcurrentAccess(1000000, 100000, kMaxValue);
335 }
336
337 }  // namespace
338
339 int main(int argc, char* argv[]) {
340   testing::InitGoogleTest(&argc, argv);
341   google::InitGoogleLogging(argv[0]);
342   google::ParseCommandLineFlags(&argc, &argv, true);
343
344   return RUN_ALL_TESTS();
345 }