d78e07eac2156e625437b4452c908cd418f04603
[folly.git] / folly / test / ThreadLocalTest.cpp
1 /*
2  * Copyright 2013 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "folly/ThreadLocal.h"
18
19 #include <sys/types.h>
20 #include <sys/wait.h>
21 #include <unistd.h>
22
23 #include <array>
24 #include <atomic>
25 #include <chrono>
26 #include <condition_variable>
27 #include <map>
28 #include <mutex>
29 #include <set>
30 #include <thread>
31 #include <unordered_map>
32
33 #include <boost/thread/tss.hpp>
34 #include <gflags/gflags.h>
35 #include <glog/logging.h>
36 #include <gtest/gtest.h>
37
38 #include "folly/Benchmark.h"
39
40 using namespace folly;
41
42 struct Widget {
43   static int totalVal_;
44   int val_;
45   ~Widget() {
46     totalVal_ += val_;
47   }
48
49   static void customDeleter(Widget* w, TLPDestructionMode mode) {
50     totalVal_ += (mode == TLPDestructionMode::ALL_THREADS) * 1000;
51     delete w;
52   }
53 };
54 int Widget::totalVal_ = 0;
55
56 TEST(ThreadLocalPtr, BasicDestructor) {
57   Widget::totalVal_ = 0;
58   ThreadLocalPtr<Widget> w;
59   std::thread([&w]() {
60       w.reset(new Widget());
61       w.get()->val_ += 10;
62     }).join();
63   EXPECT_EQ(10, Widget::totalVal_);
64 }
65
66 TEST(ThreadLocalPtr, CustomDeleter1) {
67   Widget::totalVal_ = 0;
68   {
69     ThreadLocalPtr<Widget> w;
70     std::thread([&w]() {
71         w.reset(new Widget(), Widget::customDeleter);
72         w.get()->val_ += 10;
73       }).join();
74     EXPECT_EQ(10, Widget::totalVal_);
75   }
76   EXPECT_EQ(10, Widget::totalVal_);
77 }
78
79 TEST(ThreadLocalPtr, resetNull) {
80   ThreadLocalPtr<int> tl;
81   EXPECT_FALSE(tl);
82   tl.reset(new int(4));
83   EXPECT_TRUE(static_cast<bool>(tl));
84   EXPECT_EQ(*tl.get(), 4);
85   tl.reset();
86   EXPECT_FALSE(tl);
87 }
88
89 // Test deleting the ThreadLocalPtr object
90 TEST(ThreadLocalPtr, CustomDeleter2) {
91   Widget::totalVal_ = 0;
92   std::thread t;
93   std::mutex mutex;
94   std::condition_variable cv;
95   enum class State {
96     START,
97     DONE,
98     EXIT
99   };
100   State state = State::START;
101   {
102     ThreadLocalPtr<Widget> w;
103     t = std::thread([&]() {
104         w.reset(new Widget(), Widget::customDeleter);
105         w.get()->val_ += 10;
106
107         // Notify main thread that we're done
108         {
109           std::unique_lock<std::mutex> lock(mutex);
110           state = State::DONE;
111           cv.notify_all();
112         }
113
114         // Wait for main thread to allow us to exit
115         {
116           std::unique_lock<std::mutex> lock(mutex);
117           while (state != State::EXIT) {
118             cv.wait(lock);
119           }
120         }
121     });
122
123     // Wait for main thread to start (and set w.get()->val_)
124     {
125       std::unique_lock<std::mutex> lock(mutex);
126       while (state != State::DONE) {
127         cv.wait(lock);
128       }
129     }
130
131     // Thread started but hasn't exited yet
132     EXPECT_EQ(0, Widget::totalVal_);
133
134     // Destroy ThreadLocalPtr<Widget> (by letting it go out of scope)
135   }
136
137   EXPECT_EQ(1010, Widget::totalVal_);
138
139   // Allow thread to exit
140   {
141     std::unique_lock<std::mutex> lock(mutex);
142     state = State::EXIT;
143     cv.notify_all();
144   }
145   t.join();
146
147   EXPECT_EQ(1010, Widget::totalVal_);
148 }
149
150 TEST(ThreadLocal, BasicDestructor) {
151   Widget::totalVal_ = 0;
152   ThreadLocal<Widget> w;
153   std::thread([&w]() { w->val_ += 10; }).join();
154   EXPECT_EQ(10, Widget::totalVal_);
155 }
156
157 TEST(ThreadLocal, SimpleRepeatDestructor) {
158   Widget::totalVal_ = 0;
159   {
160     ThreadLocal<Widget> w;
161     w->val_ += 10;
162   }
163   {
164     ThreadLocal<Widget> w;
165     w->val_ += 10;
166   }
167   EXPECT_EQ(20, Widget::totalVal_);
168 }
169
170 TEST(ThreadLocal, InterleavedDestructors) {
171   Widget::totalVal_ = 0;
172   ThreadLocal<Widget>* w = NULL;
173   int wVersion = 0;
174   const int wVersionMax = 2;
175   int thIter = 0;
176   std::mutex lock;
177   auto th = std::thread([&]() {
178     int wVersionPrev = 0;
179     while (true) {
180       while (true) {
181         std::lock_guard<std::mutex> g(lock);
182         if (wVersion > wVersionMax) {
183           return;
184         }
185         if (wVersion > wVersionPrev) {
186           // We have a new version of w, so it should be initialized to zero
187           EXPECT_EQ((*w)->val_, 0);
188           break;
189         }
190       }
191       std::lock_guard<std::mutex> g(lock);
192       wVersionPrev = wVersion;
193       (*w)->val_ += 10;
194       ++thIter;
195     }
196   });
197   FOR_EACH_RANGE(i, 0, wVersionMax) {
198     int thIterPrev = 0;
199     {
200       std::lock_guard<std::mutex> g(lock);
201       thIterPrev = thIter;
202       delete w;
203       w = new ThreadLocal<Widget>();
204       ++wVersion;
205     }
206     while (true) {
207       std::lock_guard<std::mutex> g(lock);
208       if (thIter > thIterPrev) {
209         break;
210       }
211     }
212   }
213   {
214     std::lock_guard<std::mutex> g(lock);
215     wVersion = wVersionMax + 1;
216   }
217   th.join();
218   EXPECT_EQ(wVersionMax * 10, Widget::totalVal_);
219 }
220
221 class SimpleThreadCachedInt {
222
223   class NewTag;
224   ThreadLocal<int,NewTag> val_;
225
226  public:
227   void add(int val) {
228     *val_ += val;
229   }
230
231   int read() {
232     int ret = 0;
233     for (const auto& i : val_.accessAllThreads()) {
234       ret += i;
235     }
236     return ret;
237   }
238 };
239
240 TEST(ThreadLocalPtr, AccessAllThreadsCounter) {
241   const int kNumThreads = 10;
242   SimpleThreadCachedInt stci;
243   std::atomic<bool> run(true);
244   std::atomic<int> totalAtomic(0);
245   std::vector<std::thread> threads;
246   for (int i = 0; i < kNumThreads; ++i) {
247     threads.push_back(std::thread([&,i]() {
248       stci.add(1);
249       totalAtomic.fetch_add(1);
250       while (run.load()) { usleep(100); }
251     }));
252   }
253   while (totalAtomic.load() != kNumThreads) { usleep(100); }
254   EXPECT_EQ(kNumThreads, stci.read());
255   run.store(false);
256   for (auto& t : threads) {
257     t.join();
258   }
259 }
260
261 TEST(ThreadLocal, resetNull) {
262   ThreadLocal<int> tl;
263   tl.reset(new int(4));
264   EXPECT_EQ(*tl.get(), 4);
265   tl.reset();
266   EXPECT_EQ(*tl.get(), 0);
267   tl.reset(new int(5));
268   EXPECT_EQ(*tl.get(), 5);
269 }
270
271 namespace {
272 struct Tag {};
273
274 struct Foo {
275   folly::ThreadLocal<int, Tag> tl;
276 };
277 }  // namespace
278
279 TEST(ThreadLocal, Movable1) {
280   Foo a;
281   Foo b;
282   EXPECT_TRUE(a.tl.get() != b.tl.get());
283
284   a = Foo();
285   b = Foo();
286   EXPECT_TRUE(a.tl.get() != b.tl.get());
287 }
288
289 TEST(ThreadLocal, Movable2) {
290   std::map<int, Foo> map;
291
292   map[42];
293   map[10];
294   map[23];
295   map[100];
296
297   std::set<void*> tls;
298   for (auto& m : map) {
299     tls.insert(m.second.tl.get());
300   }
301
302   // Make sure that we have 4 different instances of *tl
303   EXPECT_EQ(4, tls.size());
304 }
305
306 namespace {
307
308 constexpr size_t kFillObjectSize = 300;
309
310 std::atomic<uint64_t> gDestroyed;
311
312 /**
313  * Fill a chunk of memory with a unique-ish pattern that includes the thread id
314  * (so deleting one of these from another thread would cause a failure)
315  *
316  * Verify it explicitly and on destruction.
317  */
318 class FillObject {
319  public:
320   explicit FillObject(uint64_t idx) : idx_(idx) {
321     uint64_t v = val();
322     for (size_t i = 0; i < kFillObjectSize; ++i) {
323       data_[i] = v;
324     }
325   }
326
327   void check() {
328     uint64_t v = val();
329     for (size_t i = 0; i < kFillObjectSize; ++i) {
330       CHECK_EQ(v, data_[i]);
331     }
332   }
333
334   ~FillObject() {
335     ++gDestroyed;
336   }
337
338  private:
339   uint64_t val() const {
340     return (idx_ << 40) | uint64_t(pthread_self());
341   }
342
343   uint64_t idx_;
344   uint64_t data_[kFillObjectSize];
345 };
346
347 }  // namespace
348
349 #if FOLLY_HAVE_STD__THIS_THREAD__SLEEP_FOR
350 TEST(ThreadLocal, Stress) {
351   constexpr size_t numFillObjects = 250;
352   std::array<ThreadLocalPtr<FillObject>, numFillObjects> objects;
353
354   constexpr size_t numThreads = 32;
355   constexpr size_t numReps = 20;
356
357   std::vector<std::thread> threads;
358   threads.reserve(numThreads);
359
360   for (size_t i = 0; i < numThreads; ++i) {
361     threads.emplace_back([&objects] {
362       for (size_t rep = 0; rep < numReps; ++rep) {
363         for (size_t i = 0; i < objects.size(); ++i) {
364           objects[i].reset(new FillObject(rep * objects.size() + i));
365           std::this_thread::sleep_for(std::chrono::microseconds(100));
366         }
367         for (size_t i = 0; i < objects.size(); ++i) {
368           objects[i]->check();
369         }
370       }
371     });
372   }
373
374   for (auto& t : threads) {
375     t.join();
376   }
377
378   EXPECT_EQ(numFillObjects * numThreads * numReps, gDestroyed);
379 }
380 #endif
381
382 // Yes, threads and fork don't mix
383 // (http://cppwisdom.quora.com/Why-threads-and-fork-dont-mix) but if you're
384 // stupid or desperate enough to try, we shouldn't stand in your way.
385 namespace {
386 class HoldsOne {
387  public:
388   HoldsOne() : value_(1) { }
389   // Do an actual access to catch the buggy case where this == nullptr
390   int value() const { return value_; }
391  private:
392   int value_;
393 };
394
395 struct HoldsOneTag {};
396
397 ThreadLocal<HoldsOne, HoldsOneTag> ptr;
398
399 int totalValue() {
400   int value = 0;
401   for (auto& p : ptr.accessAllThreads()) {
402     value += p.value();
403   }
404   return value;
405 }
406
407 }  // namespace
408
409 TEST(ThreadLocal, Fork) {
410   EXPECT_EQ(1, ptr->value());  // ensure created
411   EXPECT_EQ(1, totalValue());
412   // Spawn a new thread
413
414   std::mutex mutex;
415   bool started = false;
416   std::condition_variable startedCond;
417   bool stopped = false;
418   std::condition_variable stoppedCond;
419
420   std::thread t([&] () {
421     EXPECT_EQ(1, ptr->value());  // ensure created
422     {
423       std::unique_lock<std::mutex> lock(mutex);
424       started = true;
425       startedCond.notify_all();
426     }
427     {
428       std::unique_lock<std::mutex> lock(mutex);
429       while (!stopped) {
430         stoppedCond.wait(lock);
431       }
432     }
433   });
434
435   {
436     std::unique_lock<std::mutex> lock(mutex);
437     while (!started) {
438       startedCond.wait(lock);
439     }
440   }
441
442   EXPECT_EQ(2, totalValue());
443
444   pid_t pid = fork();
445   if (pid == 0) {
446     // in child
447     int v = totalValue();
448
449     // exit successfully if v == 1 (one thread)
450     // diagnostic error code otherwise :)
451     switch (v) {
452     case 1: _exit(0);
453     case 0: _exit(1);
454     }
455     _exit(2);
456   } else if (pid > 0) {
457     // in parent
458     int status;
459     EXPECT_EQ(pid, waitpid(pid, &status, 0));
460     EXPECT_TRUE(WIFEXITED(status));
461     EXPECT_EQ(0, WEXITSTATUS(status));
462   } else {
463     EXPECT_TRUE(false) << "fork failed";
464   }
465
466   EXPECT_EQ(2, totalValue());
467
468   {
469     std::unique_lock<std::mutex> lock(mutex);
470     stopped = true;
471     stoppedCond.notify_all();
472   }
473
474   t.join();
475
476   EXPECT_EQ(1, totalValue());
477 }
478
479 // Simple reference implementation using pthread_get_specific
480 template<typename T>
481 class PThreadGetSpecific {
482  public:
483   PThreadGetSpecific() : key_(0) {
484     pthread_key_create(&key_, OnThreadExit);
485   }
486
487   T* get() const {
488     return static_cast<T*>(pthread_getspecific(key_));
489   }
490
491   void reset(T* t) {
492     delete get();
493     pthread_setspecific(key_, t);
494   }
495   static void OnThreadExit(void* obj) {
496     delete static_cast<T*>(obj);
497   }
498  private:
499   pthread_key_t key_;
500 };
501
502 DEFINE_int32(numThreads, 8, "Number simultaneous threads for benchmarks.");
503
504 #define REG(var)                                                \
505   BENCHMARK(FB_CONCATENATE(BM_mt_, var), iters) {               \
506     const int itersPerThread = iters / FLAGS_numThreads;        \
507     std::vector<std::thread> threads;                           \
508     for (int i = 0; i < FLAGS_numThreads; ++i) {                \
509       threads.push_back(std::thread([&]() {                     \
510         var.reset(new int(0));                                  \
511         for (int i = 0; i < itersPerThread; ++i) {              \
512           ++(*var.get());                                       \
513         }                                                       \
514       }));                                                      \
515     }                                                           \
516     for (auto& t : threads) {                                   \
517       t.join();                                                 \
518     }                                                           \
519   }
520
521 ThreadLocalPtr<int> tlp;
522 REG(tlp);
523 PThreadGetSpecific<int> pthread_get_specific;
524 REG(pthread_get_specific);
525 boost::thread_specific_ptr<int> boost_tsp;
526 REG(boost_tsp);
527 BENCHMARK_DRAW_LINE();
528
529 int main(int argc, char** argv) {
530   testing::InitGoogleTest(&argc, argv);
531   google::ParseCommandLineFlags(&argc, &argv, true);
532   google::SetCommandLineOptionWithMode(
533     "bm_max_iters", "100000000", google::SET_FLAG_IF_DEFAULT
534   );
535   if (FLAGS_benchmark) {
536     folly::runBenchmarks();
537   }
538   return RUN_ALL_TESTS();
539 }
540
541 /*
542 Ran with 24 threads on dual 12-core Xeon(R) X5650 @ 2.67GHz with 12-MB caches
543
544 Benchmark                               Iters   Total t    t/iter iter/sec
545 ------------------------------------------------------------------------------
546 *       BM_mt_tlp                   100000000  39.88 ms  398.8 ps  2.335 G
547  +5.91% BM_mt_pthread_get_specific  100000000  42.23 ms  422.3 ps  2.205 G
548  + 295% BM_mt_boost_tsp             100000000  157.8 ms  1.578 ns  604.5 M
549 ------------------------------------------------------------------------------
550 */