a5de491618c153cd2f94d15c2e67841ac1432e51
[folly.git] / folly / test / ThreadLocalTest.cpp
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/ThreadLocal.h>
18
19 #include <sys/types.h>
20 #include <sys/wait.h>
21 #include <unistd.h>
22
23 #include <array>
24 #include <atomic>
25 #include <chrono>
26 #include <condition_variable>
27 #include <map>
28 #include <mutex>
29 #include <set>
30 #include <thread>
31 #include <unordered_map>
32
33 #include <boost/thread/tss.hpp>
34 #include <gflags/gflags.h>
35 #include <glog/logging.h>
36 #include <gtest/gtest.h>
37
38 #include <folly/Benchmark.h>
39
40 using namespace folly;
41
42 struct Widget {
43   static int totalVal_;
44   int val_;
45   ~Widget() {
46     totalVal_ += val_;
47   }
48
49   static void customDeleter(Widget* w, TLPDestructionMode mode) {
50     totalVal_ += (mode == TLPDestructionMode::ALL_THREADS) * 1000;
51     delete w;
52   }
53 };
54 int Widget::totalVal_ = 0;
55
56 TEST(ThreadLocalPtr, BasicDestructor) {
57   Widget::totalVal_ = 0;
58   ThreadLocalPtr<Widget> w;
59   std::thread([&w]() {
60       w.reset(new Widget());
61       w.get()->val_ += 10;
62     }).join();
63   EXPECT_EQ(10, Widget::totalVal_);
64 }
65
66 TEST(ThreadLocalPtr, CustomDeleter1) {
67   Widget::totalVal_ = 0;
68   {
69     ThreadLocalPtr<Widget> w;
70     std::thread([&w]() {
71         w.reset(new Widget(), Widget::customDeleter);
72         w.get()->val_ += 10;
73       }).join();
74     EXPECT_EQ(10, Widget::totalVal_);
75   }
76   EXPECT_EQ(10, Widget::totalVal_);
77 }
78
79 TEST(ThreadLocalPtr, resetNull) {
80   ThreadLocalPtr<int> tl;
81   EXPECT_FALSE(tl);
82   tl.reset(new int(4));
83   EXPECT_TRUE(static_cast<bool>(tl));
84   EXPECT_EQ(*tl.get(), 4);
85   tl.reset();
86   EXPECT_FALSE(tl);
87 }
88
89 TEST(ThreadLocalPtr, TestRelease) {
90   Widget::totalVal_ = 0;
91   ThreadLocalPtr<Widget> w;
92   std::unique_ptr<Widget> wPtr;
93   std::thread([&w, &wPtr]() {
94       w.reset(new Widget());
95       w.get()->val_ += 10;
96
97       wPtr.reset(w.release());
98     }).join();
99   EXPECT_EQ(0, Widget::totalVal_);
100   wPtr.reset();
101   EXPECT_EQ(10, Widget::totalVal_);
102 }
103
104 // Test deleting the ThreadLocalPtr object
105 TEST(ThreadLocalPtr, CustomDeleter2) {
106   Widget::totalVal_ = 0;
107   std::thread t;
108   std::mutex mutex;
109   std::condition_variable cv;
110   enum class State {
111     START,
112     DONE,
113     EXIT
114   };
115   State state = State::START;
116   {
117     ThreadLocalPtr<Widget> w;
118     t = std::thread([&]() {
119         w.reset(new Widget(), Widget::customDeleter);
120         w.get()->val_ += 10;
121
122         // Notify main thread that we're done
123         {
124           std::unique_lock<std::mutex> lock(mutex);
125           state = State::DONE;
126           cv.notify_all();
127         }
128
129         // Wait for main thread to allow us to exit
130         {
131           std::unique_lock<std::mutex> lock(mutex);
132           while (state != State::EXIT) {
133             cv.wait(lock);
134           }
135         }
136     });
137
138     // Wait for main thread to start (and set w.get()->val_)
139     {
140       std::unique_lock<std::mutex> lock(mutex);
141       while (state != State::DONE) {
142         cv.wait(lock);
143       }
144     }
145
146     // Thread started but hasn't exited yet
147     EXPECT_EQ(0, Widget::totalVal_);
148
149     // Destroy ThreadLocalPtr<Widget> (by letting it go out of scope)
150   }
151
152   EXPECT_EQ(1010, Widget::totalVal_);
153
154   // Allow thread to exit
155   {
156     std::unique_lock<std::mutex> lock(mutex);
157     state = State::EXIT;
158     cv.notify_all();
159   }
160   t.join();
161
162   EXPECT_EQ(1010, Widget::totalVal_);
163 }
164
165 TEST(ThreadLocal, BasicDestructor) {
166   Widget::totalVal_ = 0;
167   ThreadLocal<Widget> w;
168   std::thread([&w]() { w->val_ += 10; }).join();
169   EXPECT_EQ(10, Widget::totalVal_);
170 }
171
172 TEST(ThreadLocal, SimpleRepeatDestructor) {
173   Widget::totalVal_ = 0;
174   {
175     ThreadLocal<Widget> w;
176     w->val_ += 10;
177   }
178   {
179     ThreadLocal<Widget> w;
180     w->val_ += 10;
181   }
182   EXPECT_EQ(20, Widget::totalVal_);
183 }
184
185 TEST(ThreadLocal, InterleavedDestructors) {
186   Widget::totalVal_ = 0;
187   ThreadLocal<Widget>* w = nullptr;
188   int wVersion = 0;
189   const int wVersionMax = 2;
190   int thIter = 0;
191   std::mutex lock;
192   auto th = std::thread([&]() {
193     int wVersionPrev = 0;
194     while (true) {
195       while (true) {
196         std::lock_guard<std::mutex> g(lock);
197         if (wVersion > wVersionMax) {
198           return;
199         }
200         if (wVersion > wVersionPrev) {
201           // We have a new version of w, so it should be initialized to zero
202           EXPECT_EQ((*w)->val_, 0);
203           break;
204         }
205       }
206       std::lock_guard<std::mutex> g(lock);
207       wVersionPrev = wVersion;
208       (*w)->val_ += 10;
209       ++thIter;
210     }
211   });
212   FOR_EACH_RANGE(i, 0, wVersionMax) {
213     int thIterPrev = 0;
214     {
215       std::lock_guard<std::mutex> g(lock);
216       thIterPrev = thIter;
217       delete w;
218       w = new ThreadLocal<Widget>();
219       ++wVersion;
220     }
221     while (true) {
222       std::lock_guard<std::mutex> g(lock);
223       if (thIter > thIterPrev) {
224         break;
225       }
226     }
227   }
228   {
229     std::lock_guard<std::mutex> g(lock);
230     wVersion = wVersionMax + 1;
231   }
232   th.join();
233   EXPECT_EQ(wVersionMax * 10, Widget::totalVal_);
234 }
235
236 class SimpleThreadCachedInt {
237
238   class NewTag;
239   ThreadLocal<int,NewTag> val_;
240
241  public:
242   void add(int val) {
243     *val_ += val;
244   }
245
246   int read() {
247     int ret = 0;
248     for (const auto& i : val_.accessAllThreads()) {
249       ret += i;
250     }
251     return ret;
252   }
253 };
254
255 TEST(ThreadLocalPtr, AccessAllThreadsCounter) {
256   const int kNumThreads = 10;
257   SimpleThreadCachedInt stci;
258   std::atomic<bool> run(true);
259   std::atomic<int> totalAtomic(0);
260   std::vector<std::thread> threads;
261   for (int i = 0; i < kNumThreads; ++i) {
262     threads.push_back(std::thread([&,i]() {
263       stci.add(1);
264       totalAtomic.fetch_add(1);
265       while (run.load()) { usleep(100); }
266     }));
267   }
268   while (totalAtomic.load() != kNumThreads) { usleep(100); }
269   EXPECT_EQ(kNumThreads, stci.read());
270   run.store(false);
271   for (auto& t : threads) {
272     t.join();
273   }
274 }
275
276 TEST(ThreadLocal, resetNull) {
277   ThreadLocal<int> tl;
278   tl.reset(new int(4));
279   EXPECT_EQ(*tl.get(), 4);
280   tl.reset();
281   EXPECT_EQ(*tl.get(), 0);
282   tl.reset(new int(5));
283   EXPECT_EQ(*tl.get(), 5);
284 }
285
286 namespace {
287 struct Tag {};
288
289 struct Foo {
290   folly::ThreadLocal<int, Tag> tl;
291 };
292 }  // namespace
293
294 TEST(ThreadLocal, Movable1) {
295   Foo a;
296   Foo b;
297   EXPECT_TRUE(a.tl.get() != b.tl.get());
298
299   a = Foo();
300   b = Foo();
301   EXPECT_TRUE(a.tl.get() != b.tl.get());
302 }
303
304 TEST(ThreadLocal, Movable2) {
305   std::map<int, Foo> map;
306
307   map[42];
308   map[10];
309   map[23];
310   map[100];
311
312   std::set<void*> tls;
313   for (auto& m : map) {
314     tls.insert(m.second.tl.get());
315   }
316
317   // Make sure that we have 4 different instances of *tl
318   EXPECT_EQ(4, tls.size());
319 }
320
321 namespace {
322
323 constexpr size_t kFillObjectSize = 300;
324
325 std::atomic<uint64_t> gDestroyed;
326
327 /**
328  * Fill a chunk of memory with a unique-ish pattern that includes the thread id
329  * (so deleting one of these from another thread would cause a failure)
330  *
331  * Verify it explicitly and on destruction.
332  */
333 class FillObject {
334  public:
335   explicit FillObject(uint64_t idx) : idx_(idx) {
336     uint64_t v = val();
337     for (size_t i = 0; i < kFillObjectSize; ++i) {
338       data_[i] = v;
339     }
340   }
341
342   void check() {
343     uint64_t v = val();
344     for (size_t i = 0; i < kFillObjectSize; ++i) {
345       CHECK_EQ(v, data_[i]);
346     }
347   }
348
349   ~FillObject() {
350     ++gDestroyed;
351   }
352
353  private:
354   uint64_t val() const {
355     return (idx_ << 40) | uint64_t(pthread_self());
356   }
357
358   uint64_t idx_;
359   uint64_t data_[kFillObjectSize];
360 };
361
362 }  // namespace
363
364 #if FOLLY_HAVE_STD__THIS_THREAD__SLEEP_FOR
365 TEST(ThreadLocal, Stress) {
366   constexpr size_t numFillObjects = 250;
367   std::array<ThreadLocalPtr<FillObject>, numFillObjects> objects;
368
369   constexpr size_t numThreads = 32;
370   constexpr size_t numReps = 20;
371
372   std::vector<std::thread> threads;
373   threads.reserve(numThreads);
374
375   for (size_t i = 0; i < numThreads; ++i) {
376     threads.emplace_back([&objects] {
377       for (size_t rep = 0; rep < numReps; ++rep) {
378         for (size_t i = 0; i < objects.size(); ++i) {
379           objects[i].reset(new FillObject(rep * objects.size() + i));
380           std::this_thread::sleep_for(std::chrono::microseconds(100));
381         }
382         for (size_t i = 0; i < objects.size(); ++i) {
383           objects[i]->check();
384         }
385       }
386     });
387   }
388
389   for (auto& t : threads) {
390     t.join();
391   }
392
393   EXPECT_EQ(numFillObjects * numThreads * numReps, gDestroyed);
394 }
395 #endif
396
397 // Yes, threads and fork don't mix
398 // (http://cppwisdom.quora.com/Why-threads-and-fork-dont-mix) but if you're
399 // stupid or desperate enough to try, we shouldn't stand in your way.
400 namespace {
401 class HoldsOne {
402  public:
403   HoldsOne() : value_(1) { }
404   // Do an actual access to catch the buggy case where this == nullptr
405   int value() const { return value_; }
406  private:
407   int value_;
408 };
409
410 struct HoldsOneTag {};
411
412 ThreadLocal<HoldsOne, HoldsOneTag> ptr;
413
414 int totalValue() {
415   int value = 0;
416   for (auto& p : ptr.accessAllThreads()) {
417     value += p.value();
418   }
419   return value;
420 }
421
422 }  // namespace
423
424 TEST(ThreadLocal, Fork) {
425   EXPECT_EQ(1, ptr->value());  // ensure created
426   EXPECT_EQ(1, totalValue());
427   // Spawn a new thread
428
429   std::mutex mutex;
430   bool started = false;
431   std::condition_variable startedCond;
432   bool stopped = false;
433   std::condition_variable stoppedCond;
434
435   std::thread t([&] () {
436     EXPECT_EQ(1, ptr->value());  // ensure created
437     {
438       std::unique_lock<std::mutex> lock(mutex);
439       started = true;
440       startedCond.notify_all();
441     }
442     {
443       std::unique_lock<std::mutex> lock(mutex);
444       while (!stopped) {
445         stoppedCond.wait(lock);
446       }
447     }
448   });
449
450   {
451     std::unique_lock<std::mutex> lock(mutex);
452     while (!started) {
453       startedCond.wait(lock);
454     }
455   }
456
457   EXPECT_EQ(2, totalValue());
458
459   pid_t pid = fork();
460   if (pid == 0) {
461     // in child
462     int v = totalValue();
463
464     // exit successfully if v == 1 (one thread)
465     // diagnostic error code otherwise :)
466     switch (v) {
467     case 1: _exit(0);
468     case 0: _exit(1);
469     }
470     _exit(2);
471   } else if (pid > 0) {
472     // in parent
473     int status;
474     EXPECT_EQ(pid, waitpid(pid, &status, 0));
475     EXPECT_TRUE(WIFEXITED(status));
476     EXPECT_EQ(0, WEXITSTATUS(status));
477   } else {
478     EXPECT_TRUE(false) << "fork failed";
479   }
480
481   EXPECT_EQ(2, totalValue());
482
483   {
484     std::unique_lock<std::mutex> lock(mutex);
485     stopped = true;
486     stoppedCond.notify_all();
487   }
488
489   t.join();
490
491   EXPECT_EQ(1, totalValue());
492 }
493
494 struct HoldsOneTag2 {};
495
496 TEST(ThreadLocal, Fork2) {
497   // A thread-local tag that was used in the parent from a *different* thread
498   // (but not the forking thread) would cause the child to hang in a
499   // ThreadLocalPtr's object destructor. Yeah.
500   ThreadLocal<HoldsOne, HoldsOneTag2> p;
501   {
502     // use tag in different thread
503     std::thread t([&p] { p.get(); });
504     t.join();
505   }
506   pid_t pid = fork();
507   if (pid == 0) {
508     {
509       ThreadLocal<HoldsOne, HoldsOneTag2> q;
510       q.get();
511     }
512     _exit(0);
513   } else if (pid > 0) {
514     int status;
515     EXPECT_EQ(pid, waitpid(pid, &status, 0));
516     EXPECT_TRUE(WIFEXITED(status));
517     EXPECT_EQ(0, WEXITSTATUS(status));
518   } else {
519     EXPECT_TRUE(false) << "fork failed";
520   }
521 }
522
523 // Simple reference implementation using pthread_get_specific
524 template<typename T>
525 class PThreadGetSpecific {
526  public:
527   PThreadGetSpecific() : key_(0) {
528     pthread_key_create(&key_, OnThreadExit);
529   }
530
531   T* get() const {
532     return static_cast<T*>(pthread_getspecific(key_));
533   }
534
535   void reset(T* t) {
536     delete get();
537     pthread_setspecific(key_, t);
538   }
539   static void OnThreadExit(void* obj) {
540     delete static_cast<T*>(obj);
541   }
542  private:
543   pthread_key_t key_;
544 };
545
546 DEFINE_int32(numThreads, 8, "Number simultaneous threads for benchmarks.");
547
548 #define REG(var)                                                \
549   BENCHMARK(FB_CONCATENATE(BM_mt_, var), iters) {               \
550     const int itersPerThread = iters / FLAGS_numThreads;        \
551     std::vector<std::thread> threads;                           \
552     for (int i = 0; i < FLAGS_numThreads; ++i) {                \
553       threads.push_back(std::thread([&]() {                     \
554         var.reset(new int(0));                                  \
555         for (int i = 0; i < itersPerThread; ++i) {              \
556           ++(*var.get());                                       \
557         }                                                       \
558       }));                                                      \
559     }                                                           \
560     for (auto& t : threads) {                                   \
561       t.join();                                                 \
562     }                                                           \
563   }
564
565 ThreadLocalPtr<int> tlp;
566 REG(tlp);
567 PThreadGetSpecific<int> pthread_get_specific;
568 REG(pthread_get_specific);
569 boost::thread_specific_ptr<int> boost_tsp;
570 REG(boost_tsp);
571 BENCHMARK_DRAW_LINE();
572
573 int main(int argc, char** argv) {
574   testing::InitGoogleTest(&argc, argv);
575   google::ParseCommandLineFlags(&argc, &argv, true);
576   google::SetCommandLineOptionWithMode(
577     "bm_max_iters", "100000000", google::SET_FLAG_IF_DEFAULT
578   );
579   if (FLAGS_benchmark) {
580     folly::runBenchmarks();
581   }
582   return RUN_ALL_TESTS();
583 }
584
585 /*
586 Ran with 24 threads on dual 12-core Xeon(R) X5650 @ 2.67GHz with 12-MB caches
587
588 Benchmark                               Iters   Total t    t/iter iter/sec
589 ------------------------------------------------------------------------------
590 *       BM_mt_tlp                   100000000  39.88 ms  398.8 ps  2.335 G
591  +5.91% BM_mt_pthread_get_specific  100000000  42.23 ms  422.3 ps  2.205 G
592  + 295% BM_mt_boost_tsp             100000000  157.8 ms  1.578 ns  604.5 M
593 ------------------------------------------------------------------------------
594 */