Pull from FB rev 63ce89e2f2301e6bba44a111cc7d4218022156f6
[folly.git] / folly / test / ThreadLocalTest.cpp
1 /*
2  * Copyright 2012 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "folly/ThreadLocal.h"
18
19 #include <map>
20 #include <unordered_map>
21 #include <set>
22 #include <atomic>
23 #include <mutex>
24 #include <condition_variable>
25 #include <thread>
26 #include <boost/thread/tss.hpp>
27 #include <gtest/gtest.h>
28 #include <gflags/gflags.h>
29 #include <glog/logging.h>
30 #include "folly/Benchmark.h"
31
32 using namespace folly;
33
34 struct Widget {
35   static int totalVal_;
36   int val_;
37   ~Widget() {
38     totalVal_ += val_;
39   }
40
41   static void customDeleter(Widget* w, TLPDestructionMode mode) {
42     totalVal_ += (mode == TLPDestructionMode::ALL_THREADS) * 1000;
43     delete w;
44   }
45 };
46 int Widget::totalVal_ = 0;
47
48 TEST(ThreadLocalPtr, BasicDestructor) {
49   Widget::totalVal_ = 0;
50   ThreadLocalPtr<Widget> w;
51   std::thread([&w]() {
52       w.reset(new Widget());
53       w.get()->val_ += 10;
54     }).join();
55   EXPECT_EQ(10, Widget::totalVal_);
56 }
57
58 TEST(ThreadLocalPtr, CustomDeleter1) {
59   Widget::totalVal_ = 0;
60   {
61     ThreadLocalPtr<Widget> w;
62     std::thread([&w]() {
63         w.reset(new Widget(), Widget::customDeleter);
64         w.get()->val_ += 10;
65       }).join();
66     EXPECT_EQ(10, Widget::totalVal_);
67   }
68   EXPECT_EQ(10, Widget::totalVal_);
69 }
70
71 // Test deleting the ThreadLocalPtr object
72 TEST(ThreadLocalPtr, CustomDeleter2) {
73   Widget::totalVal_ = 0;
74   std::thread t;
75   std::mutex mutex;
76   std::condition_variable cv;
77   enum class State {
78     START,
79     DONE,
80     EXIT
81   };
82   State state = State::START;
83   {
84     ThreadLocalPtr<Widget> w;
85     t = std::thread([&]() {
86         w.reset(new Widget(), Widget::customDeleter);
87         w.get()->val_ += 10;
88
89         // Notify main thread that we're done
90         {
91           std::unique_lock<std::mutex> lock(mutex);
92           state = State::DONE;
93           cv.notify_all();
94         }
95
96         // Wait for main thread to allow us to exit
97         {
98           std::unique_lock<std::mutex> lock(mutex);
99           while (state != State::EXIT) {
100             cv.wait(lock);
101           }
102         }
103     });
104
105     // Wait for main thread to start (and set w.get()->val_)
106     {
107       std::unique_lock<std::mutex> lock(mutex);
108       while (state != State::DONE) {
109         cv.wait(lock);
110       }
111     }
112
113     // Thread started but hasn't exited yet
114     EXPECT_EQ(0, Widget::totalVal_);
115
116     // Destroy ThreadLocalPtr<Widget> (by letting it go out of scope)
117   }
118
119   EXPECT_EQ(1010, Widget::totalVal_);
120
121   // Allow thread to exit
122   {
123     std::unique_lock<std::mutex> lock(mutex);
124     state = State::EXIT;
125     cv.notify_all();
126   }
127   t.join();
128
129   EXPECT_EQ(1010, Widget::totalVal_);
130 }
131
132 TEST(ThreadLocal, BasicDestructor) {
133   Widget::totalVal_ = 0;
134   ThreadLocal<Widget> w;
135   std::thread([&w]() { w->val_ += 10; }).join();
136   EXPECT_EQ(10, Widget::totalVal_);
137 }
138
139 TEST(ThreadLocal, SimpleRepeatDestructor) {
140   Widget::totalVal_ = 0;
141   {
142     ThreadLocal<Widget> w;
143     w->val_ += 10;
144   }
145   {
146     ThreadLocal<Widget> w;
147     w->val_ += 10;
148   }
149   EXPECT_EQ(20, Widget::totalVal_);
150 }
151
152 TEST(ThreadLocal, InterleavedDestructors) {
153   Widget::totalVal_ = 0;
154   ThreadLocal<Widget>* w = NULL;
155   int wVersion = 0;
156   const int wVersionMax = 2;
157   int thIter = 0;
158   std::mutex lock;
159   auto th = std::thread([&]() {
160     int wVersionPrev = 0;
161     while (true) {
162       while (true) {
163         std::lock_guard<std::mutex> g(lock);
164         if (wVersion > wVersionMax) {
165           return;
166         }
167         if (wVersion > wVersionPrev) {
168           // We have a new version of w, so it should be initialized to zero
169           EXPECT_EQ((*w)->val_, 0);
170           break;
171         }
172       }
173       std::lock_guard<std::mutex> g(lock);
174       wVersionPrev = wVersion;
175       (*w)->val_ += 10;
176       ++thIter;
177     }
178   });
179   FOR_EACH_RANGE(i, 0, wVersionMax) {
180     int thIterPrev = 0;
181     {
182       std::lock_guard<std::mutex> g(lock);
183       thIterPrev = thIter;
184       delete w;
185       w = new ThreadLocal<Widget>();
186       ++wVersion;
187     }
188     while (true) {
189       std::lock_guard<std::mutex> g(lock);
190       if (thIter > thIterPrev) {
191         break;
192       }
193     }
194   }
195   {
196     std::lock_guard<std::mutex> g(lock);
197     wVersion = wVersionMax + 1;
198   }
199   th.join();
200   EXPECT_EQ(wVersionMax * 10, Widget::totalVal_);
201 }
202
203 class SimpleThreadCachedInt {
204
205   class NewTag;
206   ThreadLocal<int,NewTag> val_;
207
208  public:
209   void add(int val) {
210     *val_ += val;
211   }
212
213   int read() {
214     int ret = 0;
215     for (const auto& i : val_.accessAllThreads()) {
216       ret += i;
217     }
218     return ret;
219   }
220 };
221
222 TEST(ThreadLocalPtr, AccessAllThreadsCounter) {
223   const int kNumThreads = 10;
224   SimpleThreadCachedInt stci;
225   std::atomic<bool> run(true);
226   std::atomic<int> totalAtomic(0);
227   std::vector<std::thread> threads;
228   for (int i = 0; i < kNumThreads; ++i) {
229     threads.push_back(std::thread([&,i]() {
230       stci.add(1);
231       totalAtomic.fetch_add(1);
232       while (run.load()) { usleep(100); }
233     }));
234   }
235   while (totalAtomic.load() != kNumThreads) { usleep(100); }
236   EXPECT_EQ(kNumThreads, stci.read());
237   run.store(false);
238   for (auto& t : threads) {
239     t.join();
240   }
241 }
242
243 TEST(ThreadLocal, resetNull) {
244   ThreadLocal<int> tl;
245   tl.reset(new int(4));
246   EXPECT_EQ(*tl.get(), 4);
247   tl.reset();
248   EXPECT_EQ(*tl.get(), 0);
249   tl.reset(new int(5));
250   EXPECT_EQ(*tl.get(), 5);
251 }
252
253 namespace {
254 struct Tag {};
255
256 struct Foo {
257   folly::ThreadLocal<int, Tag> tl;
258 };
259 }  // namespace
260
261 TEST(ThreadLocal, Movable1) {
262   Foo a;
263   Foo b;
264   EXPECT_TRUE(a.tl.get() != b.tl.get());
265
266   a = Foo();
267   b = Foo();
268   EXPECT_TRUE(a.tl.get() != b.tl.get());
269 }
270
271 TEST(ThreadLocal, Movable2) {
272   std::map<int, Foo> map;
273
274   map[42];
275   map[10];
276   map[23];
277   map[100];
278
279   std::set<void*> tls;
280   for (auto& m : map) {
281     tls.insert(m.second.tl.get());
282   }
283
284   // Make sure that we have 4 different instances of *tl
285   EXPECT_EQ(4, tls.size());
286 }
287
288 // Simple reference implementation using pthread_get_specific
289 template<typename T>
290 class PThreadGetSpecific {
291  public:
292   PThreadGetSpecific() : key_(0) {
293     pthread_key_create(&key_, OnThreadExit);
294   }
295
296   T* get() const {
297     return static_cast<T*>(pthread_getspecific(key_));
298   }
299
300   void reset(T* t) {
301     delete get();
302     pthread_setspecific(key_, t);
303   }
304   static void OnThreadExit(void* obj) {
305     delete static_cast<T*>(obj);
306   }
307  private:
308   pthread_key_t key_;
309 };
310
311 DEFINE_int32(numThreads, 8, "Number simultaneous threads for benchmarks.");
312
313 #define REG(var)                                                \
314   BENCHMARK(FB_CONCATENATE(BM_mt_, var), iters) {               \
315     const int itersPerThread = iters / FLAGS_numThreads;        \
316     std::vector<std::thread> threads;                           \
317     for (int i = 0; i < FLAGS_numThreads; ++i) {                \
318       threads.push_back(std::thread([&]() {                     \
319         var.reset(new int(0));                                  \
320         for (int i = 0; i < itersPerThread; ++i) {              \
321           ++(*var.get());                                       \
322         }                                                       \
323       }));                                                      \
324     }                                                           \
325     for (auto& t : threads) {                                   \
326       t.join();                                                 \
327     }                                                           \
328   }
329
330 ThreadLocalPtr<int> tlp;
331 REG(tlp);
332 PThreadGetSpecific<int> pthread_get_specific;
333 REG(pthread_get_specific);
334 boost::thread_specific_ptr<int> boost_tsp;
335 REG(boost_tsp);
336 BENCHMARK_DRAW_LINE();
337
338 int main(int argc, char** argv) {
339   testing::InitGoogleTest(&argc, argv);
340   google::ParseCommandLineFlags(&argc, &argv, true);
341   google::SetCommandLineOptionWithMode(
342     "bm_max_iters", "100000000", google::SET_FLAG_IF_DEFAULT
343   );
344   if (FLAGS_benchmark) {
345     folly::runBenchmarks();
346   }
347   return RUN_ALL_TESTS();
348 }
349
350 /*
351 Ran with 24 threads on dual 12-core Xeon(R) X5650 @ 2.67GHz with 12-MB caches
352
353 Benchmark                               Iters   Total t    t/iter iter/sec
354 ------------------------------------------------------------------------------
355 *       BM_mt_tlp                   100000000  39.88 ms  398.8 ps  2.335 G
356  +5.91% BM_mt_pthread_get_specific  100000000  42.23 ms  422.3 ps  2.205 G
357  + 295% BM_mt_boost_tsp             100000000  157.8 ms  1.578 ns  604.5 M
358 ------------------------------------------------------------------------------
359 */