Fix folly/ThreadLocal with clang after PthreadKeyUnregister change
[folly.git] / folly / test / ThreadLocalTest.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/ThreadLocal.h>
18
19 #include <dlfcn.h>
20 #include <sys/types.h>
21 #include <sys/wait.h>
22 #include <unistd.h>
23
24 #include <array>
25 #include <atomic>
26 #include <chrono>
27 #include <condition_variable>
28 #include <limits.h>
29 #include <map>
30 #include <mutex>
31 #include <set>
32 #include <thread>
33 #include <unordered_map>
34
35 #include <boost/thread/tss.hpp>
36 #include <gflags/gflags.h>
37 #include <glog/logging.h>
38 #include <gtest/gtest.h>
39
40 #include <folly/Benchmark.h>
41 #include <folly/Baton.h>
42
43 using namespace folly;
44
45 struct Widget {
46   static int totalVal_;
47   int val_;
48   ~Widget() {
49     totalVal_ += val_;
50   }
51
52   static void customDeleter(Widget* w, TLPDestructionMode mode) {
53     totalVal_ += (mode == TLPDestructionMode::ALL_THREADS) * 1000;
54     delete w;
55   }
56 };
57 int Widget::totalVal_ = 0;
58
59 TEST(ThreadLocalPtr, BasicDestructor) {
60   Widget::totalVal_ = 0;
61   ThreadLocalPtr<Widget> w;
62   std::thread([&w]() {
63       w.reset(new Widget());
64       w.get()->val_ += 10;
65     }).join();
66   EXPECT_EQ(10, Widget::totalVal_);
67 }
68
69 TEST(ThreadLocalPtr, CustomDeleter1) {
70   Widget::totalVal_ = 0;
71   {
72     ThreadLocalPtr<Widget> w;
73     std::thread([&w]() {
74         w.reset(new Widget(), Widget::customDeleter);
75         w.get()->val_ += 10;
76       }).join();
77     EXPECT_EQ(10, Widget::totalVal_);
78   }
79   EXPECT_EQ(10, Widget::totalVal_);
80 }
81
82 TEST(ThreadLocalPtr, resetNull) {
83   ThreadLocalPtr<int> tl;
84   EXPECT_FALSE(tl);
85   tl.reset(new int(4));
86   EXPECT_TRUE(static_cast<bool>(tl));
87   EXPECT_EQ(*tl.get(), 4);
88   tl.reset();
89   EXPECT_FALSE(tl);
90 }
91
92 TEST(ThreadLocalPtr, TestRelease) {
93   Widget::totalVal_ = 0;
94   ThreadLocalPtr<Widget> w;
95   std::unique_ptr<Widget> wPtr;
96   std::thread([&w, &wPtr]() {
97       w.reset(new Widget());
98       w.get()->val_ += 10;
99
100       wPtr.reset(w.release());
101     }).join();
102   EXPECT_EQ(0, Widget::totalVal_);
103   wPtr.reset();
104   EXPECT_EQ(10, Widget::totalVal_);
105 }
106
107 TEST(ThreadLocalPtr, CreateOnThreadExit) {
108   Widget::totalVal_ = 0;
109   ThreadLocal<Widget> w;
110   ThreadLocalPtr<int> tl;
111
112   std::thread([&] {
113       tl.reset(new int(1), [&] (int* ptr, TLPDestructionMode mode) {
114         delete ptr;
115         // This test ensures Widgets allocated here are not leaked.
116         ++w.get()->val_;
117         ThreadLocal<Widget> wl;
118         ++wl.get()->val_;
119       });
120     }).join();
121   EXPECT_EQ(2, Widget::totalVal_);
122 }
123
124 // Test deleting the ThreadLocalPtr object
125 TEST(ThreadLocalPtr, CustomDeleter2) {
126   Widget::totalVal_ = 0;
127   std::thread t;
128   std::mutex mutex;
129   std::condition_variable cv;
130   enum class State {
131     START,
132     DONE,
133     EXIT
134   };
135   State state = State::START;
136   {
137     ThreadLocalPtr<Widget> w;
138     t = std::thread([&]() {
139         w.reset(new Widget(), Widget::customDeleter);
140         w.get()->val_ += 10;
141
142         // Notify main thread that we're done
143         {
144           std::unique_lock<std::mutex> lock(mutex);
145           state = State::DONE;
146           cv.notify_all();
147         }
148
149         // Wait for main thread to allow us to exit
150         {
151           std::unique_lock<std::mutex> lock(mutex);
152           while (state != State::EXIT) {
153             cv.wait(lock);
154           }
155         }
156     });
157
158     // Wait for main thread to start (and set w.get()->val_)
159     {
160       std::unique_lock<std::mutex> lock(mutex);
161       while (state != State::DONE) {
162         cv.wait(lock);
163       }
164     }
165
166     // Thread started but hasn't exited yet
167     EXPECT_EQ(0, Widget::totalVal_);
168
169     // Destroy ThreadLocalPtr<Widget> (by letting it go out of scope)
170   }
171
172   EXPECT_EQ(1010, Widget::totalVal_);
173
174   // Allow thread to exit
175   {
176     std::unique_lock<std::mutex> lock(mutex);
177     state = State::EXIT;
178     cv.notify_all();
179   }
180   t.join();
181
182   EXPECT_EQ(1010, Widget::totalVal_);
183 }
184
185 TEST(ThreadLocal, BasicDestructor) {
186   Widget::totalVal_ = 0;
187   ThreadLocal<Widget> w;
188   std::thread([&w]() { w->val_ += 10; }).join();
189   EXPECT_EQ(10, Widget::totalVal_);
190 }
191
192 TEST(ThreadLocal, SimpleRepeatDestructor) {
193   Widget::totalVal_ = 0;
194   {
195     ThreadLocal<Widget> w;
196     w->val_ += 10;
197   }
198   {
199     ThreadLocal<Widget> w;
200     w->val_ += 10;
201   }
202   EXPECT_EQ(20, Widget::totalVal_);
203 }
204
205 TEST(ThreadLocal, InterleavedDestructors) {
206   Widget::totalVal_ = 0;
207   std::unique_ptr<ThreadLocal<Widget>> w;
208   int wVersion = 0;
209   const int wVersionMax = 2;
210   int thIter = 0;
211   std::mutex lock;
212   auto th = std::thread([&]() {
213     int wVersionPrev = 0;
214     while (true) {
215       while (true) {
216         std::lock_guard<std::mutex> g(lock);
217         if (wVersion > wVersionMax) {
218           return;
219         }
220         if (wVersion > wVersionPrev) {
221           // We have a new version of w, so it should be initialized to zero
222           EXPECT_EQ((*w)->val_, 0);
223           break;
224         }
225       }
226       std::lock_guard<std::mutex> g(lock);
227       wVersionPrev = wVersion;
228       (*w)->val_ += 10;
229       ++thIter;
230     }
231   });
232   FOR_EACH_RANGE(i, 0, wVersionMax) {
233     int thIterPrev = 0;
234     {
235       std::lock_guard<std::mutex> g(lock);
236       thIterPrev = thIter;
237       w.reset(new ThreadLocal<Widget>());
238       ++wVersion;
239     }
240     while (true) {
241       std::lock_guard<std::mutex> g(lock);
242       if (thIter > thIterPrev) {
243         break;
244       }
245     }
246   }
247   {
248     std::lock_guard<std::mutex> g(lock);
249     wVersion = wVersionMax + 1;
250   }
251   th.join();
252   EXPECT_EQ(wVersionMax * 10, Widget::totalVal_);
253 }
254
255 class SimpleThreadCachedInt {
256
257   class NewTag;
258   ThreadLocal<int,NewTag> val_;
259
260  public:
261   void add(int val) {
262     *val_ += val;
263   }
264
265   int read() {
266     int ret = 0;
267     for (const auto& i : val_.accessAllThreads()) {
268       ret += i;
269     }
270     return ret;
271   }
272 };
273
274 TEST(ThreadLocalPtr, AccessAllThreadsCounter) {
275   const int kNumThreads = 10;
276   SimpleThreadCachedInt stci;
277   std::atomic<bool> run(true);
278   std::atomic<int> totalAtomic(0);
279   std::vector<std::thread> threads;
280   for (int i = 0; i < kNumThreads; ++i) {
281     threads.push_back(std::thread([&,i]() {
282       stci.add(1);
283       totalAtomic.fetch_add(1);
284       while (run.load()) { usleep(100); }
285     }));
286   }
287   while (totalAtomic.load() != kNumThreads) { usleep(100); }
288   EXPECT_EQ(kNumThreads, stci.read());
289   run.store(false);
290   for (auto& t : threads) {
291     t.join();
292   }
293 }
294
295 TEST(ThreadLocal, resetNull) {
296   ThreadLocal<int> tl;
297   tl.reset(new int(4));
298   EXPECT_EQ(*tl.get(), 4);
299   tl.reset();
300   EXPECT_EQ(*tl.get(), 0);
301   tl.reset(new int(5));
302   EXPECT_EQ(*tl.get(), 5);
303 }
304
305 namespace {
306 struct Tag {};
307
308 struct Foo {
309   folly::ThreadLocal<int, Tag> tl;
310 };
311 }  // namespace
312
313 TEST(ThreadLocal, Movable1) {
314   Foo a;
315   Foo b;
316   EXPECT_TRUE(a.tl.get() != b.tl.get());
317
318   a = Foo();
319   b = Foo();
320   EXPECT_TRUE(a.tl.get() != b.tl.get());
321 }
322
323 TEST(ThreadLocal, Movable2) {
324   std::map<int, Foo> map;
325
326   map[42];
327   map[10];
328   map[23];
329   map[100];
330
331   std::set<void*> tls;
332   for (auto& m : map) {
333     tls.insert(m.second.tl.get());
334   }
335
336   // Make sure that we have 4 different instances of *tl
337   EXPECT_EQ(4, tls.size());
338 }
339
340 namespace {
341
342 constexpr size_t kFillObjectSize = 300;
343
344 std::atomic<uint64_t> gDestroyed;
345
346 /**
347  * Fill a chunk of memory with a unique-ish pattern that includes the thread id
348  * (so deleting one of these from another thread would cause a failure)
349  *
350  * Verify it explicitly and on destruction.
351  */
352 class FillObject {
353  public:
354   explicit FillObject(uint64_t idx) : idx_(idx) {
355     uint64_t v = val();
356     for (size_t i = 0; i < kFillObjectSize; ++i) {
357       data_[i] = v;
358     }
359   }
360
361   void check() {
362     uint64_t v = val();
363     for (size_t i = 0; i < kFillObjectSize; ++i) {
364       CHECK_EQ(v, data_[i]);
365     }
366   }
367
368   ~FillObject() {
369     ++gDestroyed;
370   }
371
372  private:
373   uint64_t val() const {
374     return (idx_ << 40) | uint64_t(pthread_self());
375   }
376
377   uint64_t idx_;
378   uint64_t data_[kFillObjectSize];
379 };
380
381 }  // namespace
382
383 #if FOLLY_HAVE_STD_THIS_THREAD_SLEEP_FOR
384 TEST(ThreadLocal, Stress) {
385   constexpr size_t numFillObjects = 250;
386   std::array<ThreadLocalPtr<FillObject>, numFillObjects> objects;
387
388   constexpr size_t numThreads = 32;
389   constexpr size_t numReps = 20;
390
391   std::vector<std::thread> threads;
392   threads.reserve(numThreads);
393
394   for (size_t i = 0; i < numThreads; ++i) {
395     threads.emplace_back([&objects] {
396       for (size_t rep = 0; rep < numReps; ++rep) {
397         for (size_t i = 0; i < objects.size(); ++i) {
398           objects[i].reset(new FillObject(rep * objects.size() + i));
399           std::this_thread::sleep_for(std::chrono::microseconds(100));
400         }
401         for (size_t i = 0; i < objects.size(); ++i) {
402           objects[i]->check();
403         }
404       }
405     });
406   }
407
408   for (auto& t : threads) {
409     t.join();
410   }
411
412   EXPECT_EQ(numFillObjects * numThreads * numReps, gDestroyed);
413 }
414 #endif
415
416 // Yes, threads and fork don't mix
417 // (http://cppwisdom.quora.com/Why-threads-and-fork-dont-mix) but if you're
418 // stupid or desperate enough to try, we shouldn't stand in your way.
419 namespace {
420 class HoldsOne {
421  public:
422   HoldsOne() : value_(1) { }
423   // Do an actual access to catch the buggy case where this == nullptr
424   int value() const { return value_; }
425  private:
426   int value_;
427 };
428
429 struct HoldsOneTag {};
430
431 ThreadLocal<HoldsOne, HoldsOneTag> ptr;
432
433 int totalValue() {
434   int value = 0;
435   for (auto& p : ptr.accessAllThreads()) {
436     value += p.value();
437   }
438   return value;
439 }
440
441 }  // namespace
442
443 #ifdef FOLLY_HAVE_PTHREAD_ATFORK
444 TEST(ThreadLocal, Fork) {
445   EXPECT_EQ(1, ptr->value());  // ensure created
446   EXPECT_EQ(1, totalValue());
447   // Spawn a new thread
448
449   std::mutex mutex;
450   bool started = false;
451   std::condition_variable startedCond;
452   bool stopped = false;
453   std::condition_variable stoppedCond;
454
455   std::thread t([&] () {
456     EXPECT_EQ(1, ptr->value());  // ensure created
457     {
458       std::unique_lock<std::mutex> lock(mutex);
459       started = true;
460       startedCond.notify_all();
461     }
462     {
463       std::unique_lock<std::mutex> lock(mutex);
464       while (!stopped) {
465         stoppedCond.wait(lock);
466       }
467     }
468   });
469
470   {
471     std::unique_lock<std::mutex> lock(mutex);
472     while (!started) {
473       startedCond.wait(lock);
474     }
475   }
476
477   EXPECT_EQ(2, totalValue());
478
479   pid_t pid = fork();
480   if (pid == 0) {
481     // in child
482     int v = totalValue();
483
484     // exit successfully if v == 1 (one thread)
485     // diagnostic error code otherwise :)
486     switch (v) {
487     case 1: _exit(0);
488     case 0: _exit(1);
489     }
490     _exit(2);
491   } else if (pid > 0) {
492     // in parent
493     int status;
494     EXPECT_EQ(pid, waitpid(pid, &status, 0));
495     EXPECT_TRUE(WIFEXITED(status));
496     EXPECT_EQ(0, WEXITSTATUS(status));
497   } else {
498     EXPECT_TRUE(false) << "fork failed";
499   }
500
501   EXPECT_EQ(2, totalValue());
502
503   {
504     std::unique_lock<std::mutex> lock(mutex);
505     stopped = true;
506     stoppedCond.notify_all();
507   }
508
509   t.join();
510
511   EXPECT_EQ(1, totalValue());
512 }
513 #endif
514
515 struct HoldsOneTag2 {};
516
517 TEST(ThreadLocal, Fork2) {
518   // A thread-local tag that was used in the parent from a *different* thread
519   // (but not the forking thread) would cause the child to hang in a
520   // ThreadLocalPtr's object destructor. Yeah.
521   ThreadLocal<HoldsOne, HoldsOneTag2> p;
522   {
523     // use tag in different thread
524     std::thread t([&p] { p.get(); });
525     t.join();
526   }
527   pid_t pid = fork();
528   if (pid == 0) {
529     {
530       ThreadLocal<HoldsOne, HoldsOneTag2> q;
531       q.get();
532     }
533     _exit(0);
534   } else if (pid > 0) {
535     int status;
536     EXPECT_EQ(pid, waitpid(pid, &status, 0));
537     EXPECT_TRUE(WIFEXITED(status));
538     EXPECT_EQ(0, WEXITSTATUS(status));
539   } else {
540     EXPECT_TRUE(false) << "fork failed";
541   }
542 }
543
544 TEST(ThreadLocal, SharedLibrary)
545 {
546   auto handle = dlopen("./_bin/folly/test/lib_thread_local_test.so",
547                        RTLD_LAZY);
548   EXPECT_NE(nullptr, handle);
549
550   typedef void (*useA_t)();
551   dlerror();
552   useA_t useA = (useA_t) dlsym(handle, "useA");
553
554   const char *dlsym_error = dlerror();
555   EXPECT_EQ(nullptr, dlsym_error);
556
557   useA();
558
559   folly::Baton<> b11, b12, b21, b22;
560
561   std::thread t1([&]() {
562       useA();
563       b11.post();
564       b12.wait();
565     });
566
567   std::thread t2([&]() {
568       useA();
569       b21.post();
570       b22.wait();
571     });
572
573   b11.wait();
574   b21.wait();
575
576   dlclose(handle);
577
578   b12.post();
579   b22.post();
580
581   t1.join();
582   t2.join();
583 }
584
585 // clang is unable to compile this code unless in c++14 mode.
586 #if __cplusplus >= 201402L
587 namespace {
588 // This will fail to compile unless ThreadLocal{Ptr} has a constexpr
589 // default constructor. This ensures that ThreadLocal is safe to use in
590 // static constructors without worrying about initialization order
591 class ConstexprThreadLocalCompile {
592   ThreadLocal<int> a_;
593   ThreadLocalPtr<int> b_;
594
595   constexpr ConstexprThreadLocalCompile() {}
596 };
597 }
598 #endif
599
600 // Simple reference implementation using pthread_get_specific
601 template<typename T>
602 class PThreadGetSpecific {
603  public:
604   PThreadGetSpecific() : key_(0) {
605     pthread_key_create(&key_, OnThreadExit);
606   }
607
608   T* get() const {
609     return static_cast<T*>(pthread_getspecific(key_));
610   }
611
612   void reset(T* t) {
613     delete get();
614     pthread_setspecific(key_, t);
615   }
616   static void OnThreadExit(void* obj) {
617     delete static_cast<T*>(obj);
618   }
619  private:
620   pthread_key_t key_;
621 };
622
623 DEFINE_int32(numThreads, 8, "Number simultaneous threads for benchmarks.");
624
625 #define REG(var)                                                \
626   BENCHMARK(FB_CONCATENATE(BM_mt_, var), iters) {               \
627     const int itersPerThread = iters / FLAGS_numThreads;        \
628     std::vector<std::thread> threads;                           \
629     for (int i = 0; i < FLAGS_numThreads; ++i) {                \
630       threads.push_back(std::thread([&]() {                     \
631         var.reset(new int(0));                                  \
632         for (int i = 0; i < itersPerThread; ++i) {              \
633           ++(*var.get());                                       \
634         }                                                       \
635       }));                                                      \
636     }                                                           \
637     for (auto& t : threads) {                                   \
638       t.join();                                                 \
639     }                                                           \
640   }
641
642 ThreadLocalPtr<int> tlp;
643 REG(tlp);
644 PThreadGetSpecific<int> pthread_get_specific;
645 REG(pthread_get_specific);
646 boost::thread_specific_ptr<int> boost_tsp;
647 REG(boost_tsp);
648 BENCHMARK_DRAW_LINE();
649
650 int main(int argc, char** argv) {
651   testing::InitGoogleTest(&argc, argv);
652   gflags::ParseCommandLineFlags(&argc, &argv, true);
653   gflags::SetCommandLineOptionWithMode(
654     "bm_max_iters", "100000000", gflags::SET_FLAG_IF_DEFAULT
655   );
656   if (FLAGS_benchmark) {
657     folly::runBenchmarks();
658   }
659   return RUN_ALL_TESTS();
660 }
661
662 /*
663 Ran with 24 threads on dual 12-core Xeon(R) X5650 @ 2.67GHz with 12-MB caches
664
665 Benchmark                               Iters   Total t    t/iter iter/sec
666 ------------------------------------------------------------------------------
667 *       BM_mt_tlp                   100000000  39.88 ms  398.8 ps  2.335 G
668  +5.91% BM_mt_pthread_get_specific  100000000  42.23 ms  422.3 ps  2.205 G
669  + 295% BM_mt_boost_tsp             100000000  157.8 ms  1.578 ns  604.5 M
670 ------------------------------------------------------------------------------
671 */