Adding support for in-place use of ProducerConsumerQueue.
[folly.git] / folly / test / ConcurrentSkipListTest.cpp
1 /*
2  * Copyright 2012 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 // @author: Xin Liu <xliux@fb.com>
18
19 #include <set>
20 #include <vector>
21 #include <boost/thread.hpp>
22
23 #include <glog/logging.h>
24 #include <gflags/gflags.h>
25 #include "folly/ConcurrentSkipList.h"
26 #include "folly/Foreach.h"
27 #include "gtest/gtest.h"
28
29 DEFINE_int32(num_threads, 12, "num concurrent threads to test");
30
31 namespace {
32
33 using namespace folly;
34 using std::vector;
35
36 typedef int ValueType;
37 typedef detail::SkipListNode<ValueType> SkipListNodeType;
38 typedef ConcurrentSkipList<ValueType> SkipListType;
39 typedef SkipListType::Accessor SkipListAccessor;
40 typedef vector<ValueType> VectorType;
41 typedef std::set<ValueType> SetType;
42
43 static const int kHeadHeight = 2;
44 static const int kMaxValue = 5000;
45
46 static void randomAdding(int size,
47     SkipListAccessor skipList,
48     SetType *verifier,
49     int maxValue = kMaxValue) {
50   for (int i = 0; i < size; ++i) {
51     int32_t r = rand() % maxValue;
52     verifier->insert(r);
53     skipList.add(r);
54   }
55 }
56
57 static void randomRemoval(int size,
58     SkipListAccessor skipList,
59     SetType *verifier,
60     int maxValue=kMaxValue) {
61   for (int i = 0; i < size; ++i) {
62     int32_t r = rand() % maxValue;
63     verifier->insert(r);
64     skipList.remove(r);
65   }
66 }
67
68 static void sumAllValues(SkipListAccessor skipList, int64_t *sum) {
69   *sum = 0;
70   FOR_EACH(it, skipList) {
71     *sum += *it;
72   }
73   VLOG(20) << "sum = " << sum;
74 }
75
76 static void concurrentSkip(const vector<ValueType> *values,
77     SkipListAccessor skipList) {
78   int64_t sum = 0;
79   SkipListAccessor::Skipper skipper(skipList);
80   FOR_EACH(it, *values) {
81     if (skipper.to(*it)) sum += *it;
82   }
83   VLOG(20) << "sum = " << sum;
84 }
85
86 bool verifyEqual(SkipListAccessor skipList,
87     const SetType &verifier) {
88   EXPECT_EQ(verifier.size(), skipList.size());
89   FOR_EACH(it, verifier) {
90     CHECK(skipList.contains(*it)) << *it;
91     SkipListType::const_iterator iter = skipList.find(*it);
92     CHECK(iter != skipList.end());
93     EXPECT_EQ(*iter, *it);
94   }
95   EXPECT_TRUE(std::equal(verifier.begin(), verifier.end(), skipList.begin()));
96   return true;
97 }
98
99 TEST(ConcurrentSkipList, SequentialAccess) {
100   {
101     LOG(INFO) << "nodetype size=" << sizeof(SkipListNodeType);
102
103     auto skipList(SkipListType::create(kHeadHeight));
104     EXPECT_TRUE(skipList.first() == NULL);
105     EXPECT_TRUE(skipList.last() == NULL);
106
107     skipList.add(3);
108     EXPECT_TRUE(skipList.contains(3));
109     EXPECT_FALSE(skipList.contains(2));
110     EXPECT_EQ(3, *skipList.first());
111     EXPECT_EQ(3, *skipList.last());
112
113     EXPECT_EQ(3, *skipList.find(3));
114     EXPECT_FALSE(skipList.find(3) == skipList.end());
115     EXPECT_TRUE(skipList.find(2) == skipList.end());
116
117     {
118       SkipListAccessor::Skipper skipper(skipList);
119       skipper.to(3);
120       CHECK_EQ(3, *skipper);
121     }
122
123     skipList.add(2);
124     EXPECT_EQ(2, *skipList.first());
125     EXPECT_EQ(3, *skipList.last());
126     skipList.add(5);
127     EXPECT_EQ(5, *skipList.last());
128     skipList.add(3);
129     EXPECT_EQ(5, *skipList.last());
130     auto ret = skipList.insert(9);
131     EXPECT_EQ(9, *ret.first);
132     EXPECT_TRUE(ret.second);
133
134     ret = skipList.insert(5);
135     EXPECT_EQ(5, *ret.first);
136     EXPECT_FALSE(ret.second);
137
138     EXPECT_EQ(2, *skipList.first());
139     EXPECT_EQ(9, *skipList.last());
140     EXPECT_TRUE(skipList.pop_back());
141     EXPECT_EQ(5, *skipList.last());
142     EXPECT_TRUE(skipList.pop_back());
143     EXPECT_EQ(3, *skipList.last());
144
145     skipList.add(9);
146     skipList.add(5);
147
148     CHECK(skipList.contains(2));
149     CHECK(skipList.contains(3));
150     CHECK(skipList.contains(5));
151     CHECK(skipList.contains(9));
152     CHECK(!skipList.contains(4));
153
154     // lower_bound
155     auto it = skipList.lower_bound(5);
156     EXPECT_EQ(5, *it);
157     it = skipList.lower_bound(4);
158     EXPECT_EQ(5, *it);
159     it = skipList.lower_bound(9);
160     EXPECT_EQ(9, *it);
161     it = skipList.lower_bound(12);
162     EXPECT_FALSE(it.good());
163
164     it = skipList.begin();
165     EXPECT_EQ(2, *it);
166
167     // skipper test
168     SkipListAccessor::Skipper skipper(skipList);
169     skipper.to(3);
170     EXPECT_EQ(3, skipper.data());
171     skipper.to(5);
172     EXPECT_EQ(5, skipper.data());
173     CHECK(!skipper.to(7));
174
175     skipList.remove(5);
176     skipList.remove(3);
177     CHECK(skipper.to(9));
178     EXPECT_EQ(9, skipper.data());
179
180     CHECK(!skipList.contains(3));
181     skipList.add(3);
182     CHECK(skipList.contains(3));
183     int pos = 0;
184     FOR_EACH(it, skipList) {
185       LOG(INFO) << "pos= " << pos++ << " value= " << *it;
186     }
187   }
188
189   {
190     auto skipList(SkipListType::create(kHeadHeight));
191
192     SetType verifier;
193     randomAdding(10000, skipList, &verifier);
194     verifyEqual(skipList, verifier);
195
196     // test skipper
197     SkipListAccessor::Skipper skipper(skipList);
198     int num_skips = 1000;
199     for (int i = 0; i < num_skips; ++i) {
200       int n = i * kMaxValue / num_skips;
201       bool found = skipper.to(n);
202       EXPECT_EQ(found, (verifier.find(n) != verifier.end()));
203     }
204   }
205
206 }
207
208 void testConcurrentAdd(int numThreads) {
209   auto skipList(SkipListType::create(kHeadHeight));
210
211   vector<boost::thread> threads;
212   vector<SetType> verifiers(numThreads);
213   for (int i = 0; i < numThreads; ++i) {
214     threads.push_back(boost::thread(
215           &randomAdding, 100, skipList, &verifiers[i], kMaxValue));
216   }
217   for (int i = 0; i < threads.size(); ++i) {
218     threads[i].join();
219   }
220
221   SetType all;
222   FOR_EACH(s, verifiers) {
223     all.insert(s->begin(), s->end());
224   }
225   verifyEqual(skipList, all);
226 }
227
228 TEST(ConcurrentSkipList, ConcurrentAdd) {
229   // test it many times
230   for (int numThreads = 10; numThreads < 10000; numThreads += 1000) {
231     testConcurrentAdd(numThreads);
232   }
233 }
234
235 void testConcurrentRemoval(int numThreads, int maxValue) {
236   auto skipList = SkipListType::create(kHeadHeight);
237   for (int i = 0; i < maxValue; ++i) {
238     skipList.add(i);
239   }
240
241   vector<boost::thread> threads;
242   vector<SetType > verifiers(numThreads);
243   for (int i = 0; i < numThreads; ++i) {
244     threads.push_back(boost::thread(
245           &randomRemoval, 100, skipList, &verifiers[i], maxValue));
246   }
247   FOR_EACH(t, threads) {
248     (*t).join();
249   }
250
251   SetType all;
252   FOR_EACH(s, verifiers) {
253     all.insert(s->begin(), s->end());
254   }
255
256   CHECK_EQ(maxValue, all.size() + skipList.size());
257   for (int i = 0; i < maxValue; ++i) {
258     if (all.find(i) != all.end()) {
259       CHECK(!skipList.contains(i)) << i;
260     } else {
261       CHECK(skipList.contains(i)) << i;
262     }
263   }
264 }
265
266 TEST(ConcurrentSkipList, ConcurrentRemove) {
267   for (int numThreads = 10; numThreads < 1000; numThreads += 100) {
268     testConcurrentRemoval(numThreads, 100 * numThreads);
269   }
270 }
271
272 static void testConcurrentAccess(
273     int numInsertions, int numDeletions, int maxValue) {
274   auto skipList = SkipListType::create(kHeadHeight);
275
276   vector<SetType> verifiers(FLAGS_num_threads);
277   vector<int64_t> sums(FLAGS_num_threads);
278   vector<vector<ValueType> > skipValues(FLAGS_num_threads);
279
280   for (int i = 0; i < FLAGS_num_threads; ++i) {
281     for (int j = 0; j < numInsertions; ++j) {
282       skipValues[i].push_back(rand() % (maxValue + 1));
283     }
284     std::sort(skipValues[i].begin(), skipValues[i].end());
285   }
286
287   vector<boost::thread> threads;
288   for (int i = 0; i < FLAGS_num_threads; ++i) {
289     switch (i % 8) {
290       case 0:
291       case 1:
292         threads.push_back(boost::thread(
293               randomAdding, numInsertions, skipList, &verifiers[i], maxValue));
294         break;
295       case 2:
296         threads.push_back(boost::thread(
297               randomRemoval, numDeletions, skipList, &verifiers[i], maxValue));
298         break;
299       case 3:
300         threads.push_back(boost::thread(
301               concurrentSkip, &skipValues[i], skipList));
302         break;
303       default:
304         threads.push_back(boost::thread(sumAllValues, skipList, &sums[i]));
305         break;
306     }
307   }
308
309   FOR_EACH(t, threads) {
310     (*t).join();
311   }
312   // just run through it, no need to verify the correctness.
313 }
314
315 TEST(ConcurrentSkipList, ConcurrentAccess) {
316   testConcurrentAccess(10000, 100, kMaxValue);
317   testConcurrentAccess(100000, 10000, kMaxValue * 10);
318   testConcurrentAccess(1000000, 100000, kMaxValue);
319 }
320
321 }  // namespace
322
323 int main(int argc, char* argv[]) {
324   testing::InitGoogleTest(&argc, argv);
325   google::InitGoogleLogging(argv[0]);
326   google::ParseCommandLineFlags(&argc, &argv, true);
327
328   return RUN_ALL_TESTS();
329 }