Make ThreadLocalPtr behave sanely around fork()
[folly.git] / folly / test / ThreadLocalTest.cpp
1 /*
2  * Copyright 2013 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "folly/ThreadLocal.h"
18
19 #include <sys/types.h>
20 #include <sys/wait.h>
21 #include <map>
22 #include <unordered_map>
23 #include <set>
24 #include <atomic>
25 #include <mutex>
26 #include <condition_variable>
27 #include <thread>
28 #include <unistd.h>
29 #include <boost/thread/tss.hpp>
30 #include <gtest/gtest.h>
31 #include <gflags/gflags.h>
32 #include <glog/logging.h>
33 #include "folly/Benchmark.h"
34
35 using namespace folly;
36
37 struct Widget {
38   static int totalVal_;
39   int val_;
40   ~Widget() {
41     totalVal_ += val_;
42   }
43
44   static void customDeleter(Widget* w, TLPDestructionMode mode) {
45     totalVal_ += (mode == TLPDestructionMode::ALL_THREADS) * 1000;
46     delete w;
47   }
48 };
49 int Widget::totalVal_ = 0;
50
51 TEST(ThreadLocalPtr, BasicDestructor) {
52   Widget::totalVal_ = 0;
53   ThreadLocalPtr<Widget> w;
54   std::thread([&w]() {
55       w.reset(new Widget());
56       w.get()->val_ += 10;
57     }).join();
58   EXPECT_EQ(10, Widget::totalVal_);
59 }
60
61 TEST(ThreadLocalPtr, CustomDeleter1) {
62   Widget::totalVal_ = 0;
63   {
64     ThreadLocalPtr<Widget> w;
65     std::thread([&w]() {
66         w.reset(new Widget(), Widget::customDeleter);
67         w.get()->val_ += 10;
68       }).join();
69     EXPECT_EQ(10, Widget::totalVal_);
70   }
71   EXPECT_EQ(10, Widget::totalVal_);
72 }
73
74 TEST(ThreadLocalPtr, resetNull) {
75   ThreadLocalPtr<int> tl;
76   EXPECT_FALSE(tl);
77   tl.reset(new int(4));
78   EXPECT_TRUE(static_cast<bool>(tl));
79   EXPECT_EQ(*tl.get(), 4);
80   tl.reset();
81   EXPECT_FALSE(tl);
82 }
83
84 // Test deleting the ThreadLocalPtr object
85 TEST(ThreadLocalPtr, CustomDeleter2) {
86   Widget::totalVal_ = 0;
87   std::thread t;
88   std::mutex mutex;
89   std::condition_variable cv;
90   enum class State {
91     START,
92     DONE,
93     EXIT
94   };
95   State state = State::START;
96   {
97     ThreadLocalPtr<Widget> w;
98     t = std::thread([&]() {
99         w.reset(new Widget(), Widget::customDeleter);
100         w.get()->val_ += 10;
101
102         // Notify main thread that we're done
103         {
104           std::unique_lock<std::mutex> lock(mutex);
105           state = State::DONE;
106           cv.notify_all();
107         }
108
109         // Wait for main thread to allow us to exit
110         {
111           std::unique_lock<std::mutex> lock(mutex);
112           while (state != State::EXIT) {
113             cv.wait(lock);
114           }
115         }
116     });
117
118     // Wait for main thread to start (and set w.get()->val_)
119     {
120       std::unique_lock<std::mutex> lock(mutex);
121       while (state != State::DONE) {
122         cv.wait(lock);
123       }
124     }
125
126     // Thread started but hasn't exited yet
127     EXPECT_EQ(0, Widget::totalVal_);
128
129     // Destroy ThreadLocalPtr<Widget> (by letting it go out of scope)
130   }
131
132   EXPECT_EQ(1010, Widget::totalVal_);
133
134   // Allow thread to exit
135   {
136     std::unique_lock<std::mutex> lock(mutex);
137     state = State::EXIT;
138     cv.notify_all();
139   }
140   t.join();
141
142   EXPECT_EQ(1010, Widget::totalVal_);
143 }
144
145 TEST(ThreadLocal, BasicDestructor) {
146   Widget::totalVal_ = 0;
147   ThreadLocal<Widget> w;
148   std::thread([&w]() { w->val_ += 10; }).join();
149   EXPECT_EQ(10, Widget::totalVal_);
150 }
151
152 TEST(ThreadLocal, SimpleRepeatDestructor) {
153   Widget::totalVal_ = 0;
154   {
155     ThreadLocal<Widget> w;
156     w->val_ += 10;
157   }
158   {
159     ThreadLocal<Widget> w;
160     w->val_ += 10;
161   }
162   EXPECT_EQ(20, Widget::totalVal_);
163 }
164
165 TEST(ThreadLocal, InterleavedDestructors) {
166   Widget::totalVal_ = 0;
167   ThreadLocal<Widget>* w = NULL;
168   int wVersion = 0;
169   const int wVersionMax = 2;
170   int thIter = 0;
171   std::mutex lock;
172   auto th = std::thread([&]() {
173     int wVersionPrev = 0;
174     while (true) {
175       while (true) {
176         std::lock_guard<std::mutex> g(lock);
177         if (wVersion > wVersionMax) {
178           return;
179         }
180         if (wVersion > wVersionPrev) {
181           // We have a new version of w, so it should be initialized to zero
182           EXPECT_EQ((*w)->val_, 0);
183           break;
184         }
185       }
186       std::lock_guard<std::mutex> g(lock);
187       wVersionPrev = wVersion;
188       (*w)->val_ += 10;
189       ++thIter;
190     }
191   });
192   FOR_EACH_RANGE(i, 0, wVersionMax) {
193     int thIterPrev = 0;
194     {
195       std::lock_guard<std::mutex> g(lock);
196       thIterPrev = thIter;
197       delete w;
198       w = new ThreadLocal<Widget>();
199       ++wVersion;
200     }
201     while (true) {
202       std::lock_guard<std::mutex> g(lock);
203       if (thIter > thIterPrev) {
204         break;
205       }
206     }
207   }
208   {
209     std::lock_guard<std::mutex> g(lock);
210     wVersion = wVersionMax + 1;
211   }
212   th.join();
213   EXPECT_EQ(wVersionMax * 10, Widget::totalVal_);
214 }
215
216 class SimpleThreadCachedInt {
217
218   class NewTag;
219   ThreadLocal<int,NewTag> val_;
220
221  public:
222   void add(int val) {
223     *val_ += val;
224   }
225
226   int read() {
227     int ret = 0;
228     for (const auto& i : val_.accessAllThreads()) {
229       ret += i;
230     }
231     return ret;
232   }
233 };
234
235 TEST(ThreadLocalPtr, AccessAllThreadsCounter) {
236   const int kNumThreads = 10;
237   SimpleThreadCachedInt stci;
238   std::atomic<bool> run(true);
239   std::atomic<int> totalAtomic(0);
240   std::vector<std::thread> threads;
241   for (int i = 0; i < kNumThreads; ++i) {
242     threads.push_back(std::thread([&,i]() {
243       stci.add(1);
244       totalAtomic.fetch_add(1);
245       while (run.load()) { usleep(100); }
246     }));
247   }
248   while (totalAtomic.load() != kNumThreads) { usleep(100); }
249   EXPECT_EQ(kNumThreads, stci.read());
250   run.store(false);
251   for (auto& t : threads) {
252     t.join();
253   }
254 }
255
256 TEST(ThreadLocal, resetNull) {
257   ThreadLocal<int> tl;
258   tl.reset(new int(4));
259   EXPECT_EQ(*tl.get(), 4);
260   tl.reset();
261   EXPECT_EQ(*tl.get(), 0);
262   tl.reset(new int(5));
263   EXPECT_EQ(*tl.get(), 5);
264 }
265
266 namespace {
267 struct Tag {};
268
269 struct Foo {
270   folly::ThreadLocal<int, Tag> tl;
271 };
272 }  // namespace
273
274 TEST(ThreadLocal, Movable1) {
275   Foo a;
276   Foo b;
277   EXPECT_TRUE(a.tl.get() != b.tl.get());
278
279   a = Foo();
280   b = Foo();
281   EXPECT_TRUE(a.tl.get() != b.tl.get());
282 }
283
284 TEST(ThreadLocal, Movable2) {
285   std::map<int, Foo> map;
286
287   map[42];
288   map[10];
289   map[23];
290   map[100];
291
292   std::set<void*> tls;
293   for (auto& m : map) {
294     tls.insert(m.second.tl.get());
295   }
296
297   // Make sure that we have 4 different instances of *tl
298   EXPECT_EQ(4, tls.size());
299 }
300
301 // Yes, threads and fork don't mix
302 // (http://cppwisdom.quora.com/Why-threads-and-fork-dont-mix) but if you're
303 // stupid or desperate enough to try, we shouldn't stand in your way.
304 namespace {
305 class HoldsOne {
306  public:
307   HoldsOne() : value_(1) { }
308   // Do an actual access to catch the buggy case where this == nullptr
309   int value() const { return value_; }
310  private:
311   int value_;
312 };
313
314 struct HoldsOneTag {};
315
316 ThreadLocal<HoldsOne, HoldsOneTag> ptr;
317
318 int totalValue() {
319   int value = 0;
320   for (auto& p : ptr.accessAllThreads()) {
321     value += p.value();
322   }
323   return value;
324 }
325
326 }  // namespace
327
328 TEST(ThreadLocal, Fork) {
329   EXPECT_EQ(1, ptr->value());  // ensure created
330   EXPECT_EQ(1, totalValue());
331   // Spawn a new thread
332
333   std::mutex mutex;
334   bool started = false;
335   std::condition_variable startedCond;
336   bool stopped = false;
337   std::condition_variable stoppedCond;
338
339   std::thread t([&] () {
340     EXPECT_EQ(1, ptr->value());  // ensure created
341     {
342       std::unique_lock<std::mutex> lock(mutex);
343       started = true;
344       startedCond.notify_all();
345     }
346     {
347       std::unique_lock<std::mutex> lock(mutex);
348       while (!stopped) {
349         stoppedCond.wait(lock);
350       }
351     }
352   });
353
354   {
355     std::unique_lock<std::mutex> lock(mutex);
356     while (!started) {
357       startedCond.wait(lock);
358     }
359   }
360
361   EXPECT_EQ(2, totalValue());
362
363   pid_t pid = fork();
364   if (pid == 0) {
365     // in child
366     int v = totalValue();
367
368     // exit successfully if v == 1 (one thread)
369     // diagnostic error code otherwise :)
370     switch (v) {
371     case 1: _exit(0);
372     case 0: _exit(1);
373     }
374     _exit(2);
375   } else if (pid > 0) {
376     // in parent
377     int status;
378     EXPECT_EQ(pid, waitpid(pid, &status, 0));
379     EXPECT_TRUE(WIFEXITED(status));
380     EXPECT_EQ(0, WEXITSTATUS(status));
381   } else {
382     EXPECT_TRUE(false) << "fork failed";
383   }
384
385   EXPECT_EQ(2, totalValue());
386
387   {
388     std::unique_lock<std::mutex> lock(mutex);
389     stopped = true;
390     stoppedCond.notify_all();
391   }
392
393   t.join();
394
395   EXPECT_EQ(1, totalValue());
396 }
397
398 // Simple reference implementation using pthread_get_specific
399 template<typename T>
400 class PThreadGetSpecific {
401  public:
402   PThreadGetSpecific() : key_(0) {
403     pthread_key_create(&key_, OnThreadExit);
404   }
405
406   T* get() const {
407     return static_cast<T*>(pthread_getspecific(key_));
408   }
409
410   void reset(T* t) {
411     delete get();
412     pthread_setspecific(key_, t);
413   }
414   static void OnThreadExit(void* obj) {
415     delete static_cast<T*>(obj);
416   }
417  private:
418   pthread_key_t key_;
419 };
420
421 DEFINE_int32(numThreads, 8, "Number simultaneous threads for benchmarks.");
422
423 #define REG(var)                                                \
424   BENCHMARK(FB_CONCATENATE(BM_mt_, var), iters) {               \
425     const int itersPerThread = iters / FLAGS_numThreads;        \
426     std::vector<std::thread> threads;                           \
427     for (int i = 0; i < FLAGS_numThreads; ++i) {                \
428       threads.push_back(std::thread([&]() {                     \
429         var.reset(new int(0));                                  \
430         for (int i = 0; i < itersPerThread; ++i) {              \
431           ++(*var.get());                                       \
432         }                                                       \
433       }));                                                      \
434     }                                                           \
435     for (auto& t : threads) {                                   \
436       t.join();                                                 \
437     }                                                           \
438   }
439
440 ThreadLocalPtr<int> tlp;
441 REG(tlp);
442 PThreadGetSpecific<int> pthread_get_specific;
443 REG(pthread_get_specific);
444 boost::thread_specific_ptr<int> boost_tsp;
445 REG(boost_tsp);
446 BENCHMARK_DRAW_LINE();
447
448 int main(int argc, char** argv) {
449   testing::InitGoogleTest(&argc, argv);
450   google::ParseCommandLineFlags(&argc, &argv, true);
451   google::SetCommandLineOptionWithMode(
452     "bm_max_iters", "100000000", google::SET_FLAG_IF_DEFAULT
453   );
454   if (FLAGS_benchmark) {
455     folly::runBenchmarks();
456   }
457   return RUN_ALL_TESTS();
458 }
459
460 /*
461 Ran with 24 threads on dual 12-core Xeon(R) X5650 @ 2.67GHz with 12-MB caches
462
463 Benchmark                               Iters   Total t    t/iter iter/sec
464 ------------------------------------------------------------------------------
465 *       BM_mt_tlp                   100000000  39.88 ms  398.8 ps  2.335 G
466  +5.91% BM_mt_pthread_get_specific  100000000  42.23 ms  422.3 ps  2.205 G
467  + 295% BM_mt_boost_tsp             100000000  157.8 ms  1.578 ns  604.5 M
468 ------------------------------------------------------------------------------
469 */