folly copyright 2015 -> copyright 2016
[folly.git] / folly / test / ThreadLocalTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/ThreadLocal.h>
18
19 #include <dlfcn.h>
20 #include <sys/types.h>
21 #include <sys/wait.h>
22 #include <unistd.h>
23
24 #include <array>
25 #include <atomic>
26 #include <chrono>
27 #include <condition_variable>
28 #include <limits.h>
29 #include <map>
30 #include <mutex>
31 #include <set>
32 #include <thread>
33 #include <unordered_map>
34
35 #include <boost/thread/tss.hpp>
36 #include <gflags/gflags.h>
37 #include <glog/logging.h>
38 #include <gtest/gtest.h>
39
40 #include <folly/Benchmark.h>
41 #include <folly/Baton.h>
42 #include <folly/experimental/io/FsUtil.h>
43
44 using namespace folly;
45
46 struct Widget {
47   static int totalVal_;
48   int val_;
49   ~Widget() {
50     totalVal_ += val_;
51   }
52
53   static void customDeleter(Widget* w, TLPDestructionMode mode) {
54     totalVal_ += (mode == TLPDestructionMode::ALL_THREADS) * 1000;
55     delete w;
56   }
57 };
58 int Widget::totalVal_ = 0;
59
60 TEST(ThreadLocalPtr, BasicDestructor) {
61   Widget::totalVal_ = 0;
62   ThreadLocalPtr<Widget> w;
63   std::thread([&w]() {
64       w.reset(new Widget());
65       w.get()->val_ += 10;
66     }).join();
67   EXPECT_EQ(10, Widget::totalVal_);
68 }
69
70 TEST(ThreadLocalPtr, CustomDeleter1) {
71   Widget::totalVal_ = 0;
72   {
73     ThreadLocalPtr<Widget> w;
74     std::thread([&w]() {
75         w.reset(new Widget(), Widget::customDeleter);
76         w.get()->val_ += 10;
77       }).join();
78     EXPECT_EQ(10, Widget::totalVal_);
79   }
80   EXPECT_EQ(10, Widget::totalVal_);
81 }
82
83 TEST(ThreadLocalPtr, resetNull) {
84   ThreadLocalPtr<int> tl;
85   EXPECT_FALSE(tl);
86   tl.reset(new int(4));
87   EXPECT_TRUE(static_cast<bool>(tl));
88   EXPECT_EQ(*tl.get(), 4);
89   tl.reset();
90   EXPECT_FALSE(tl);
91 }
92
93 TEST(ThreadLocalPtr, TestRelease) {
94   Widget::totalVal_ = 0;
95   ThreadLocalPtr<Widget> w;
96   std::unique_ptr<Widget> wPtr;
97   std::thread([&w, &wPtr]() {
98       w.reset(new Widget());
99       w.get()->val_ += 10;
100
101       wPtr.reset(w.release());
102     }).join();
103   EXPECT_EQ(0, Widget::totalVal_);
104   wPtr.reset();
105   EXPECT_EQ(10, Widget::totalVal_);
106 }
107
108 TEST(ThreadLocalPtr, CreateOnThreadExit) {
109   Widget::totalVal_ = 0;
110   ThreadLocal<Widget> w;
111   ThreadLocalPtr<int> tl;
112
113   std::thread([&] {
114     tl.reset(new int(1),
115              [&](int* ptr, TLPDestructionMode /* mode */) {
116                delete ptr;
117                // This test ensures Widgets allocated here are not leaked.
118                ++w.get()->val_;
119                ThreadLocal<Widget> wl;
120                ++wl.get()->val_;
121              });
122   }).join();
123   EXPECT_EQ(2, Widget::totalVal_);
124 }
125
126 // Test deleting the ThreadLocalPtr object
127 TEST(ThreadLocalPtr, CustomDeleter2) {
128   Widget::totalVal_ = 0;
129   std::thread t;
130   std::mutex mutex;
131   std::condition_variable cv;
132   enum class State {
133     START,
134     DONE,
135     EXIT
136   };
137   State state = State::START;
138   {
139     ThreadLocalPtr<Widget> w;
140     t = std::thread([&]() {
141         w.reset(new Widget(), Widget::customDeleter);
142         w.get()->val_ += 10;
143
144         // Notify main thread that we're done
145         {
146           std::unique_lock<std::mutex> lock(mutex);
147           state = State::DONE;
148           cv.notify_all();
149         }
150
151         // Wait for main thread to allow us to exit
152         {
153           std::unique_lock<std::mutex> lock(mutex);
154           while (state != State::EXIT) {
155             cv.wait(lock);
156           }
157         }
158     });
159
160     // Wait for main thread to start (and set w.get()->val_)
161     {
162       std::unique_lock<std::mutex> lock(mutex);
163       while (state != State::DONE) {
164         cv.wait(lock);
165       }
166     }
167
168     // Thread started but hasn't exited yet
169     EXPECT_EQ(0, Widget::totalVal_);
170
171     // Destroy ThreadLocalPtr<Widget> (by letting it go out of scope)
172   }
173
174   EXPECT_EQ(1010, Widget::totalVal_);
175
176   // Allow thread to exit
177   {
178     std::unique_lock<std::mutex> lock(mutex);
179     state = State::EXIT;
180     cv.notify_all();
181   }
182   t.join();
183
184   EXPECT_EQ(1010, Widget::totalVal_);
185 }
186
187 TEST(ThreadLocal, BasicDestructor) {
188   Widget::totalVal_ = 0;
189   ThreadLocal<Widget> w;
190   std::thread([&w]() { w->val_ += 10; }).join();
191   EXPECT_EQ(10, Widget::totalVal_);
192 }
193
194 TEST(ThreadLocal, SimpleRepeatDestructor) {
195   Widget::totalVal_ = 0;
196   {
197     ThreadLocal<Widget> w;
198     w->val_ += 10;
199   }
200   {
201     ThreadLocal<Widget> w;
202     w->val_ += 10;
203   }
204   EXPECT_EQ(20, Widget::totalVal_);
205 }
206
207 TEST(ThreadLocal, InterleavedDestructors) {
208   Widget::totalVal_ = 0;
209   std::unique_ptr<ThreadLocal<Widget>> w;
210   int wVersion = 0;
211   const int wVersionMax = 2;
212   int thIter = 0;
213   std::mutex lock;
214   auto th = std::thread([&]() {
215     int wVersionPrev = 0;
216     while (true) {
217       while (true) {
218         std::lock_guard<std::mutex> g(lock);
219         if (wVersion > wVersionMax) {
220           return;
221         }
222         if (wVersion > wVersionPrev) {
223           // We have a new version of w, so it should be initialized to zero
224           EXPECT_EQ((*w)->val_, 0);
225           break;
226         }
227       }
228       std::lock_guard<std::mutex> g(lock);
229       wVersionPrev = wVersion;
230       (*w)->val_ += 10;
231       ++thIter;
232     }
233   });
234   FOR_EACH_RANGE(i, 0, wVersionMax) {
235     int thIterPrev = 0;
236     {
237       std::lock_guard<std::mutex> g(lock);
238       thIterPrev = thIter;
239       w.reset(new ThreadLocal<Widget>());
240       ++wVersion;
241     }
242     while (true) {
243       std::lock_guard<std::mutex> g(lock);
244       if (thIter > thIterPrev) {
245         break;
246       }
247     }
248   }
249   {
250     std::lock_guard<std::mutex> g(lock);
251     wVersion = wVersionMax + 1;
252   }
253   th.join();
254   EXPECT_EQ(wVersionMax * 10, Widget::totalVal_);
255 }
256
257 class SimpleThreadCachedInt {
258
259   class NewTag;
260   ThreadLocal<int,NewTag> val_;
261
262  public:
263   void add(int val) {
264     *val_ += val;
265   }
266
267   int read() {
268     int ret = 0;
269     for (const auto& i : val_.accessAllThreads()) {
270       ret += i;
271     }
272     return ret;
273   }
274 };
275
276 TEST(ThreadLocalPtr, AccessAllThreadsCounter) {
277   const int kNumThreads = 10;
278   SimpleThreadCachedInt stci;
279   std::atomic<bool> run(true);
280   std::atomic<int> totalAtomic(0);
281   std::vector<std::thread> threads;
282   for (int i = 0; i < kNumThreads; ++i) {
283     threads.push_back(std::thread([&,i]() {
284       stci.add(1);
285       totalAtomic.fetch_add(1);
286       while (run.load()) { usleep(100); }
287     }));
288   }
289   while (totalAtomic.load() != kNumThreads) { usleep(100); }
290   EXPECT_EQ(kNumThreads, stci.read());
291   run.store(false);
292   for (auto& t : threads) {
293     t.join();
294   }
295 }
296
297 TEST(ThreadLocal, resetNull) {
298   ThreadLocal<int> tl;
299   tl.reset(new int(4));
300   EXPECT_EQ(*tl.get(), 4);
301   tl.reset();
302   EXPECT_EQ(*tl.get(), 0);
303   tl.reset(new int(5));
304   EXPECT_EQ(*tl.get(), 5);
305 }
306
307 namespace {
308 struct Tag {};
309
310 struct Foo {
311   folly::ThreadLocal<int, Tag> tl;
312 };
313 }  // namespace
314
315 TEST(ThreadLocal, Movable1) {
316   Foo a;
317   Foo b;
318   EXPECT_TRUE(a.tl.get() != b.tl.get());
319
320   a = Foo();
321   b = Foo();
322   EXPECT_TRUE(a.tl.get() != b.tl.get());
323 }
324
325 TEST(ThreadLocal, Movable2) {
326   std::map<int, Foo> map;
327
328   map[42];
329   map[10];
330   map[23];
331   map[100];
332
333   std::set<void*> tls;
334   for (auto& m : map) {
335     tls.insert(m.second.tl.get());
336   }
337
338   // Make sure that we have 4 different instances of *tl
339   EXPECT_EQ(4, tls.size());
340 }
341
342 namespace {
343
344 constexpr size_t kFillObjectSize = 300;
345
346 std::atomic<uint64_t> gDestroyed;
347
348 /**
349  * Fill a chunk of memory with a unique-ish pattern that includes the thread id
350  * (so deleting one of these from another thread would cause a failure)
351  *
352  * Verify it explicitly and on destruction.
353  */
354 class FillObject {
355  public:
356   explicit FillObject(uint64_t idx) : idx_(idx) {
357     uint64_t v = val();
358     for (size_t i = 0; i < kFillObjectSize; ++i) {
359       data_[i] = v;
360     }
361   }
362
363   void check() {
364     uint64_t v = val();
365     for (size_t i = 0; i < kFillObjectSize; ++i) {
366       CHECK_EQ(v, data_[i]);
367     }
368   }
369
370   ~FillObject() {
371     ++gDestroyed;
372   }
373
374  private:
375   uint64_t val() const {
376     return (idx_ << 40) | uint64_t(pthread_self());
377   }
378
379   uint64_t idx_;
380   uint64_t data_[kFillObjectSize];
381 };
382
383 }  // namespace
384
385 #if FOLLY_HAVE_STD_THIS_THREAD_SLEEP_FOR
386 TEST(ThreadLocal, Stress) {
387   constexpr size_t numFillObjects = 250;
388   std::array<ThreadLocalPtr<FillObject>, numFillObjects> objects;
389
390   constexpr size_t numThreads = 32;
391   constexpr size_t numReps = 20;
392
393   std::vector<std::thread> threads;
394   threads.reserve(numThreads);
395
396   for (size_t i = 0; i < numThreads; ++i) {
397     threads.emplace_back([&objects] {
398       for (size_t rep = 0; rep < numReps; ++rep) {
399         for (size_t i = 0; i < objects.size(); ++i) {
400           objects[i].reset(new FillObject(rep * objects.size() + i));
401           std::this_thread::sleep_for(std::chrono::microseconds(100));
402         }
403         for (size_t i = 0; i < objects.size(); ++i) {
404           objects[i]->check();
405         }
406       }
407     });
408   }
409
410   for (auto& t : threads) {
411     t.join();
412   }
413
414   EXPECT_EQ(numFillObjects * numThreads * numReps, gDestroyed);
415 }
416 #endif
417
418 // Yes, threads and fork don't mix
419 // (http://cppwisdom.quora.com/Why-threads-and-fork-dont-mix) but if you're
420 // stupid or desperate enough to try, we shouldn't stand in your way.
421 namespace {
422 class HoldsOne {
423  public:
424   HoldsOne() : value_(1) { }
425   // Do an actual access to catch the buggy case where this == nullptr
426   int value() const { return value_; }
427  private:
428   int value_;
429 };
430
431 struct HoldsOneTag {};
432
433 ThreadLocal<HoldsOne, HoldsOneTag> ptr;
434
435 int totalValue() {
436   int value = 0;
437   for (auto& p : ptr.accessAllThreads()) {
438     value += p.value();
439   }
440   return value;
441 }
442
443 }  // namespace
444
445 #ifdef FOLLY_HAVE_PTHREAD_ATFORK
446 TEST(ThreadLocal, Fork) {
447   EXPECT_EQ(1, ptr->value());  // ensure created
448   EXPECT_EQ(1, totalValue());
449   // Spawn a new thread
450
451   std::mutex mutex;
452   bool started = false;
453   std::condition_variable startedCond;
454   bool stopped = false;
455   std::condition_variable stoppedCond;
456
457   std::thread t([&] () {
458     EXPECT_EQ(1, ptr->value());  // ensure created
459     {
460       std::unique_lock<std::mutex> lock(mutex);
461       started = true;
462       startedCond.notify_all();
463     }
464     {
465       std::unique_lock<std::mutex> lock(mutex);
466       while (!stopped) {
467         stoppedCond.wait(lock);
468       }
469     }
470   });
471
472   {
473     std::unique_lock<std::mutex> lock(mutex);
474     while (!started) {
475       startedCond.wait(lock);
476     }
477   }
478
479   EXPECT_EQ(2, totalValue());
480
481   pid_t pid = fork();
482   if (pid == 0) {
483     // in child
484     int v = totalValue();
485
486     // exit successfully if v == 1 (one thread)
487     // diagnostic error code otherwise :)
488     switch (v) {
489     case 1: _exit(0);
490     case 0: _exit(1);
491     }
492     _exit(2);
493   } else if (pid > 0) {
494     // in parent
495     int status;
496     EXPECT_EQ(pid, waitpid(pid, &status, 0));
497     EXPECT_TRUE(WIFEXITED(status));
498     EXPECT_EQ(0, WEXITSTATUS(status));
499   } else {
500     EXPECT_TRUE(false) << "fork failed";
501   }
502
503   EXPECT_EQ(2, totalValue());
504
505   {
506     std::unique_lock<std::mutex> lock(mutex);
507     stopped = true;
508     stoppedCond.notify_all();
509   }
510
511   t.join();
512
513   EXPECT_EQ(1, totalValue());
514 }
515 #endif
516
517 struct HoldsOneTag2 {};
518
519 TEST(ThreadLocal, Fork2) {
520   // A thread-local tag that was used in the parent from a *different* thread
521   // (but not the forking thread) would cause the child to hang in a
522   // ThreadLocalPtr's object destructor. Yeah.
523   ThreadLocal<HoldsOne, HoldsOneTag2> p;
524   {
525     // use tag in different thread
526     std::thread t([&p] { p.get(); });
527     t.join();
528   }
529   pid_t pid = fork();
530   if (pid == 0) {
531     {
532       ThreadLocal<HoldsOne, HoldsOneTag2> q;
533       q.get();
534     }
535     _exit(0);
536   } else if (pid > 0) {
537     int status;
538     EXPECT_EQ(pid, waitpid(pid, &status, 0));
539     EXPECT_TRUE(WIFEXITED(status));
540     EXPECT_EQ(0, WEXITSTATUS(status));
541   } else {
542     EXPECT_TRUE(false) << "fork failed";
543   }
544 }
545
546 TEST(ThreadLocal, SharedLibrary) {
547   auto exe = fs::executable_path();
548   auto lib = exe.parent_path() / "lib_thread_local_test.so";
549   auto handle = dlopen(lib.string().c_str(), RTLD_LAZY);
550   EXPECT_NE(nullptr, handle);
551
552   typedef void (*useA_t)();
553   dlerror();
554   useA_t useA = (useA_t) dlsym(handle, "useA");
555
556   const char *dlsym_error = dlerror();
557   EXPECT_EQ(nullptr, dlsym_error);
558
559   useA();
560
561   folly::Baton<> b11, b12, b21, b22;
562
563   std::thread t1([&]() {
564       useA();
565       b11.post();
566       b12.wait();
567     });
568
569   std::thread t2([&]() {
570       useA();
571       b21.post();
572       b22.wait();
573     });
574
575   b11.wait();
576   b21.wait();
577
578   dlclose(handle);
579
580   b12.post();
581   b22.post();
582
583   t1.join();
584   t2.join();
585 }
586
587 namespace folly { namespace threadlocal_detail {
588 struct PthreadKeyUnregisterTester {
589   PthreadKeyUnregister p;
590   constexpr PthreadKeyUnregisterTester() = default;
591 };
592 }}
593
594 TEST(ThreadLocal, UnregisterClassHasConstExprCtor) {
595   folly::threadlocal_detail::PthreadKeyUnregisterTester x;
596   // yep!
597   SUCCEED();
598 }
599
600 // clang is unable to compile this code unless in c++14 mode.
601 #if __cplusplus >= 201402L
602 namespace {
603 // This will fail to compile unless ThreadLocal{Ptr} has a constexpr
604 // default constructor. This ensures that ThreadLocal is safe to use in
605 // static constructors without worrying about initialization order
606 class ConstexprThreadLocalCompile {
607   ThreadLocal<int> a_;
608   ThreadLocalPtr<int> b_;
609
610   constexpr ConstexprThreadLocalCompile() {}
611 };
612 }
613 #endif
614
615 // Simple reference implementation using pthread_get_specific
616 template<typename T>
617 class PThreadGetSpecific {
618  public:
619   PThreadGetSpecific() : key_(0) {
620     pthread_key_create(&key_, OnThreadExit);
621   }
622
623   T* get() const {
624     return static_cast<T*>(pthread_getspecific(key_));
625   }
626
627   void reset(T* t) {
628     delete get();
629     pthread_setspecific(key_, t);
630   }
631   static void OnThreadExit(void* obj) {
632     delete static_cast<T*>(obj);
633   }
634  private:
635   pthread_key_t key_;
636 };
637
638 DEFINE_int32(numThreads, 8, "Number simultaneous threads for benchmarks.");
639
640 #define REG(var)                                                \
641   BENCHMARK(FB_CONCATENATE(BM_mt_, var), iters) {               \
642     const int itersPerThread = iters / FLAGS_numThreads;        \
643     std::vector<std::thread> threads;                           \
644     for (int i = 0; i < FLAGS_numThreads; ++i) {                \
645       threads.push_back(std::thread([&]() {                     \
646         var.reset(new int(0));                                  \
647         for (int i = 0; i < itersPerThread; ++i) {              \
648           ++(*var.get());                                       \
649         }                                                       \
650       }));                                                      \
651     }                                                           \
652     for (auto& t : threads) {                                   \
653       t.join();                                                 \
654     }                                                           \
655   }
656
657 ThreadLocalPtr<int> tlp;
658 REG(tlp);
659 PThreadGetSpecific<int> pthread_get_specific;
660 REG(pthread_get_specific);
661 boost::thread_specific_ptr<int> boost_tsp;
662 REG(boost_tsp);
663 BENCHMARK_DRAW_LINE();
664
665 int main(int argc, char** argv) {
666   testing::InitGoogleTest(&argc, argv);
667   gflags::ParseCommandLineFlags(&argc, &argv, true);
668   gflags::SetCommandLineOptionWithMode(
669     "bm_max_iters", "100000000", gflags::SET_FLAG_IF_DEFAULT
670   );
671   if (FLAGS_benchmark) {
672     folly::runBenchmarks();
673   }
674   return RUN_ALL_TESTS();
675 }
676
677 /*
678 Ran with 24 threads on dual 12-core Xeon(R) X5650 @ 2.67GHz with 12-MB caches
679
680 Benchmark                               Iters   Total t    t/iter iter/sec
681 ------------------------------------------------------------------------------
682 *       BM_mt_tlp                   100000000  39.88 ms  398.8 ps  2.335 G
683  +5.91% BM_mt_pthread_get_specific  100000000  42.23 ms  422.3 ps  2.205 G
684  + 295% BM_mt_boost_tsp             100000000  157.8 ms  1.578 ns  604.5 M
685 ------------------------------------------------------------------------------
686 */