Override for include-guard
[folly.git] / folly / test / ThreadLocalTest.cpp
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/ThreadLocal.h>
18
19 #include <sys/types.h>
20 #include <sys/wait.h>
21 #include <unistd.h>
22
23 #include <array>
24 #include <atomic>
25 #include <chrono>
26 #include <condition_variable>
27 #include <map>
28 #include <mutex>
29 #include <set>
30 #include <thread>
31 #include <unordered_map>
32
33 #include <boost/thread/tss.hpp>
34 #include <gflags/gflags.h>
35 #include <glog/logging.h>
36 #include <gtest/gtest.h>
37
38 #include <folly/Benchmark.h>
39
40 using namespace folly;
41
42 struct Widget {
43   static int totalVal_;
44   int val_;
45   ~Widget() {
46     totalVal_ += val_;
47   }
48
49   static void customDeleter(Widget* w, TLPDestructionMode mode) {
50     totalVal_ += (mode == TLPDestructionMode::ALL_THREADS) * 1000;
51     delete w;
52   }
53 };
54 int Widget::totalVal_ = 0;
55
56 TEST(ThreadLocalPtr, BasicDestructor) {
57   Widget::totalVal_ = 0;
58   ThreadLocalPtr<Widget> w;
59   std::thread([&w]() {
60       w.reset(new Widget());
61       w.get()->val_ += 10;
62     }).join();
63   EXPECT_EQ(10, Widget::totalVal_);
64 }
65
66 TEST(ThreadLocalPtr, CustomDeleter1) {
67   Widget::totalVal_ = 0;
68   {
69     ThreadLocalPtr<Widget> w;
70     std::thread([&w]() {
71         w.reset(new Widget(), Widget::customDeleter);
72         w.get()->val_ += 10;
73       }).join();
74     EXPECT_EQ(10, Widget::totalVal_);
75   }
76   EXPECT_EQ(10, Widget::totalVal_);
77 }
78
79 TEST(ThreadLocalPtr, resetNull) {
80   ThreadLocalPtr<int> tl;
81   EXPECT_FALSE(tl);
82   tl.reset(new int(4));
83   EXPECT_TRUE(static_cast<bool>(tl));
84   EXPECT_EQ(*tl.get(), 4);
85   tl.reset();
86   EXPECT_FALSE(tl);
87 }
88
89 TEST(ThreadLocalPtr, TestRelease) {
90   Widget::totalVal_ = 0;
91   ThreadLocalPtr<Widget> w;
92   std::unique_ptr<Widget> wPtr;
93   std::thread([&w, &wPtr]() {
94       w.reset(new Widget());
95       w.get()->val_ += 10;
96
97       wPtr.reset(w.release());
98     }).join();
99   EXPECT_EQ(0, Widget::totalVal_);
100   wPtr.reset();
101   EXPECT_EQ(10, Widget::totalVal_);
102 }
103
104 // Test deleting the ThreadLocalPtr object
105 TEST(ThreadLocalPtr, CustomDeleter2) {
106   Widget::totalVal_ = 0;
107   std::thread t;
108   std::mutex mutex;
109   std::condition_variable cv;
110   enum class State {
111     START,
112     DONE,
113     EXIT
114   };
115   State state = State::START;
116   {
117     ThreadLocalPtr<Widget> w;
118     t = std::thread([&]() {
119         w.reset(new Widget(), Widget::customDeleter);
120         w.get()->val_ += 10;
121
122         // Notify main thread that we're done
123         {
124           std::unique_lock<std::mutex> lock(mutex);
125           state = State::DONE;
126           cv.notify_all();
127         }
128
129         // Wait for main thread to allow us to exit
130         {
131           std::unique_lock<std::mutex> lock(mutex);
132           while (state != State::EXIT) {
133             cv.wait(lock);
134           }
135         }
136     });
137
138     // Wait for main thread to start (and set w.get()->val_)
139     {
140       std::unique_lock<std::mutex> lock(mutex);
141       while (state != State::DONE) {
142         cv.wait(lock);
143       }
144     }
145
146     // Thread started but hasn't exited yet
147     EXPECT_EQ(0, Widget::totalVal_);
148
149     // Destroy ThreadLocalPtr<Widget> (by letting it go out of scope)
150   }
151
152   EXPECT_EQ(1010, Widget::totalVal_);
153
154   // Allow thread to exit
155   {
156     std::unique_lock<std::mutex> lock(mutex);
157     state = State::EXIT;
158     cv.notify_all();
159   }
160   t.join();
161
162   EXPECT_EQ(1010, Widget::totalVal_);
163 }
164
165 TEST(ThreadLocal, BasicDestructor) {
166   Widget::totalVal_ = 0;
167   ThreadLocal<Widget> w;
168   std::thread([&w]() { w->val_ += 10; }).join();
169   EXPECT_EQ(10, Widget::totalVal_);
170 }
171
172 TEST(ThreadLocal, SimpleRepeatDestructor) {
173   Widget::totalVal_ = 0;
174   {
175     ThreadLocal<Widget> w;
176     w->val_ += 10;
177   }
178   {
179     ThreadLocal<Widget> w;
180     w->val_ += 10;
181   }
182   EXPECT_EQ(20, Widget::totalVal_);
183 }
184
185 TEST(ThreadLocal, InterleavedDestructors) {
186   Widget::totalVal_ = 0;
187   std::unique_ptr<ThreadLocal<Widget>> w;
188   int wVersion = 0;
189   const int wVersionMax = 2;
190   int thIter = 0;
191   std::mutex lock;
192   auto th = std::thread([&]() {
193     int wVersionPrev = 0;
194     while (true) {
195       while (true) {
196         std::lock_guard<std::mutex> g(lock);
197         if (wVersion > wVersionMax) {
198           return;
199         }
200         if (wVersion > wVersionPrev) {
201           // We have a new version of w, so it should be initialized to zero
202           EXPECT_EQ((*w)->val_, 0);
203           break;
204         }
205       }
206       std::lock_guard<std::mutex> g(lock);
207       wVersionPrev = wVersion;
208       (*w)->val_ += 10;
209       ++thIter;
210     }
211   });
212   FOR_EACH_RANGE(i, 0, wVersionMax) {
213     int thIterPrev = 0;
214     {
215       std::lock_guard<std::mutex> g(lock);
216       thIterPrev = thIter;
217       w.reset(new ThreadLocal<Widget>());
218       ++wVersion;
219     }
220     while (true) {
221       std::lock_guard<std::mutex> g(lock);
222       if (thIter > thIterPrev) {
223         break;
224       }
225     }
226   }
227   {
228     std::lock_guard<std::mutex> g(lock);
229     wVersion = wVersionMax + 1;
230   }
231   th.join();
232   EXPECT_EQ(wVersionMax * 10, Widget::totalVal_);
233 }
234
235 class SimpleThreadCachedInt {
236
237   class NewTag;
238   ThreadLocal<int,NewTag> val_;
239
240  public:
241   void add(int val) {
242     *val_ += val;
243   }
244
245   int read() {
246     int ret = 0;
247     for (const auto& i : val_.accessAllThreads()) {
248       ret += i;
249     }
250     return ret;
251   }
252 };
253
254 TEST(ThreadLocalPtr, AccessAllThreadsCounter) {
255   const int kNumThreads = 10;
256   SimpleThreadCachedInt stci;
257   std::atomic<bool> run(true);
258   std::atomic<int> totalAtomic(0);
259   std::vector<std::thread> threads;
260   for (int i = 0; i < kNumThreads; ++i) {
261     threads.push_back(std::thread([&,i]() {
262       stci.add(1);
263       totalAtomic.fetch_add(1);
264       while (run.load()) { usleep(100); }
265     }));
266   }
267   while (totalAtomic.load() != kNumThreads) { usleep(100); }
268   EXPECT_EQ(kNumThreads, stci.read());
269   run.store(false);
270   for (auto& t : threads) {
271     t.join();
272   }
273 }
274
275 TEST(ThreadLocal, resetNull) {
276   ThreadLocal<int> tl;
277   tl.reset(new int(4));
278   EXPECT_EQ(*tl.get(), 4);
279   tl.reset();
280   EXPECT_EQ(*tl.get(), 0);
281   tl.reset(new int(5));
282   EXPECT_EQ(*tl.get(), 5);
283 }
284
285 namespace {
286 struct Tag {};
287
288 struct Foo {
289   folly::ThreadLocal<int, Tag> tl;
290 };
291 }  // namespace
292
293 TEST(ThreadLocal, Movable1) {
294   Foo a;
295   Foo b;
296   EXPECT_TRUE(a.tl.get() != b.tl.get());
297
298   a = Foo();
299   b = Foo();
300   EXPECT_TRUE(a.tl.get() != b.tl.get());
301 }
302
303 TEST(ThreadLocal, Movable2) {
304   std::map<int, Foo> map;
305
306   map[42];
307   map[10];
308   map[23];
309   map[100];
310
311   std::set<void*> tls;
312   for (auto& m : map) {
313     tls.insert(m.second.tl.get());
314   }
315
316   // Make sure that we have 4 different instances of *tl
317   EXPECT_EQ(4, tls.size());
318 }
319
320 namespace {
321
322 constexpr size_t kFillObjectSize = 300;
323
324 std::atomic<uint64_t> gDestroyed;
325
326 /**
327  * Fill a chunk of memory with a unique-ish pattern that includes the thread id
328  * (so deleting one of these from another thread would cause a failure)
329  *
330  * Verify it explicitly and on destruction.
331  */
332 class FillObject {
333  public:
334   explicit FillObject(uint64_t idx) : idx_(idx) {
335     uint64_t v = val();
336     for (size_t i = 0; i < kFillObjectSize; ++i) {
337       data_[i] = v;
338     }
339   }
340
341   void check() {
342     uint64_t v = val();
343     for (size_t i = 0; i < kFillObjectSize; ++i) {
344       CHECK_EQ(v, data_[i]);
345     }
346   }
347
348   ~FillObject() {
349     ++gDestroyed;
350   }
351
352  private:
353   uint64_t val() const {
354     return (idx_ << 40) | uint64_t(pthread_self());
355   }
356
357   uint64_t idx_;
358   uint64_t data_[kFillObjectSize];
359 };
360
361 }  // namespace
362
363 #if FOLLY_HAVE_STD__THIS_THREAD__SLEEP_FOR
364 TEST(ThreadLocal, Stress) {
365   constexpr size_t numFillObjects = 250;
366   std::array<ThreadLocalPtr<FillObject>, numFillObjects> objects;
367
368   constexpr size_t numThreads = 32;
369   constexpr size_t numReps = 20;
370
371   std::vector<std::thread> threads;
372   threads.reserve(numThreads);
373
374   for (size_t i = 0; i < numThreads; ++i) {
375     threads.emplace_back([&objects] {
376       for (size_t rep = 0; rep < numReps; ++rep) {
377         for (size_t i = 0; i < objects.size(); ++i) {
378           objects[i].reset(new FillObject(rep * objects.size() + i));
379           std::this_thread::sleep_for(std::chrono::microseconds(100));
380         }
381         for (size_t i = 0; i < objects.size(); ++i) {
382           objects[i]->check();
383         }
384       }
385     });
386   }
387
388   for (auto& t : threads) {
389     t.join();
390   }
391
392   EXPECT_EQ(numFillObjects * numThreads * numReps, gDestroyed);
393 }
394 #endif
395
396 // Yes, threads and fork don't mix
397 // (http://cppwisdom.quora.com/Why-threads-and-fork-dont-mix) but if you're
398 // stupid or desperate enough to try, we shouldn't stand in your way.
399 namespace {
400 class HoldsOne {
401  public:
402   HoldsOne() : value_(1) { }
403   // Do an actual access to catch the buggy case where this == nullptr
404   int value() const { return value_; }
405  private:
406   int value_;
407 };
408
409 struct HoldsOneTag {};
410
411 ThreadLocal<HoldsOne, HoldsOneTag> ptr;
412
413 int totalValue() {
414   int value = 0;
415   for (auto& p : ptr.accessAllThreads()) {
416     value += p.value();
417   }
418   return value;
419 }
420
421 }  // namespace
422
423 TEST(ThreadLocal, Fork) {
424   EXPECT_EQ(1, ptr->value());  // ensure created
425   EXPECT_EQ(1, totalValue());
426   // Spawn a new thread
427
428   std::mutex mutex;
429   bool started = false;
430   std::condition_variable startedCond;
431   bool stopped = false;
432   std::condition_variable stoppedCond;
433
434   std::thread t([&] () {
435     EXPECT_EQ(1, ptr->value());  // ensure created
436     {
437       std::unique_lock<std::mutex> lock(mutex);
438       started = true;
439       startedCond.notify_all();
440     }
441     {
442       std::unique_lock<std::mutex> lock(mutex);
443       while (!stopped) {
444         stoppedCond.wait(lock);
445       }
446     }
447   });
448
449   {
450     std::unique_lock<std::mutex> lock(mutex);
451     while (!started) {
452       startedCond.wait(lock);
453     }
454   }
455
456   EXPECT_EQ(2, totalValue());
457
458   pid_t pid = fork();
459   if (pid == 0) {
460     // in child
461     int v = totalValue();
462
463     // exit successfully if v == 1 (one thread)
464     // diagnostic error code otherwise :)
465     switch (v) {
466     case 1: _exit(0);
467     case 0: _exit(1);
468     }
469     _exit(2);
470   } else if (pid > 0) {
471     // in parent
472     int status;
473     EXPECT_EQ(pid, waitpid(pid, &status, 0));
474     EXPECT_TRUE(WIFEXITED(status));
475     EXPECT_EQ(0, WEXITSTATUS(status));
476   } else {
477     EXPECT_TRUE(false) << "fork failed";
478   }
479
480   EXPECT_EQ(2, totalValue());
481
482   {
483     std::unique_lock<std::mutex> lock(mutex);
484     stopped = true;
485     stoppedCond.notify_all();
486   }
487
488   t.join();
489
490   EXPECT_EQ(1, totalValue());
491 }
492
493 struct HoldsOneTag2 {};
494
495 TEST(ThreadLocal, Fork2) {
496   // A thread-local tag that was used in the parent from a *different* thread
497   // (but not the forking thread) would cause the child to hang in a
498   // ThreadLocalPtr's object destructor. Yeah.
499   ThreadLocal<HoldsOne, HoldsOneTag2> p;
500   {
501     // use tag in different thread
502     std::thread t([&p] { p.get(); });
503     t.join();
504   }
505   pid_t pid = fork();
506   if (pid == 0) {
507     {
508       ThreadLocal<HoldsOne, HoldsOneTag2> q;
509       q.get();
510     }
511     _exit(0);
512   } else if (pid > 0) {
513     int status;
514     EXPECT_EQ(pid, waitpid(pid, &status, 0));
515     EXPECT_TRUE(WIFEXITED(status));
516     EXPECT_EQ(0, WEXITSTATUS(status));
517   } else {
518     EXPECT_TRUE(false) << "fork failed";
519   }
520 }
521
522 // Simple reference implementation using pthread_get_specific
523 template<typename T>
524 class PThreadGetSpecific {
525  public:
526   PThreadGetSpecific() : key_(0) {
527     pthread_key_create(&key_, OnThreadExit);
528   }
529
530   T* get() const {
531     return static_cast<T*>(pthread_getspecific(key_));
532   }
533
534   void reset(T* t) {
535     delete get();
536     pthread_setspecific(key_, t);
537   }
538   static void OnThreadExit(void* obj) {
539     delete static_cast<T*>(obj);
540   }
541  private:
542   pthread_key_t key_;
543 };
544
545 DEFINE_int32(numThreads, 8, "Number simultaneous threads for benchmarks.");
546
547 #define REG(var)                                                \
548   BENCHMARK(FB_CONCATENATE(BM_mt_, var), iters) {               \
549     const int itersPerThread = iters / FLAGS_numThreads;        \
550     std::vector<std::thread> threads;                           \
551     for (int i = 0; i < FLAGS_numThreads; ++i) {                \
552       threads.push_back(std::thread([&]() {                     \
553         var.reset(new int(0));                                  \
554         for (int i = 0; i < itersPerThread; ++i) {              \
555           ++(*var.get());                                       \
556         }                                                       \
557       }));                                                      \
558     }                                                           \
559     for (auto& t : threads) {                                   \
560       t.join();                                                 \
561     }                                                           \
562   }
563
564 ThreadLocalPtr<int> tlp;
565 REG(tlp);
566 PThreadGetSpecific<int> pthread_get_specific;
567 REG(pthread_get_specific);
568 boost::thread_specific_ptr<int> boost_tsp;
569 REG(boost_tsp);
570 BENCHMARK_DRAW_LINE();
571
572 int main(int argc, char** argv) {
573   testing::InitGoogleTest(&argc, argv);
574   gflags::ParseCommandLineFlags(&argc, &argv, true);
575   gflags::SetCommandLineOptionWithMode(
576     "bm_max_iters", "100000000", gflags::SET_FLAG_IF_DEFAULT
577   );
578   if (FLAGS_benchmark) {
579     folly::runBenchmarks();
580   }
581   return RUN_ALL_TESTS();
582 }
583
584 /*
585 Ran with 24 threads on dual 12-core Xeon(R) X5650 @ 2.67GHz with 12-MB caches
586
587 Benchmark                               Iters   Total t    t/iter iter/sec
588 ------------------------------------------------------------------------------
589 *       BM_mt_tlp                   100000000  39.88 ms  398.8 ps  2.335 G
590  +5.91% BM_mt_pthread_get_specific  100000000  42.23 ms  422.3 ps  2.205 G
591  + 295% BM_mt_boost_tsp             100000000  157.8 ms  1.578 ns  604.5 M
592 ------------------------------------------------------------------------------
593 */