Copyright 2012 -> 2013
[folly.git] / folly / test / ThreadLocalTest.cpp
1 /*
2  * Copyright 2013 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "folly/ThreadLocal.h"
18
19 #include <map>
20 #include <unordered_map>
21 #include <set>
22 #include <atomic>
23 #include <mutex>
24 #include <condition_variable>
25 #include <thread>
26 #include <boost/thread/tss.hpp>
27 #include <gtest/gtest.h>
28 #include <gflags/gflags.h>
29 #include <glog/logging.h>
30 #include "folly/Benchmark.h"
31
32 using namespace folly;
33
34 struct Widget {
35   static int totalVal_;
36   int val_;
37   ~Widget() {
38     totalVal_ += val_;
39   }
40
41   static void customDeleter(Widget* w, TLPDestructionMode mode) {
42     totalVal_ += (mode == TLPDestructionMode::ALL_THREADS) * 1000;
43     delete w;
44   }
45 };
46 int Widget::totalVal_ = 0;
47
48 TEST(ThreadLocalPtr, BasicDestructor) {
49   Widget::totalVal_ = 0;
50   ThreadLocalPtr<Widget> w;
51   std::thread([&w]() {
52       w.reset(new Widget());
53       w.get()->val_ += 10;
54     }).join();
55   EXPECT_EQ(10, Widget::totalVal_);
56 }
57
58 TEST(ThreadLocalPtr, CustomDeleter1) {
59   Widget::totalVal_ = 0;
60   {
61     ThreadLocalPtr<Widget> w;
62     std::thread([&w]() {
63         w.reset(new Widget(), Widget::customDeleter);
64         w.get()->val_ += 10;
65       }).join();
66     EXPECT_EQ(10, Widget::totalVal_);
67   }
68   EXPECT_EQ(10, Widget::totalVal_);
69 }
70
71 TEST(ThreadLocalPtr, resetNull) {
72   ThreadLocalPtr<int> tl;
73   EXPECT_FALSE(tl);
74   tl.reset(new int(4));
75   EXPECT_TRUE(static_cast<bool>(tl));
76   EXPECT_EQ(*tl.get(), 4);
77   tl.reset();
78   EXPECT_FALSE(tl);
79 }
80
81 // Test deleting the ThreadLocalPtr object
82 TEST(ThreadLocalPtr, CustomDeleter2) {
83   Widget::totalVal_ = 0;
84   std::thread t;
85   std::mutex mutex;
86   std::condition_variable cv;
87   enum class State {
88     START,
89     DONE,
90     EXIT
91   };
92   State state = State::START;
93   {
94     ThreadLocalPtr<Widget> w;
95     t = std::thread([&]() {
96         w.reset(new Widget(), Widget::customDeleter);
97         w.get()->val_ += 10;
98
99         // Notify main thread that we're done
100         {
101           std::unique_lock<std::mutex> lock(mutex);
102           state = State::DONE;
103           cv.notify_all();
104         }
105
106         // Wait for main thread to allow us to exit
107         {
108           std::unique_lock<std::mutex> lock(mutex);
109           while (state != State::EXIT) {
110             cv.wait(lock);
111           }
112         }
113     });
114
115     // Wait for main thread to start (and set w.get()->val_)
116     {
117       std::unique_lock<std::mutex> lock(mutex);
118       while (state != State::DONE) {
119         cv.wait(lock);
120       }
121     }
122
123     // Thread started but hasn't exited yet
124     EXPECT_EQ(0, Widget::totalVal_);
125
126     // Destroy ThreadLocalPtr<Widget> (by letting it go out of scope)
127   }
128
129   EXPECT_EQ(1010, Widget::totalVal_);
130
131   // Allow thread to exit
132   {
133     std::unique_lock<std::mutex> lock(mutex);
134     state = State::EXIT;
135     cv.notify_all();
136   }
137   t.join();
138
139   EXPECT_EQ(1010, Widget::totalVal_);
140 }
141
142 TEST(ThreadLocal, BasicDestructor) {
143   Widget::totalVal_ = 0;
144   ThreadLocal<Widget> w;
145   std::thread([&w]() { w->val_ += 10; }).join();
146   EXPECT_EQ(10, Widget::totalVal_);
147 }
148
149 TEST(ThreadLocal, SimpleRepeatDestructor) {
150   Widget::totalVal_ = 0;
151   {
152     ThreadLocal<Widget> w;
153     w->val_ += 10;
154   }
155   {
156     ThreadLocal<Widget> w;
157     w->val_ += 10;
158   }
159   EXPECT_EQ(20, Widget::totalVal_);
160 }
161
162 TEST(ThreadLocal, InterleavedDestructors) {
163   Widget::totalVal_ = 0;
164   ThreadLocal<Widget>* w = NULL;
165   int wVersion = 0;
166   const int wVersionMax = 2;
167   int thIter = 0;
168   std::mutex lock;
169   auto th = std::thread([&]() {
170     int wVersionPrev = 0;
171     while (true) {
172       while (true) {
173         std::lock_guard<std::mutex> g(lock);
174         if (wVersion > wVersionMax) {
175           return;
176         }
177         if (wVersion > wVersionPrev) {
178           // We have a new version of w, so it should be initialized to zero
179           EXPECT_EQ((*w)->val_, 0);
180           break;
181         }
182       }
183       std::lock_guard<std::mutex> g(lock);
184       wVersionPrev = wVersion;
185       (*w)->val_ += 10;
186       ++thIter;
187     }
188   });
189   FOR_EACH_RANGE(i, 0, wVersionMax) {
190     int thIterPrev = 0;
191     {
192       std::lock_guard<std::mutex> g(lock);
193       thIterPrev = thIter;
194       delete w;
195       w = new ThreadLocal<Widget>();
196       ++wVersion;
197     }
198     while (true) {
199       std::lock_guard<std::mutex> g(lock);
200       if (thIter > thIterPrev) {
201         break;
202       }
203     }
204   }
205   {
206     std::lock_guard<std::mutex> g(lock);
207     wVersion = wVersionMax + 1;
208   }
209   th.join();
210   EXPECT_EQ(wVersionMax * 10, Widget::totalVal_);
211 }
212
213 class SimpleThreadCachedInt {
214
215   class NewTag;
216   ThreadLocal<int,NewTag> val_;
217
218  public:
219   void add(int val) {
220     *val_ += val;
221   }
222
223   int read() {
224     int ret = 0;
225     for (const auto& i : val_.accessAllThreads()) {
226       ret += i;
227     }
228     return ret;
229   }
230 };
231
232 TEST(ThreadLocalPtr, AccessAllThreadsCounter) {
233   const int kNumThreads = 10;
234   SimpleThreadCachedInt stci;
235   std::atomic<bool> run(true);
236   std::atomic<int> totalAtomic(0);
237   std::vector<std::thread> threads;
238   for (int i = 0; i < kNumThreads; ++i) {
239     threads.push_back(std::thread([&,i]() {
240       stci.add(1);
241       totalAtomic.fetch_add(1);
242       while (run.load()) { usleep(100); }
243     }));
244   }
245   while (totalAtomic.load() != kNumThreads) { usleep(100); }
246   EXPECT_EQ(kNumThreads, stci.read());
247   run.store(false);
248   for (auto& t : threads) {
249     t.join();
250   }
251 }
252
253 TEST(ThreadLocal, resetNull) {
254   ThreadLocal<int> tl;
255   tl.reset(new int(4));
256   EXPECT_EQ(*tl.get(), 4);
257   tl.reset();
258   EXPECT_EQ(*tl.get(), 0);
259   tl.reset(new int(5));
260   EXPECT_EQ(*tl.get(), 5);
261 }
262
263 namespace {
264 struct Tag {};
265
266 struct Foo {
267   folly::ThreadLocal<int, Tag> tl;
268 };
269 }  // namespace
270
271 TEST(ThreadLocal, Movable1) {
272   Foo a;
273   Foo b;
274   EXPECT_TRUE(a.tl.get() != b.tl.get());
275
276   a = Foo();
277   b = Foo();
278   EXPECT_TRUE(a.tl.get() != b.tl.get());
279 }
280
281 TEST(ThreadLocal, Movable2) {
282   std::map<int, Foo> map;
283
284   map[42];
285   map[10];
286   map[23];
287   map[100];
288
289   std::set<void*> tls;
290   for (auto& m : map) {
291     tls.insert(m.second.tl.get());
292   }
293
294   // Make sure that we have 4 different instances of *tl
295   EXPECT_EQ(4, tls.size());
296 }
297
298 // Simple reference implementation using pthread_get_specific
299 template<typename T>
300 class PThreadGetSpecific {
301  public:
302   PThreadGetSpecific() : key_(0) {
303     pthread_key_create(&key_, OnThreadExit);
304   }
305
306   T* get() const {
307     return static_cast<T*>(pthread_getspecific(key_));
308   }
309
310   void reset(T* t) {
311     delete get();
312     pthread_setspecific(key_, t);
313   }
314   static void OnThreadExit(void* obj) {
315     delete static_cast<T*>(obj);
316   }
317  private:
318   pthread_key_t key_;
319 };
320
321 DEFINE_int32(numThreads, 8, "Number simultaneous threads for benchmarks.");
322
323 #define REG(var)                                                \
324   BENCHMARK(FB_CONCATENATE(BM_mt_, var), iters) {               \
325     const int itersPerThread = iters / FLAGS_numThreads;        \
326     std::vector<std::thread> threads;                           \
327     for (int i = 0; i < FLAGS_numThreads; ++i) {                \
328       threads.push_back(std::thread([&]() {                     \
329         var.reset(new int(0));                                  \
330         for (int i = 0; i < itersPerThread; ++i) {              \
331           ++(*var.get());                                       \
332         }                                                       \
333       }));                                                      \
334     }                                                           \
335     for (auto& t : threads) {                                   \
336       t.join();                                                 \
337     }                                                           \
338   }
339
340 ThreadLocalPtr<int> tlp;
341 REG(tlp);
342 PThreadGetSpecific<int> pthread_get_specific;
343 REG(pthread_get_specific);
344 boost::thread_specific_ptr<int> boost_tsp;
345 REG(boost_tsp);
346 BENCHMARK_DRAW_LINE();
347
348 int main(int argc, char** argv) {
349   testing::InitGoogleTest(&argc, argv);
350   google::ParseCommandLineFlags(&argc, &argv, true);
351   google::SetCommandLineOptionWithMode(
352     "bm_max_iters", "100000000", google::SET_FLAG_IF_DEFAULT
353   );
354   if (FLAGS_benchmark) {
355     folly::runBenchmarks();
356   }
357   return RUN_ALL_TESTS();
358 }
359
360 /*
361 Ran with 24 threads on dual 12-core Xeon(R) X5650 @ 2.67GHz with 12-MB caches
362
363 Benchmark                               Iters   Total t    t/iter iter/sec
364 ------------------------------------------------------------------------------
365 *       BM_mt_tlp                   100000000  39.88 ms  398.8 ps  2.335 G
366  +5.91% BM_mt_pthread_get_specific  100000000  42.23 ms  422.3 ps  2.205 G
367  + 295% BM_mt_boost_tsp             100000000  157.8 ms  1.578 ns  604.5 M
368 ------------------------------------------------------------------------------
369 */