Make ThreadLocalPtr behave sanely around fork()
authorTudor Bosman <tudorb@fb.com>
Thu, 1 Aug 2013 21:21:52 +0000 (14:21 -0700)
committerSara Golemon <sgolemon@fb.com>
Wed, 28 Aug 2013 21:30:11 +0000 (14:30 -0700)
Summary:
Threads and fork still don't mix, but we shouldn't help you shoot yourself in
the foot if you decide to do it.

Test Plan: test added

Reviewed By: mshneer@fb.com

FB internal diff: D911224

folly/detail/ThreadLocalDetail.h
folly/test/ThreadLocalTest.cpp

index 6891d3e8d0f10cb53871311570ce390aa459c3a0..ad53b9c9e671490260f858a915a42697901e2f6f 100644 (file)
@@ -29,6 +29,7 @@
 
 #include <glog/logging.h>
 
+#include "folly/Exception.h"
 #include "folly/Foreach.h"
 #include "folly/Malloc.h"
 
@@ -146,9 +147,9 @@ struct StaticMeta {
   static StaticMeta<Tag>& instance() {
     // Leak it on exit, there's only one per process and we don't have to
     // worry about synchronization with exiting threads.
-    static bool constructed = (inst = new StaticMeta<Tag>());
+    static bool constructed = (inst_ = new StaticMeta<Tag>());
     (void)constructed; // suppress unused warning
-    return *inst;
+    return *inst_;
   }
 
   int nextId_;
@@ -171,33 +172,36 @@ struct StaticMeta {
   }
 
   static __thread ThreadEntry threadEntry_;
-  static StaticMeta<Tag>* inst;
+  static StaticMeta<Tag>* inst_;
 
   StaticMeta() : nextId_(1) {
     head_.next = head_.prev = &head_;
     int ret = pthread_key_create(&pthreadKey_, &onThreadExit);
-    if (ret != 0) {
-      std::string msg;
-      switch (ret) {
-        case EAGAIN:
-          char buf[100];
-          snprintf(buf, sizeof(buf), "PTHREAD_KEYS_MAX (%d) is exceeded",
-                   PTHREAD_KEYS_MAX);
-          msg = buf;
-          break;
-        case ENOMEM:
-          msg = "Out-of-memory";
-          break;
-        default:
-          msg = "(unknown error)";
-      }
-      throw std::runtime_error("pthread_key_create failed: " + msg);
-    }
+    checkPosixError(ret, "pthread_key_create failed");
+
+    ret = pthread_atfork(/*prepare*/ &StaticMeta::preFork,
+                         /*parent*/ &StaticMeta::onForkParent,
+                         /*child*/ &StaticMeta::onForkChild);
+    checkPosixError(ret, "pthread_atfork failed");
   }
   ~StaticMeta() {
     LOG(FATAL) << "StaticMeta lives forever!";
   }
 
+  static void preFork(void) {
+    instance().lock_.lock();  // Make sure it's created
+  }
+
+  static void onForkParent(void) {
+    inst_->lock_.unlock();
+  }
+
+  static void onForkChild(void) {
+    inst_->head_.next = inst_->head_.prev = &inst_->head_;
+    inst_->push_back(&threadEntry_);  // only the current thread survives
+    inst_->lock_.unlock();
+  }
+
   static void onThreadExit(void* ptr) {
     auto & meta = instance();
     DCHECK_EQ(ptr, &meta);
@@ -328,7 +332,7 @@ struct StaticMeta {
 };
 
 template <class Tag> __thread ThreadEntry StaticMeta<Tag>::threadEntry_ = {0};
-template <class Tag> StaticMeta<Tag>* StaticMeta<Tag>::inst = nullptr;
+template <class Tag> StaticMeta<Tag>* StaticMeta<Tag>::inst_ = nullptr;
 
 }  // namespace threadlocal_detail
 }  // namespace folly
index 04f4ebc4d61d14c71a635c36baf5105f687afd38..1c948257f35e360a593d888dafc974fc94fcbc03 100644 (file)
@@ -16,6 +16,8 @@
 
 #include "folly/ThreadLocal.h"
 
+#include <sys/types.h>
+#include <sys/wait.h>
 #include <map>
 #include <unordered_map>
 #include <set>
@@ -23,6 +25,7 @@
 #include <mutex>
 #include <condition_variable>
 #include <thread>
+#include <unistd.h>
 #include <boost/thread/tss.hpp>
 #include <gtest/gtest.h>
 #include <gflags/gflags.h>
@@ -295,6 +298,103 @@ TEST(ThreadLocal, Movable2) {
   EXPECT_EQ(4, tls.size());
 }
 
+// Yes, threads and fork don't mix
+// (http://cppwisdom.quora.com/Why-threads-and-fork-dont-mix) but if you're
+// stupid or desperate enough to try, we shouldn't stand in your way.
+namespace {
+class HoldsOne {
+ public:
+  HoldsOne() : value_(1) { }
+  // Do an actual access to catch the buggy case where this == nullptr
+  int value() const { return value_; }
+ private:
+  int value_;
+};
+
+struct HoldsOneTag {};
+
+ThreadLocal<HoldsOne, HoldsOneTag> ptr;
+
+int totalValue() {
+  int value = 0;
+  for (auto& p : ptr.accessAllThreads()) {
+    value += p.value();
+  }
+  return value;
+}
+
+}  // namespace
+
+TEST(ThreadLocal, Fork) {
+  EXPECT_EQ(1, ptr->value());  // ensure created
+  EXPECT_EQ(1, totalValue());
+  // Spawn a new thread
+
+  std::mutex mutex;
+  bool started = false;
+  std::condition_variable startedCond;
+  bool stopped = false;
+  std::condition_variable stoppedCond;
+
+  std::thread t([&] () {
+    EXPECT_EQ(1, ptr->value());  // ensure created
+    {
+      std::unique_lock<std::mutex> lock(mutex);
+      started = true;
+      startedCond.notify_all();
+    }
+    {
+      std::unique_lock<std::mutex> lock(mutex);
+      while (!stopped) {
+        stoppedCond.wait(lock);
+      }
+    }
+  });
+
+  {
+    std::unique_lock<std::mutex> lock(mutex);
+    while (!started) {
+      startedCond.wait(lock);
+    }
+  }
+
+  EXPECT_EQ(2, totalValue());
+
+  pid_t pid = fork();
+  if (pid == 0) {
+    // in child
+    int v = totalValue();
+
+    // exit successfully if v == 1 (one thread)
+    // diagnostic error code otherwise :)
+    switch (v) {
+    case 1: _exit(0);
+    case 0: _exit(1);
+    }
+    _exit(2);
+  } else if (pid > 0) {
+    // in parent
+    int status;
+    EXPECT_EQ(pid, waitpid(pid, &status, 0));
+    EXPECT_TRUE(WIFEXITED(status));
+    EXPECT_EQ(0, WEXITSTATUS(status));
+  } else {
+    EXPECT_TRUE(false) << "fork failed";
+  }
+
+  EXPECT_EQ(2, totalValue());
+
+  {
+    std::unique_lock<std::mutex> lock(mutex);
+    stopped = true;
+    stoppedCond.notify_all();
+  }
+
+  t.join();
+
+  EXPECT_EQ(1, totalValue());
+}
+
 // Simple reference implementation using pthread_get_specific
 template<typename T>
 class PThreadGetSpecific {