16d7a35cab78bd29cfc108cf2e206c6ac8e2ffce
[folly.git] / folly / test / ThreadLocalTest.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/ThreadLocal.h>
18
19 #include <sys/types.h>
20 #include <sys/wait.h>
21 #include <unistd.h>
22
23 #include <array>
24 #include <atomic>
25 #include <chrono>
26 #include <condition_variable>
27 #include <limits.h>
28 #include <map>
29 #include <mutex>
30 #include <set>
31 #include <thread>
32 #include <unordered_map>
33
34 #include <boost/thread/tss.hpp>
35 #include <gflags/gflags.h>
36 #include <glog/logging.h>
37 #include <gtest/gtest.h>
38
39 #include <folly/Benchmark.h>
40
41 using namespace folly;
42
43 struct Widget {
44   static int totalVal_;
45   int val_;
46   ~Widget() {
47     totalVal_ += val_;
48   }
49
50   static void customDeleter(Widget* w, TLPDestructionMode mode) {
51     totalVal_ += (mode == TLPDestructionMode::ALL_THREADS) * 1000;
52     delete w;
53   }
54 };
55 int Widget::totalVal_ = 0;
56
57 TEST(ThreadLocalPtr, BasicDestructor) {
58   Widget::totalVal_ = 0;
59   ThreadLocalPtr<Widget> w;
60   std::thread([&w]() {
61       w.reset(new Widget());
62       w.get()->val_ += 10;
63     }).join();
64   EXPECT_EQ(10, Widget::totalVal_);
65 }
66
67 TEST(ThreadLocalPtr, CustomDeleter1) {
68   Widget::totalVal_ = 0;
69   {
70     ThreadLocalPtr<Widget> w;
71     std::thread([&w]() {
72         w.reset(new Widget(), Widget::customDeleter);
73         w.get()->val_ += 10;
74       }).join();
75     EXPECT_EQ(10, Widget::totalVal_);
76   }
77   EXPECT_EQ(10, Widget::totalVal_);
78 }
79
80 TEST(ThreadLocalPtr, resetNull) {
81   ThreadLocalPtr<int> tl;
82   EXPECT_FALSE(tl);
83   tl.reset(new int(4));
84   EXPECT_TRUE(static_cast<bool>(tl));
85   EXPECT_EQ(*tl.get(), 4);
86   tl.reset();
87   EXPECT_FALSE(tl);
88 }
89
90 TEST(ThreadLocalPtr, TestRelease) {
91   Widget::totalVal_ = 0;
92   ThreadLocalPtr<Widget> w;
93   std::unique_ptr<Widget> wPtr;
94   std::thread([&w, &wPtr]() {
95       w.reset(new Widget());
96       w.get()->val_ += 10;
97
98       wPtr.reset(w.release());
99     }).join();
100   EXPECT_EQ(0, Widget::totalVal_);
101   wPtr.reset();
102   EXPECT_EQ(10, Widget::totalVal_);
103 }
104
105 TEST(ThreadLocalPtr, CreateOnThreadExit) {
106   Widget::totalVal_ = 0;
107   ThreadLocal<Widget> w;
108   ThreadLocalPtr<int> tl;
109
110   std::thread([&] {
111       tl.reset(new int(1), [&] (int* ptr, TLPDestructionMode mode) {
112         delete ptr;
113         // This test ensures Widgets allocated here are not leaked.
114         ++w.get()->val_;
115         ThreadLocal<Widget> wl;
116         ++wl.get()->val_;
117       });
118     }).join();
119   EXPECT_EQ(2, Widget::totalVal_);
120 }
121
122 // Test deleting the ThreadLocalPtr object
123 TEST(ThreadLocalPtr, CustomDeleter2) {
124   Widget::totalVal_ = 0;
125   std::thread t;
126   std::mutex mutex;
127   std::condition_variable cv;
128   enum class State {
129     START,
130     DONE,
131     EXIT
132   };
133   State state = State::START;
134   {
135     ThreadLocalPtr<Widget> w;
136     t = std::thread([&]() {
137         w.reset(new Widget(), Widget::customDeleter);
138         w.get()->val_ += 10;
139
140         // Notify main thread that we're done
141         {
142           std::unique_lock<std::mutex> lock(mutex);
143           state = State::DONE;
144           cv.notify_all();
145         }
146
147         // Wait for main thread to allow us to exit
148         {
149           std::unique_lock<std::mutex> lock(mutex);
150           while (state != State::EXIT) {
151             cv.wait(lock);
152           }
153         }
154     });
155
156     // Wait for main thread to start (and set w.get()->val_)
157     {
158       std::unique_lock<std::mutex> lock(mutex);
159       while (state != State::DONE) {
160         cv.wait(lock);
161       }
162     }
163
164     // Thread started but hasn't exited yet
165     EXPECT_EQ(0, Widget::totalVal_);
166
167     // Destroy ThreadLocalPtr<Widget> (by letting it go out of scope)
168   }
169
170   EXPECT_EQ(1010, Widget::totalVal_);
171
172   // Allow thread to exit
173   {
174     std::unique_lock<std::mutex> lock(mutex);
175     state = State::EXIT;
176     cv.notify_all();
177   }
178   t.join();
179
180   EXPECT_EQ(1010, Widget::totalVal_);
181 }
182
183 TEST(ThreadLocal, BasicDestructor) {
184   Widget::totalVal_ = 0;
185   ThreadLocal<Widget> w;
186   std::thread([&w]() { w->val_ += 10; }).join();
187   EXPECT_EQ(10, Widget::totalVal_);
188 }
189
190 TEST(ThreadLocal, SimpleRepeatDestructor) {
191   Widget::totalVal_ = 0;
192   {
193     ThreadLocal<Widget> w;
194     w->val_ += 10;
195   }
196   {
197     ThreadLocal<Widget> w;
198     w->val_ += 10;
199   }
200   EXPECT_EQ(20, Widget::totalVal_);
201 }
202
203 TEST(ThreadLocal, InterleavedDestructors) {
204   Widget::totalVal_ = 0;
205   std::unique_ptr<ThreadLocal<Widget>> w;
206   int wVersion = 0;
207   const int wVersionMax = 2;
208   int thIter = 0;
209   std::mutex lock;
210   auto th = std::thread([&]() {
211     int wVersionPrev = 0;
212     while (true) {
213       while (true) {
214         std::lock_guard<std::mutex> g(lock);
215         if (wVersion > wVersionMax) {
216           return;
217         }
218         if (wVersion > wVersionPrev) {
219           // We have a new version of w, so it should be initialized to zero
220           EXPECT_EQ((*w)->val_, 0);
221           break;
222         }
223       }
224       std::lock_guard<std::mutex> g(lock);
225       wVersionPrev = wVersion;
226       (*w)->val_ += 10;
227       ++thIter;
228     }
229   });
230   FOR_EACH_RANGE(i, 0, wVersionMax) {
231     int thIterPrev = 0;
232     {
233       std::lock_guard<std::mutex> g(lock);
234       thIterPrev = thIter;
235       w.reset(new ThreadLocal<Widget>());
236       ++wVersion;
237     }
238     while (true) {
239       std::lock_guard<std::mutex> g(lock);
240       if (thIter > thIterPrev) {
241         break;
242       }
243     }
244   }
245   {
246     std::lock_guard<std::mutex> g(lock);
247     wVersion = wVersionMax + 1;
248   }
249   th.join();
250   EXPECT_EQ(wVersionMax * 10, Widget::totalVal_);
251 }
252
253 TEST(ThreadLocalPtr, ODRUseEntryIDkInvalid) {
254   // EntryID::kInvalid is odr-used
255   // see http://en.cppreference.com/w/cpp/language/static
256   const uint32_t* pInvalid =
257     &(threadlocal_detail::StaticMeta<void>::EntryID::kInvalid);
258   EXPECT_EQ(std::numeric_limits<uint32_t>::max(), *pInvalid);
259 }
260
261 class SimpleThreadCachedInt {
262
263   class NewTag;
264   ThreadLocal<int,NewTag> val_;
265
266  public:
267   void add(int val) {
268     *val_ += val;
269   }
270
271   int read() {
272     int ret = 0;
273     for (const auto& i : val_.accessAllThreads()) {
274       ret += i;
275     }
276     return ret;
277   }
278 };
279
280 TEST(ThreadLocalPtr, AccessAllThreadsCounter) {
281   const int kNumThreads = 10;
282   SimpleThreadCachedInt stci;
283   std::atomic<bool> run(true);
284   std::atomic<int> totalAtomic(0);
285   std::vector<std::thread> threads;
286   for (int i = 0; i < kNumThreads; ++i) {
287     threads.push_back(std::thread([&,i]() {
288       stci.add(1);
289       totalAtomic.fetch_add(1);
290       while (run.load()) { usleep(100); }
291     }));
292   }
293   while (totalAtomic.load() != kNumThreads) { usleep(100); }
294   EXPECT_EQ(kNumThreads, stci.read());
295   run.store(false);
296   for (auto& t : threads) {
297     t.join();
298   }
299 }
300
301 TEST(ThreadLocal, resetNull) {
302   ThreadLocal<int> tl;
303   tl.reset(new int(4));
304   EXPECT_EQ(*tl.get(), 4);
305   tl.reset();
306   EXPECT_EQ(*tl.get(), 0);
307   tl.reset(new int(5));
308   EXPECT_EQ(*tl.get(), 5);
309 }
310
311 namespace {
312 struct Tag {};
313
314 struct Foo {
315   folly::ThreadLocal<int, Tag> tl;
316 };
317 }  // namespace
318
319 TEST(ThreadLocal, Movable1) {
320   Foo a;
321   Foo b;
322   EXPECT_TRUE(a.tl.get() != b.tl.get());
323
324   a = Foo();
325   b = Foo();
326   EXPECT_TRUE(a.tl.get() != b.tl.get());
327 }
328
329 TEST(ThreadLocal, Movable2) {
330   std::map<int, Foo> map;
331
332   map[42];
333   map[10];
334   map[23];
335   map[100];
336
337   std::set<void*> tls;
338   for (auto& m : map) {
339     tls.insert(m.second.tl.get());
340   }
341
342   // Make sure that we have 4 different instances of *tl
343   EXPECT_EQ(4, tls.size());
344 }
345
346 namespace {
347
348 constexpr size_t kFillObjectSize = 300;
349
350 std::atomic<uint64_t> gDestroyed;
351
352 /**
353  * Fill a chunk of memory with a unique-ish pattern that includes the thread id
354  * (so deleting one of these from another thread would cause a failure)
355  *
356  * Verify it explicitly and on destruction.
357  */
358 class FillObject {
359  public:
360   explicit FillObject(uint64_t idx) : idx_(idx) {
361     uint64_t v = val();
362     for (size_t i = 0; i < kFillObjectSize; ++i) {
363       data_[i] = v;
364     }
365   }
366
367   void check() {
368     uint64_t v = val();
369     for (size_t i = 0; i < kFillObjectSize; ++i) {
370       CHECK_EQ(v, data_[i]);
371     }
372   }
373
374   ~FillObject() {
375     ++gDestroyed;
376   }
377
378  private:
379   uint64_t val() const {
380     return (idx_ << 40) | uint64_t(pthread_self());
381   }
382
383   uint64_t idx_;
384   uint64_t data_[kFillObjectSize];
385 };
386
387 }  // namespace
388
389 #if FOLLY_HAVE_STD_THIS_THREAD_SLEEP_FOR
390 TEST(ThreadLocal, Stress) {
391   constexpr size_t numFillObjects = 250;
392   std::array<ThreadLocalPtr<FillObject>, numFillObjects> objects;
393
394   constexpr size_t numThreads = 32;
395   constexpr size_t numReps = 20;
396
397   std::vector<std::thread> threads;
398   threads.reserve(numThreads);
399
400   for (size_t i = 0; i < numThreads; ++i) {
401     threads.emplace_back([&objects] {
402       for (size_t rep = 0; rep < numReps; ++rep) {
403         for (size_t i = 0; i < objects.size(); ++i) {
404           objects[i].reset(new FillObject(rep * objects.size() + i));
405           std::this_thread::sleep_for(std::chrono::microseconds(100));
406         }
407         for (size_t i = 0; i < objects.size(); ++i) {
408           objects[i]->check();
409         }
410       }
411     });
412   }
413
414   for (auto& t : threads) {
415     t.join();
416   }
417
418   EXPECT_EQ(numFillObjects * numThreads * numReps, gDestroyed);
419 }
420 #endif
421
422 // Yes, threads and fork don't mix
423 // (http://cppwisdom.quora.com/Why-threads-and-fork-dont-mix) but if you're
424 // stupid or desperate enough to try, we shouldn't stand in your way.
425 namespace {
426 class HoldsOne {
427  public:
428   HoldsOne() : value_(1) { }
429   // Do an actual access to catch the buggy case where this == nullptr
430   int value() const { return value_; }
431  private:
432   int value_;
433 };
434
435 struct HoldsOneTag {};
436
437 ThreadLocal<HoldsOne, HoldsOneTag> ptr;
438
439 int totalValue() {
440   int value = 0;
441   for (auto& p : ptr.accessAllThreads()) {
442     value += p.value();
443   }
444   return value;
445 }
446
447 }  // namespace
448
449 #ifdef FOLLY_HAVE_PTHREAD_ATFORK
450 TEST(ThreadLocal, Fork) {
451   EXPECT_EQ(1, ptr->value());  // ensure created
452   EXPECT_EQ(1, totalValue());
453   // Spawn a new thread
454
455   std::mutex mutex;
456   bool started = false;
457   std::condition_variable startedCond;
458   bool stopped = false;
459   std::condition_variable stoppedCond;
460
461   std::thread t([&] () {
462     EXPECT_EQ(1, ptr->value());  // ensure created
463     {
464       std::unique_lock<std::mutex> lock(mutex);
465       started = true;
466       startedCond.notify_all();
467     }
468     {
469       std::unique_lock<std::mutex> lock(mutex);
470       while (!stopped) {
471         stoppedCond.wait(lock);
472       }
473     }
474   });
475
476   {
477     std::unique_lock<std::mutex> lock(mutex);
478     while (!started) {
479       startedCond.wait(lock);
480     }
481   }
482
483   EXPECT_EQ(2, totalValue());
484
485   pid_t pid = fork();
486   if (pid == 0) {
487     // in child
488     int v = totalValue();
489
490     // exit successfully if v == 1 (one thread)
491     // diagnostic error code otherwise :)
492     switch (v) {
493     case 1: _exit(0);
494     case 0: _exit(1);
495     }
496     _exit(2);
497   } else if (pid > 0) {
498     // in parent
499     int status;
500     EXPECT_EQ(pid, waitpid(pid, &status, 0));
501     EXPECT_TRUE(WIFEXITED(status));
502     EXPECT_EQ(0, WEXITSTATUS(status));
503   } else {
504     EXPECT_TRUE(false) << "fork failed";
505   }
506
507   EXPECT_EQ(2, totalValue());
508
509   {
510     std::unique_lock<std::mutex> lock(mutex);
511     stopped = true;
512     stoppedCond.notify_all();
513   }
514
515   t.join();
516
517   EXPECT_EQ(1, totalValue());
518 }
519 #endif
520
521 struct HoldsOneTag2 {};
522
523 TEST(ThreadLocal, Fork2) {
524   // A thread-local tag that was used in the parent from a *different* thread
525   // (but not the forking thread) would cause the child to hang in a
526   // ThreadLocalPtr's object destructor. Yeah.
527   ThreadLocal<HoldsOne, HoldsOneTag2> p;
528   {
529     // use tag in different thread
530     std::thread t([&p] { p.get(); });
531     t.join();
532   }
533   pid_t pid = fork();
534   if (pid == 0) {
535     {
536       ThreadLocal<HoldsOne, HoldsOneTag2> q;
537       q.get();
538     }
539     _exit(0);
540   } else if (pid > 0) {
541     int status;
542     EXPECT_EQ(pid, waitpid(pid, &status, 0));
543     EXPECT_TRUE(WIFEXITED(status));
544     EXPECT_EQ(0, WEXITSTATUS(status));
545   } else {
546     EXPECT_TRUE(false) << "fork failed";
547   }
548 }
549
550 // clang is unable to compile this code unless in c++14 mode.
551 #if __cplusplus >= 201402L
552 namespace {
553 // This will fail to compile unless ThreadLocal{Ptr} has a constexpr
554 // default constructor. This ensures that ThreadLocal is safe to use in
555 // static constructors without worrying about initialization order
556 class ConstexprThreadLocalCompile {
557   ThreadLocal<int> a_;
558   ThreadLocalPtr<int> b_;
559
560   constexpr ConstexprThreadLocalCompile() {}
561 };
562 }
563 #endif
564
565 // Simple reference implementation using pthread_get_specific
566 template<typename T>
567 class PThreadGetSpecific {
568  public:
569   PThreadGetSpecific() : key_(0) {
570     pthread_key_create(&key_, OnThreadExit);
571   }
572
573   T* get() const {
574     return static_cast<T*>(pthread_getspecific(key_));
575   }
576
577   void reset(T* t) {
578     delete get();
579     pthread_setspecific(key_, t);
580   }
581   static void OnThreadExit(void* obj) {
582     delete static_cast<T*>(obj);
583   }
584  private:
585   pthread_key_t key_;
586 };
587
588 DEFINE_int32(numThreads, 8, "Number simultaneous threads for benchmarks.");
589
590 #define REG(var)                                                \
591   BENCHMARK(FB_CONCATENATE(BM_mt_, var), iters) {               \
592     const int itersPerThread = iters / FLAGS_numThreads;        \
593     std::vector<std::thread> threads;                           \
594     for (int i = 0; i < FLAGS_numThreads; ++i) {                \
595       threads.push_back(std::thread([&]() {                     \
596         var.reset(new int(0));                                  \
597         for (int i = 0; i < itersPerThread; ++i) {              \
598           ++(*var.get());                                       \
599         }                                                       \
600       }));                                                      \
601     }                                                           \
602     for (auto& t : threads) {                                   \
603       t.join();                                                 \
604     }                                                           \
605   }
606
607 ThreadLocalPtr<int> tlp;
608 REG(tlp);
609 PThreadGetSpecific<int> pthread_get_specific;
610 REG(pthread_get_specific);
611 boost::thread_specific_ptr<int> boost_tsp;
612 REG(boost_tsp);
613 BENCHMARK_DRAW_LINE();
614
615 int main(int argc, char** argv) {
616   testing::InitGoogleTest(&argc, argv);
617   gflags::ParseCommandLineFlags(&argc, &argv, true);
618   gflags::SetCommandLineOptionWithMode(
619     "bm_max_iters", "100000000", gflags::SET_FLAG_IF_DEFAULT
620   );
621   if (FLAGS_benchmark) {
622     folly::runBenchmarks();
623   }
624   return RUN_ALL_TESTS();
625 }
626
627 /*
628 Ran with 24 threads on dual 12-core Xeon(R) X5650 @ 2.67GHz with 12-MB caches
629
630 Benchmark                               Iters   Total t    t/iter iter/sec
631 ------------------------------------------------------------------------------
632 *       BM_mt_tlp                   100000000  39.88 ms  398.8 ps  2.335 G
633  +5.91% BM_mt_pthread_get_specific  100000000  42.23 ms  422.3 ps  2.205 G
634  + 295% BM_mt_boost_tsp             100000000  157.8 ms  1.578 ns  604.5 M
635 ------------------------------------------------------------------------------
636 */