Fix ThreadCachedInt race condition
authorDan Schatzberg <dschatzberg@fb.com>
Mon, 29 Aug 2016 23:42:44 +0000 (16:42 -0700)
committerFacebook Github Bot 4 <facebook-github-bot-4-bot@fb.com>
Mon, 29 Aug 2016 23:53:29 +0000 (16:53 -0700)
Summary:
Move ThreadLocal object destruction to occur under the lock to avoid races.
This causes a few cascading changes - the Tag lock needs to be a recursive_mutex so
constructing a new object while destroying another st. Also, forking requires
a new mutex to avoid deadlocking on accessing a recursive_mutex across a fork()

Reviewed By: andriigrynenko

Differential Revision: D3755446

fbshipit-source-id: bb4c4f29bab98d763490df29b460066f124303e0

folly/ThreadCachedInt.h
folly/ThreadLocal.h
folly/detail/ThreadLocalDetail.cpp
folly/detail/ThreadLocalDetail.h
folly/test/ThreadCachedIntTest.cpp

index fcc4a7ea0e2f18e256f7c79db215d170a5e2a1de..b84697216d3065477bd02d66d2d00c8dec191a9f 100644 (file)
@@ -63,8 +63,11 @@ class ThreadCachedInt : boost::noncopyable {
   // Reads the current value plus all the cached increments.  Requires grabbing
   // a lock, so this is significantly slower than readFast().
   IntT readFull() const {
+    // This could race with thread destruction and so the access lock should be
+    // acquired before reading the current value
+    auto accessor = cache_.accessAllThreads();
     IntT ret = readFast();
-    for (const auto& cache : cache_.accessAllThreads()) {
+    for (const auto& cache : accessor) {
       if (!cache.reset_.load(std::memory_order_acquire)) {
         ret += cache.val_.load(std::memory_order_relaxed);
       }
@@ -82,8 +85,11 @@ class ThreadCachedInt : boost::noncopyable {
   // little off, however, but it should be much better than calling readFull()
   // and set(0) sequentially.
   IntT readFullAndReset() {
+    // This could race with thread destruction and so the access lock should be
+    // acquired before reading the current value
+    auto accessor = cache_.accessAllThreads();
     IntT ret = readFastAndReset();
-    for (auto& cache : cache_.accessAllThreads()) {
+    for (auto& cache : accessor) {
       if (!cache.reset_.load(std::memory_order_acquire)) {
         ret += cache.val_.load(std::memory_order_relaxed);
         cache.reset_.store(true, std::memory_order_release);
index b47742072a65a9c837872edf602b8fe4989b9137..55b7a246e9e00d94dbb45d92f7165a9ad5e4c516 100644 (file)
 
 #pragma once
 
+#include <boost/iterator/iterator_facade.hpp>
 #include <folly/Likely.h>
 #include <folly/Portability.h>
 #include <folly/ScopeGuard.h>
-#include <boost/iterator/iterator_facade.hpp>
+#include <folly/SharedMutex.h>
 #include <type_traits>
 #include <utility>
 
@@ -249,6 +250,7 @@ class ThreadLocalPtr {
     friend class ThreadLocalPtr<T,Tag>;
 
     threadlocal_detail::StaticMetaBase& meta_;
+    SharedMutex* accessAllThreadsLock_;
     std::mutex* lock_;
     uint32_t id_;
 
@@ -321,10 +323,12 @@ class ThreadLocalPtr {
     Accessor& operator=(const Accessor&) = delete;
 
     Accessor(Accessor&& other) noexcept
-      : meta_(other.meta_),
-        lock_(other.lock_),
-        id_(other.id_) {
+        : meta_(other.meta_),
+          accessAllThreadsLock_(other.accessAllThreadsLock_),
+          lock_(other.lock_),
+          id_(other.id_) {
       other.id_ = 0;
+      other.accessAllThreadsLock_ = nullptr;
       other.lock_ = nullptr;
     }
 
@@ -338,20 +342,23 @@ class ThreadLocalPtr {
       assert(&meta_ == &other.meta_);
       assert(lock_ == nullptr);
       using std::swap;
+      swap(accessAllThreadsLock_, other.accessAllThreadsLock_);
       swap(lock_, other.lock_);
       swap(id_, other.id_);
     }
 
     Accessor()
-      : meta_(threadlocal_detail::StaticMeta<Tag>::instance()),
-        lock_(nullptr),
-        id_(0) {
-    }
+        : meta_(threadlocal_detail::StaticMeta<Tag>::instance()),
+          accessAllThreadsLock_(nullptr),
+          lock_(nullptr),
+          id_(0) {}
 
    private:
     explicit Accessor(uint32_t id)
-      : meta_(threadlocal_detail::StaticMeta<Tag>::instance()),
-        lock_(&meta_.lock_) {
+        : meta_(threadlocal_detail::StaticMeta<Tag>::instance()),
+          accessAllThreadsLock_(&meta_.accessAllThreadsLock_),
+          lock_(&meta_.lock_) {
+      accessAllThreadsLock_->lock();
       lock_->lock();
       id_ = id;
     }
@@ -359,8 +366,11 @@ class ThreadLocalPtr {
     void release() {
       if (lock_) {
         lock_->unlock();
+        DCHECK(accessAllThreadsLock_ != nullptr);
+        accessAllThreadsLock_->unlock();
         id_ = 0;
         lock_ = nullptr;
+        accessAllThreadsLock_ = nullptr;
       }
     }
   };
index 7dcb22d4ccc0a44b00f352dae5dddfb46eb820d4..dfd17259199ed49d77aa51b35eb1ca91517379d6 100644 (file)
@@ -45,20 +45,23 @@ void StaticMetaBase::onThreadExit(void* ptr) {
   };
 
   {
-    std::lock_guard<std::mutex> g(meta.lock_);
-    meta.erase(&(*threadEntry));
-    // No need to hold the lock any longer; the ThreadEntry is private to this
-    // thread now that it's been removed from meta.
-  }
-  // NOTE: User-provided deleter / object dtor itself may be using ThreadLocal
-  // with the same Tag, so dispose() calls below may (re)create some of the
-  // elements or even increase elementsCapacity, thus multiple cleanup rounds
-  // may be required.
-  for (bool shouldRun = true; shouldRun;) {
-    shouldRun = false;
-    FOR_EACH_RANGE (i, 0, threadEntry->elementsCapacity) {
-      if (threadEntry->elements[i].dispose(TLPDestructionMode::THIS_THREAD)) {
-        shouldRun = true;
+    SharedMutex::ReadHolder rlock(meta.accessAllThreadsLock_);
+    {
+      std::lock_guard<std::mutex> g(meta.lock_);
+      meta.erase(&(*threadEntry));
+      // No need to hold the lock any longer; the ThreadEntry is private to this
+      // thread now that it's been removed from meta.
+    }
+    // NOTE: User-provided deleter / object dtor itself may be using ThreadLocal
+    // with the same Tag, so dispose() calls below may (re)create some of the
+    // elements or even increase elementsCapacity, thus multiple cleanup rounds
+    // may be required.
+    for (bool shouldRun = true; shouldRun;) {
+      shouldRun = false;
+      FOR_EACH_RANGE (i, 0, threadEntry->elementsCapacity) {
+        if (threadEntry->elements[i].dispose(TLPDestructionMode::THIS_THREAD)) {
+          shouldRun = true;
+        }
       }
     }
   }
index 68676a23433e4d3d93c3f6f1722d08dcd8f1c27f..e05acab5db38fd1379fef283fba7671f910b0edb 100644 (file)
@@ -296,6 +296,7 @@ struct StaticMetaBase {
   uint32_t nextId_;
   std::vector<uint32_t> freeIds_;
   std::mutex lock_;
+  SharedMutex accessAllThreadsLock_;
   pthread_key_t pthreadKey_;
   ThreadEntry head_;
   ThreadEntry* (*threadEntry_)();
index 4f5c377e097933b34233b888a3a49738917425c5..a17c4cb6b5c3b006a74a745789b2b5de344be1a2 100644 (file)
@@ -17,6 +17,7 @@
 #include <folly/ThreadCachedInt.h>
 
 #include <atomic>
+#include <condition_variable>
 #include <thread>
 
 #include <glog/logging.h>
 
 using namespace folly;
 
+using std::unique_ptr;
+using std::vector;
+
+using Counter = ThreadCachedInt<int64_t>;
+
+class ThreadCachedIntTest : public testing::Test {
+ public:
+  uint32_t GetDeadThreadsTotal(const Counter& counter) {
+    return counter.readFast();
+  }
+};
+
+// Multithreaded tests.  Creates a specified number of threads each of
+// which iterates a different amount and dies.
+
+namespace {
+// Set cacheSize to be large so cached data moves to target_ only when
+// thread dies.
+Counter g_counter_for_mt_slow(0, UINT32_MAX);
+Counter g_counter_for_mt_fast(0, UINT32_MAX);
+
+// Used to sync between threads.  The value of this variable is the
+// maximum iteration index upto which Runner() is allowed to go.
+uint32_t g_sync_for_mt(0);
+std::condition_variable cv;
+std::mutex cv_m;
+
+// Performs the specified number of iterations.  Within each
+// iteration, it increments counter 10 times.  At the beginning of
+// each iteration it checks g_sync_for_mt to see if it can proceed,
+// otherwise goes into a loop sleeping and rechecking.
+void Runner(Counter* counter, uint32_t iterations) {
+  for (uint32_t i = 0; i < iterations; ++i) {
+    std::unique_lock<std::mutex> lk(cv_m);
+    cv.wait(lk, [i] { return i < g_sync_for_mt; });
+    for (uint32_t j = 0; j < 10; ++j) {
+      counter->increment(1);
+    }
+  }
+}
+}
+
+// Slow test with fewer threads where there are more busy waits and
+// many calls to readFull().  This attempts to test as many of the
+// code paths in Counter as possible to ensure that counter values are
+// properly passed from thread local state, both at calls to
+// readFull() and at thread death.
+TEST_F(ThreadCachedIntTest, MultithreadedSlow) {
+  static constexpr uint32_t kNumThreads = 20;
+  g_sync_for_mt = 0;
+  vector<unique_ptr<std::thread>> threads(kNumThreads);
+  // Creates kNumThreads threads.  Each thread performs a different
+  // number of iterations in Runner() - threads[0] performs 1
+  // iteration, threads[1] performs 2 iterations, threads[2] performs
+  // 3 iterations, and so on.
+  for (uint32_t i = 0; i < kNumThreads; ++i) {
+    threads[i].reset(new std::thread(Runner, &g_counter_for_mt_slow, i + 1));
+  }
+  // Variable to grab current counter value.
+  int32_t counter_value;
+  // The expected value of the counter.
+  int32_t total = 0;
+  // The expected value of GetDeadThreadsTotal().
+  int32_t dead_total = 0;
+  // Each iteration of the following thread allows one additional
+  // iteration of the threads.  Given that the threads perform
+  // different number of iterations from 1 through kNumThreads, one
+  // thread will complete in each of the iterations of the loop below.
+  for (uint32_t i = 0; i < kNumThreads; ++i) {
+    // Allow upto iteration i on all threads.
+    {
+      std::lock_guard<std::mutex> lk(cv_m);
+      g_sync_for_mt = i + 1;
+    }
+    cv.notify_all();
+    total += (kNumThreads - i) * 10;
+    // Loop until the counter reaches its expected value.
+    do {
+      counter_value = g_counter_for_mt_slow.readFull();
+    } while (counter_value < total);
+    // All threads have done what they can until iteration i, now make
+    // sure they don't go further by checking 10 more times in the
+    // following loop.
+    for (uint32_t j = 0; j < 10; ++j) {
+      counter_value = g_counter_for_mt_slow.readFull();
+      EXPECT_EQ(total, counter_value);
+    }
+    dead_total += (i + 1) * 10;
+    EXPECT_GE(dead_total, GetDeadThreadsTotal(g_counter_for_mt_slow));
+  }
+  // All threads are done.
+  for (uint32_t i = 0; i < kNumThreads; ++i) {
+    threads[i]->join();
+  }
+  counter_value = g_counter_for_mt_slow.readFull();
+  EXPECT_EQ(total, counter_value);
+  EXPECT_EQ(total, dead_total);
+  EXPECT_EQ(dead_total, GetDeadThreadsTotal(g_counter_for_mt_slow));
+}
+
+// Fast test with lots of threads and only one call to readFull()
+// at the end.
+TEST_F(ThreadCachedIntTest, MultithreadedFast) {
+  static constexpr uint32_t kNumThreads = 1000;
+  g_sync_for_mt = 0;
+  vector<unique_ptr<std::thread>> threads(kNumThreads);
+  // Creates kNumThreads threads.  Each thread performs a different
+  // number of iterations in Runner() - threads[0] performs 1
+  // iteration, threads[1] performs 2 iterations, threads[2] performs
+  // 3 iterations, and so on.
+  for (uint32_t i = 0; i < kNumThreads; ++i) {
+    threads[i].reset(new std::thread(Runner, &g_counter_for_mt_fast, i + 1));
+  }
+  // Let the threads run to completion.
+  {
+    std::lock_guard<std::mutex> lk(cv_m);
+    g_sync_for_mt = kNumThreads;
+  }
+  cv.notify_all();
+  // The expected value of the counter.
+  uint32_t total = 0;
+  for (uint32_t i = 0; i < kNumThreads; ++i) {
+    total += (kNumThreads - i) * 10;
+  }
+  // Wait for all threads to complete.
+  for (uint32_t i = 0; i < kNumThreads; ++i) {
+    threads[i]->join();
+  }
+  int32_t counter_value = g_counter_for_mt_fast.readFull();
+  EXPECT_EQ(total, counter_value);
+  EXPECT_EQ(total, GetDeadThreadsTotal(g_counter_for_mt_fast));
+}
+
 TEST(ThreadCachedInt, SingleThreadedNotCached) {
   ThreadCachedInt<int64_t> val(0, 0);
   EXPECT_EQ(0, val.readFast());