Fix a bad bug in folly::ThreadLocal (incorrect using of pthread)
[folly.git] / folly / test / ThreadLocalTest.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/ThreadLocal.h>
18
19 #include <sys/types.h>
20 #include <sys/wait.h>
21 #include <unistd.h>
22
23 #include <array>
24 #include <atomic>
25 #include <chrono>
26 #include <condition_variable>
27 #include <map>
28 #include <mutex>
29 #include <set>
30 #include <thread>
31 #include <unordered_map>
32
33 #include <boost/thread/tss.hpp>
34 #include <gflags/gflags.h>
35 #include <glog/logging.h>
36 #include <gtest/gtest.h>
37
38 #include <folly/Benchmark.h>
39
40 using namespace folly;
41
42 struct Widget {
43   static int totalVal_;
44   int val_;
45   ~Widget() {
46     totalVal_ += val_;
47   }
48
49   static void customDeleter(Widget* w, TLPDestructionMode mode) {
50     totalVal_ += (mode == TLPDestructionMode::ALL_THREADS) * 1000;
51     delete w;
52   }
53 };
54 int Widget::totalVal_ = 0;
55
56 TEST(ThreadLocalPtr, BasicDestructor) {
57   Widget::totalVal_ = 0;
58   ThreadLocalPtr<Widget> w;
59   std::thread([&w]() {
60       w.reset(new Widget());
61       w.get()->val_ += 10;
62     }).join();
63   EXPECT_EQ(10, Widget::totalVal_);
64 }
65
66 TEST(ThreadLocalPtr, CustomDeleter1) {
67   Widget::totalVal_ = 0;
68   {
69     ThreadLocalPtr<Widget> w;
70     std::thread([&w]() {
71         w.reset(new Widget(), Widget::customDeleter);
72         w.get()->val_ += 10;
73       }).join();
74     EXPECT_EQ(10, Widget::totalVal_);
75   }
76   EXPECT_EQ(10, Widget::totalVal_);
77 }
78
79 TEST(ThreadLocalPtr, resetNull) {
80   ThreadLocalPtr<int> tl;
81   EXPECT_FALSE(tl);
82   tl.reset(new int(4));
83   EXPECT_TRUE(static_cast<bool>(tl));
84   EXPECT_EQ(*tl.get(), 4);
85   tl.reset();
86   EXPECT_FALSE(tl);
87 }
88
89 TEST(ThreadLocalPtr, TestRelease) {
90   Widget::totalVal_ = 0;
91   ThreadLocalPtr<Widget> w;
92   std::unique_ptr<Widget> wPtr;
93   std::thread([&w, &wPtr]() {
94       w.reset(new Widget());
95       w.get()->val_ += 10;
96
97       wPtr.reset(w.release());
98     }).join();
99   EXPECT_EQ(0, Widget::totalVal_);
100   wPtr.reset();
101   EXPECT_EQ(10, Widget::totalVal_);
102 }
103
104 TEST(ThreadLocalPtr, CreateOnThreadExit) {
105   Widget::totalVal_ = 0;
106   ThreadLocal<Widget> w;
107   ThreadLocalPtr<int> tl;
108
109   std::thread([&] {
110       tl.reset(new int(1), [&] (int* ptr, TLPDestructionMode mode) {
111         delete ptr;
112         // This test ensures Widgets allocated here are not leaked.
113         ++w.get()->val_;
114         ThreadLocal<Widget> wl;
115         ++wl.get()->val_;
116       });
117     }).join();
118   EXPECT_EQ(2, Widget::totalVal_);
119 }
120
121 // Test deleting the ThreadLocalPtr object
122 TEST(ThreadLocalPtr, CustomDeleter2) {
123   Widget::totalVal_ = 0;
124   std::thread t;
125   std::mutex mutex;
126   std::condition_variable cv;
127   enum class State {
128     START,
129     DONE,
130     EXIT
131   };
132   State state = State::START;
133   {
134     ThreadLocalPtr<Widget> w;
135     t = std::thread([&]() {
136         w.reset(new Widget(), Widget::customDeleter);
137         w.get()->val_ += 10;
138
139         // Notify main thread that we're done
140         {
141           std::unique_lock<std::mutex> lock(mutex);
142           state = State::DONE;
143           cv.notify_all();
144         }
145
146         // Wait for main thread to allow us to exit
147         {
148           std::unique_lock<std::mutex> lock(mutex);
149           while (state != State::EXIT) {
150             cv.wait(lock);
151           }
152         }
153     });
154
155     // Wait for main thread to start (and set w.get()->val_)
156     {
157       std::unique_lock<std::mutex> lock(mutex);
158       while (state != State::DONE) {
159         cv.wait(lock);
160       }
161     }
162
163     // Thread started but hasn't exited yet
164     EXPECT_EQ(0, Widget::totalVal_);
165
166     // Destroy ThreadLocalPtr<Widget> (by letting it go out of scope)
167   }
168
169   EXPECT_EQ(1010, Widget::totalVal_);
170
171   // Allow thread to exit
172   {
173     std::unique_lock<std::mutex> lock(mutex);
174     state = State::EXIT;
175     cv.notify_all();
176   }
177   t.join();
178
179   EXPECT_EQ(1010, Widget::totalVal_);
180 }
181
182 TEST(ThreadLocal, BasicDestructor) {
183   Widget::totalVal_ = 0;
184   ThreadLocal<Widget> w;
185   std::thread([&w]() { w->val_ += 10; }).join();
186   EXPECT_EQ(10, Widget::totalVal_);
187 }
188
189 TEST(ThreadLocal, SimpleRepeatDestructor) {
190   Widget::totalVal_ = 0;
191   {
192     ThreadLocal<Widget> w;
193     w->val_ += 10;
194   }
195   {
196     ThreadLocal<Widget> w;
197     w->val_ += 10;
198   }
199   EXPECT_EQ(20, Widget::totalVal_);
200 }
201
202 TEST(ThreadLocal, InterleavedDestructors) {
203   Widget::totalVal_ = 0;
204   std::unique_ptr<ThreadLocal<Widget>> w;
205   int wVersion = 0;
206   const int wVersionMax = 2;
207   int thIter = 0;
208   std::mutex lock;
209   auto th = std::thread([&]() {
210     int wVersionPrev = 0;
211     while (true) {
212       while (true) {
213         std::lock_guard<std::mutex> g(lock);
214         if (wVersion > wVersionMax) {
215           return;
216         }
217         if (wVersion > wVersionPrev) {
218           // We have a new version of w, so it should be initialized to zero
219           EXPECT_EQ((*w)->val_, 0);
220           break;
221         }
222       }
223       std::lock_guard<std::mutex> g(lock);
224       wVersionPrev = wVersion;
225       (*w)->val_ += 10;
226       ++thIter;
227     }
228   });
229   FOR_EACH_RANGE(i, 0, wVersionMax) {
230     int thIterPrev = 0;
231     {
232       std::lock_guard<std::mutex> g(lock);
233       thIterPrev = thIter;
234       w.reset(new ThreadLocal<Widget>());
235       ++wVersion;
236     }
237     while (true) {
238       std::lock_guard<std::mutex> g(lock);
239       if (thIter > thIterPrev) {
240         break;
241       }
242     }
243   }
244   {
245     std::lock_guard<std::mutex> g(lock);
246     wVersion = wVersionMax + 1;
247   }
248   th.join();
249   EXPECT_EQ(wVersionMax * 10, Widget::totalVal_);
250 }
251
252 class SimpleThreadCachedInt {
253
254   class NewTag;
255   ThreadLocal<int,NewTag> val_;
256
257  public:
258   void add(int val) {
259     *val_ += val;
260   }
261
262   int read() {
263     int ret = 0;
264     for (const auto& i : val_.accessAllThreads()) {
265       ret += i;
266     }
267     return ret;
268   }
269 };
270
271 TEST(ThreadLocalPtr, AccessAllThreadsCounter) {
272   const int kNumThreads = 10;
273   SimpleThreadCachedInt stci;
274   std::atomic<bool> run(true);
275   std::atomic<int> totalAtomic(0);
276   std::vector<std::thread> threads;
277   for (int i = 0; i < kNumThreads; ++i) {
278     threads.push_back(std::thread([&,i]() {
279       stci.add(1);
280       totalAtomic.fetch_add(1);
281       while (run.load()) { usleep(100); }
282     }));
283   }
284   while (totalAtomic.load() != kNumThreads) { usleep(100); }
285   EXPECT_EQ(kNumThreads, stci.read());
286   run.store(false);
287   for (auto& t : threads) {
288     t.join();
289   }
290 }
291
292 TEST(ThreadLocal, resetNull) {
293   ThreadLocal<int> tl;
294   tl.reset(new int(4));
295   EXPECT_EQ(*tl.get(), 4);
296   tl.reset();
297   EXPECT_EQ(*tl.get(), 0);
298   tl.reset(new int(5));
299   EXPECT_EQ(*tl.get(), 5);
300 }
301
302 namespace {
303 struct Tag {};
304
305 struct Foo {
306   folly::ThreadLocal<int, Tag> tl;
307 };
308 }  // namespace
309
310 TEST(ThreadLocal, Movable1) {
311   Foo a;
312   Foo b;
313   EXPECT_TRUE(a.tl.get() != b.tl.get());
314
315   a = Foo();
316   b = Foo();
317   EXPECT_TRUE(a.tl.get() != b.tl.get());
318 }
319
320 TEST(ThreadLocal, Movable2) {
321   std::map<int, Foo> map;
322
323   map[42];
324   map[10];
325   map[23];
326   map[100];
327
328   std::set<void*> tls;
329   for (auto& m : map) {
330     tls.insert(m.second.tl.get());
331   }
332
333   // Make sure that we have 4 different instances of *tl
334   EXPECT_EQ(4, tls.size());
335 }
336
337 namespace {
338
339 constexpr size_t kFillObjectSize = 300;
340
341 std::atomic<uint64_t> gDestroyed;
342
343 /**
344  * Fill a chunk of memory with a unique-ish pattern that includes the thread id
345  * (so deleting one of these from another thread would cause a failure)
346  *
347  * Verify it explicitly and on destruction.
348  */
349 class FillObject {
350  public:
351   explicit FillObject(uint64_t idx) : idx_(idx) {
352     uint64_t v = val();
353     for (size_t i = 0; i < kFillObjectSize; ++i) {
354       data_[i] = v;
355     }
356   }
357
358   void check() {
359     uint64_t v = val();
360     for (size_t i = 0; i < kFillObjectSize; ++i) {
361       CHECK_EQ(v, data_[i]);
362     }
363   }
364
365   ~FillObject() {
366     ++gDestroyed;
367   }
368
369  private:
370   uint64_t val() const {
371     return (idx_ << 40) | uint64_t(pthread_self());
372   }
373
374   uint64_t idx_;
375   uint64_t data_[kFillObjectSize];
376 };
377
378 }  // namespace
379
380 #if FOLLY_HAVE_STD_THIS_THREAD_SLEEP_FOR
381 TEST(ThreadLocal, Stress) {
382   constexpr size_t numFillObjects = 250;
383   std::array<ThreadLocalPtr<FillObject>, numFillObjects> objects;
384
385   constexpr size_t numThreads = 32;
386   constexpr size_t numReps = 20;
387
388   std::vector<std::thread> threads;
389   threads.reserve(numThreads);
390
391   for (size_t i = 0; i < numThreads; ++i) {
392     threads.emplace_back([&objects] {
393       for (size_t rep = 0; rep < numReps; ++rep) {
394         for (size_t i = 0; i < objects.size(); ++i) {
395           objects[i].reset(new FillObject(rep * objects.size() + i));
396           std::this_thread::sleep_for(std::chrono::microseconds(100));
397         }
398         for (size_t i = 0; i < objects.size(); ++i) {
399           objects[i]->check();
400         }
401       }
402     });
403   }
404
405   for (auto& t : threads) {
406     t.join();
407   }
408
409   EXPECT_EQ(numFillObjects * numThreads * numReps, gDestroyed);
410 }
411 #endif
412
413 // Yes, threads and fork don't mix
414 // (http://cppwisdom.quora.com/Why-threads-and-fork-dont-mix) but if you're
415 // stupid or desperate enough to try, we shouldn't stand in your way.
416 namespace {
417 class HoldsOne {
418  public:
419   HoldsOne() : value_(1) { }
420   // Do an actual access to catch the buggy case where this == nullptr
421   int value() const { return value_; }
422  private:
423   int value_;
424 };
425
426 struct HoldsOneTag {};
427
428 ThreadLocal<HoldsOne, HoldsOneTag> ptr;
429
430 int totalValue() {
431   int value = 0;
432   for (auto& p : ptr.accessAllThreads()) {
433     value += p.value();
434   }
435   return value;
436 }
437
438 }  // namespace
439
440 #ifdef FOLLY_HAVE_PTHREAD_ATFORK
441 TEST(ThreadLocal, Fork) {
442   EXPECT_EQ(1, ptr->value());  // ensure created
443   EXPECT_EQ(1, totalValue());
444   // Spawn a new thread
445
446   std::mutex mutex;
447   bool started = false;
448   std::condition_variable startedCond;
449   bool stopped = false;
450   std::condition_variable stoppedCond;
451
452   std::thread t([&] () {
453     EXPECT_EQ(1, ptr->value());  // ensure created
454     {
455       std::unique_lock<std::mutex> lock(mutex);
456       started = true;
457       startedCond.notify_all();
458     }
459     {
460       std::unique_lock<std::mutex> lock(mutex);
461       while (!stopped) {
462         stoppedCond.wait(lock);
463       }
464     }
465   });
466
467   {
468     std::unique_lock<std::mutex> lock(mutex);
469     while (!started) {
470       startedCond.wait(lock);
471     }
472   }
473
474   EXPECT_EQ(2, totalValue());
475
476   pid_t pid = fork();
477   if (pid == 0) {
478     // in child
479     int v = totalValue();
480
481     // exit successfully if v == 1 (one thread)
482     // diagnostic error code otherwise :)
483     switch (v) {
484     case 1: _exit(0);
485     case 0: _exit(1);
486     }
487     _exit(2);
488   } else if (pid > 0) {
489     // in parent
490     int status;
491     EXPECT_EQ(pid, waitpid(pid, &status, 0));
492     EXPECT_TRUE(WIFEXITED(status));
493     EXPECT_EQ(0, WEXITSTATUS(status));
494   } else {
495     EXPECT_TRUE(false) << "fork failed";
496   }
497
498   EXPECT_EQ(2, totalValue());
499
500   {
501     std::unique_lock<std::mutex> lock(mutex);
502     stopped = true;
503     stoppedCond.notify_all();
504   }
505
506   t.join();
507
508   EXPECT_EQ(1, totalValue());
509 }
510 #endif
511
512 struct HoldsOneTag2 {};
513
514 TEST(ThreadLocal, Fork2) {
515   // A thread-local tag that was used in the parent from a *different* thread
516   // (but not the forking thread) would cause the child to hang in a
517   // ThreadLocalPtr's object destructor. Yeah.
518   ThreadLocal<HoldsOne, HoldsOneTag2> p;
519   {
520     // use tag in different thread
521     std::thread t([&p] { p.get(); });
522     t.join();
523   }
524   pid_t pid = fork();
525   if (pid == 0) {
526     {
527       ThreadLocal<HoldsOne, HoldsOneTag2> q;
528       q.get();
529     }
530     _exit(0);
531   } else if (pid > 0) {
532     int status;
533     EXPECT_EQ(pid, waitpid(pid, &status, 0));
534     EXPECT_TRUE(WIFEXITED(status));
535     EXPECT_EQ(0, WEXITSTATUS(status));
536   } else {
537     EXPECT_TRUE(false) << "fork failed";
538   }
539 }
540
541 // Simple reference implementation using pthread_get_specific
542 template<typename T>
543 class PThreadGetSpecific {
544  public:
545   PThreadGetSpecific() : key_(0) {
546     pthread_key_create(&key_, OnThreadExit);
547   }
548
549   T* get() const {
550     return static_cast<T*>(pthread_getspecific(key_));
551   }
552
553   void reset(T* t) {
554     delete get();
555     pthread_setspecific(key_, t);
556   }
557   static void OnThreadExit(void* obj) {
558     delete static_cast<T*>(obj);
559   }
560  private:
561   pthread_key_t key_;
562 };
563
564 DEFINE_int32(numThreads, 8, "Number simultaneous threads for benchmarks.");
565
566 #define REG(var)                                                \
567   BENCHMARK(FB_CONCATENATE(BM_mt_, var), iters) {               \
568     const int itersPerThread = iters / FLAGS_numThreads;        \
569     std::vector<std::thread> threads;                           \
570     for (int i = 0; i < FLAGS_numThreads; ++i) {                \
571       threads.push_back(std::thread([&]() {                     \
572         var.reset(new int(0));                                  \
573         for (int i = 0; i < itersPerThread; ++i) {              \
574           ++(*var.get());                                       \
575         }                                                       \
576       }));                                                      \
577     }                                                           \
578     for (auto& t : threads) {                                   \
579       t.join();                                                 \
580     }                                                           \
581   }
582
583 ThreadLocalPtr<int> tlp;
584 REG(tlp);
585 PThreadGetSpecific<int> pthread_get_specific;
586 REG(pthread_get_specific);
587 boost::thread_specific_ptr<int> boost_tsp;
588 REG(boost_tsp);
589 BENCHMARK_DRAW_LINE();
590
591 int main(int argc, char** argv) {
592   testing::InitGoogleTest(&argc, argv);
593   gflags::ParseCommandLineFlags(&argc, &argv, true);
594   gflags::SetCommandLineOptionWithMode(
595     "bm_max_iters", "100000000", gflags::SET_FLAG_IF_DEFAULT
596   );
597   if (FLAGS_benchmark) {
598     folly::runBenchmarks();
599   }
600   return RUN_ALL_TESTS();
601 }
602
603 /*
604 Ran with 24 threads on dual 12-core Xeon(R) X5650 @ 2.67GHz with 12-MB caches
605
606 Benchmark                               Iters   Total t    t/iter iter/sec
607 ------------------------------------------------------------------------------
608 *       BM_mt_tlp                   100000000  39.88 ms  398.8 ps  2.335 G
609  +5.91% BM_mt_pthread_get_specific  100000000  42.23 ms  422.3 ps  2.205 G
610  + 295% BM_mt_boost_tsp             100000000  157.8 ms  1.578 ns  604.5 M
611 ------------------------------------------------------------------------------
612 */