Fix ThreadCachedInt race condition
authorDan Schatzberg <dschatzberg@fb.com>
Thu, 29 Sep 2016 15:22:50 +0000 (08:22 -0700)
committerFacebook Github Bot <facebook-github-bot-bot@fb.com>
Thu, 29 Sep 2016 15:23:59 +0000 (08:23 -0700)
Summary:
Acquire a SharedMutex at ThreadExit to ensure that after unlinking the ThreadEntry from
the list, future accessAllThreads() won't miss a destroying thread.

This is quite a dangerous fix as it changes some lock ordering semantics. ThreadLocal
elements are now destroyed while holding a lock, so if the destruction function
acquires a different lock, ordering must be consistent with other
uses of accessAllThreads().

I've made accessAllThreads() an opt-in feature via a template parameter and changed
all existing uses. I've also fixed a few lock ordering issues that arose due to this
change.

Reviewed By: andriigrynenko

Differential Revision: D3931072

fbshipit-source-id: 4d464408713184080079698df453b95873bb1a6c

folly/Singleton.cpp
folly/ThreadCachedInt.h
folly/ThreadLocal.h
folly/detail/ThreadLocalDetail.cpp
folly/detail/ThreadLocalDetail.h
folly/test/ThreadCachedIntTest.cpp

index c1031944313c610dfbfc33e97062119295e89c40..4a690c72099d87b92433c49b20cc12a7c76bbcfd 100644 (file)
@@ -234,7 +234,7 @@ void SingletonVault::reenableInstances() {
 void SingletonVault::scheduleDestroyInstances() {
   // Add a dependency on folly::ThreadLocal to make sure all its static
   // singletons are initalized first.
-  threadlocal_detail::StaticMeta<void>::instance();
+  threadlocal_detail::StaticMeta<void, void>::instance();
 
   class SingletonVaultDestructor {
    public:
index fcc4a7ea0e2f18e256f7c79db215d170a5e2a1de..10ccbbc5479b022da1e574ba0202893be625b22c 100644 (file)
@@ -63,8 +63,11 @@ class ThreadCachedInt : boost::noncopyable {
   // Reads the current value plus all the cached increments.  Requires grabbing
   // a lock, so this is significantly slower than readFast().
   IntT readFull() const {
+    // This could race with thread destruction and so the access lock should be
+    // acquired before reading the current value
+    auto accessor = cache_.accessAllThreads();
     IntT ret = readFast();
-    for (const auto& cache : cache_.accessAllThreads()) {
+    for (const auto& cache : accessor) {
       if (!cache.reset_.load(std::memory_order_acquire)) {
         ret += cache.val_.load(std::memory_order_relaxed);
       }
@@ -82,8 +85,11 @@ class ThreadCachedInt : boost::noncopyable {
   // little off, however, but it should be much better than calling readFull()
   // and set(0) sequentially.
   IntT readFullAndReset() {
+    // This could race with thread destruction and so the access lock should be
+    // acquired before reading the current value
+    auto accessor = cache_.accessAllThreads();
     IntT ret = readFastAndReset();
-    for (auto& cache : cache_.accessAllThreads()) {
+    for (auto& cache : accessor) {
       if (!cache.reset_.load(std::memory_order_acquire)) {
         ret += cache.val_.load(std::memory_order_relaxed);
         cache.reset_.store(true, std::memory_order_release);
@@ -128,7 +134,8 @@ class ThreadCachedInt : boost::noncopyable {
  private:
   std::atomic<IntT> target_;
   std::atomic<uint32_t> cacheSize_;
-  ThreadLocalPtr<IntCache,Tag> cache_; // Must be last for dtor ordering
+  ThreadLocalPtr<IntCache, Tag, AccessModeStrict>
+      cache_; // Must be last for dtor ordering
 
   // This should only ever be modified by one thread
   struct IntCache {
index b47742072a65a9c837872edf602b8fe4989b9137..1e28daf01aa93b9c7d78e993d89158cda66d7476 100644 (file)
  * objects of a parent.  accessAllThreads() initializes an accessor which holds
  * a global lock *that blocks all creation and destruction of ThreadLocal
  * objects with the same Tag* and can be used as an iterable container.
+ * accessAllThreads() can race with destruction of thread-local elements. We
+ * provide a strict mode which is dangerous because it requires the access lock
+ * to be held while destroying thread-local elements which could cause
+ * deadlocks. We gate this mode behind the AccessModeStrict template parameter.
  *
  * Intended use is for frequent write, infrequent read data access patterns such
  * as counters.
 
 #pragma once
 
+#include <boost/iterator/iterator_facade.hpp>
 #include <folly/Likely.h>
 #include <folly/Portability.h>
 #include <folly/ScopeGuard.h>
-#include <boost/iterator/iterator_facade.hpp>
+#include <folly/SharedMutex.h>
 #include <type_traits>
 #include <utility>
 
@@ -48,15 +53,17 @@ enum class TLPDestructionMode {
   THIS_THREAD,
   ALL_THREADS
 };
+struct AccessModeStrict {};
 }  // namespace
 
 #include <folly/detail/ThreadLocalDetail.h>
 
 namespace folly {
 
-template<class T, class Tag> class ThreadLocalPtr;
+template <class T, class Tag, class AccessMode>
+class ThreadLocalPtr;
 
-template<class T, class Tag=void>
+template <class T, class Tag = void, class AccessMode = void>
 class ThreadLocal {
  public:
   constexpr ThreadLocal() : constructor_([]() {
@@ -89,7 +96,7 @@ class ThreadLocal {
     tlp_.reset(newPtr);
   }
 
-  typedef typename ThreadLocalPtr<T,Tag>::Accessor Accessor;
+  typedef typename ThreadLocalPtr<T, Tag, AccessMode>::Accessor Accessor;
   Accessor accessAllThreads() const {
     return tlp_.accessAllThreads();
   }
@@ -109,7 +116,7 @@ class ThreadLocal {
     return ptr;
   }
 
-  mutable ThreadLocalPtr<T,Tag> tlp_;
+  mutable ThreadLocalPtr<T, Tag, AccessMode> tlp_;
   std::function<T*()> constructor_;
 };
 
@@ -139,10 +146,11 @@ class ThreadLocal {
  *       with __declspec(thread)
  */
 
-template<class T, class Tag=void>
+template <class T, class Tag = void, class AccessMode = void>
 class ThreadLocalPtr {
  private:
-  typedef threadlocal_detail::StaticMeta<Tag> StaticMeta;
+  typedef threadlocal_detail::StaticMeta<Tag, AccessMode> StaticMeta;
+
  public:
   constexpr ThreadLocalPtr() : id_() {}
 
@@ -246,9 +254,10 @@ class ThreadLocalPtr {
   // Can be used as an iterable container.
   // Use accessAllThreads() to obtain one.
   class Accessor {
-    friend class ThreadLocalPtr<T,Tag>;
+    friend class ThreadLocalPtr<T, Tag, AccessMode>;
 
     threadlocal_detail::StaticMetaBase& meta_;
+    SharedMutex* accessAllThreadsLock_;
     std::mutex* lock_;
     uint32_t id_;
 
@@ -321,10 +330,12 @@ class ThreadLocalPtr {
     Accessor& operator=(const Accessor&) = delete;
 
     Accessor(Accessor&& other) noexcept
-      : meta_(other.meta_),
-        lock_(other.lock_),
-        id_(other.id_) {
+        : meta_(other.meta_),
+          accessAllThreadsLock_(other.accessAllThreadsLock_),
+          lock_(other.lock_),
+          id_(other.id_) {
       other.id_ = 0;
+      other.accessAllThreadsLock_ = nullptr;
       other.lock_ = nullptr;
     }
 
@@ -338,20 +349,23 @@ class ThreadLocalPtr {
       assert(&meta_ == &other.meta_);
       assert(lock_ == nullptr);
       using std::swap;
+      swap(accessAllThreadsLock_, other.accessAllThreadsLock_);
       swap(lock_, other.lock_);
       swap(id_, other.id_);
     }
 
     Accessor()
-      : meta_(threadlocal_detail::StaticMeta<Tag>::instance()),
-        lock_(nullptr),
-        id_(0) {
-    }
+        : meta_(threadlocal_detail::StaticMeta<Tag, AccessMode>::instance()),
+          accessAllThreadsLock_(nullptr),
+          lock_(nullptr),
+          id_(0) {}
 
    private:
     explicit Accessor(uint32_t id)
-      : meta_(threadlocal_detail::StaticMeta<Tag>::instance()),
-        lock_(&meta_.lock_) {
+        : meta_(threadlocal_detail::StaticMeta<Tag, AccessMode>::instance()),
+          accessAllThreadsLock_(&meta_.accessAllThreadsLock_),
+          lock_(&meta_.lock_) {
+      accessAllThreadsLock_->lock();
       lock_->lock();
       id_ = id;
     }
@@ -359,8 +373,11 @@ class ThreadLocalPtr {
     void release() {
       if (lock_) {
         lock_->unlock();
+        DCHECK(accessAllThreadsLock_ != nullptr);
+        accessAllThreadsLock_->unlock();
         id_ = 0;
         lock_ = nullptr;
+        accessAllThreadsLock_ = nullptr;
       }
     }
   };
index 7dcb22d4ccc0a44b00f352dae5dddfb46eb820d4..6be41ed485e3a51307dc8296230f2ce7e63b46cd 100644 (file)
@@ -20,8 +20,8 @@
 
 namespace folly { namespace threadlocal_detail {
 
-StaticMetaBase::StaticMetaBase(ThreadEntry* (*threadEntry)())
-    : nextId_(1), threadEntry_(threadEntry) {
+StaticMetaBase::StaticMetaBase(ThreadEntry* (*threadEntry)(), bool strict)
+    : nextId_(1), threadEntry_(threadEntry), strict_(strict) {
   head_.next = head_.prev = &head_;
   int ret = pthread_key_create(&pthreadKey_, &onThreadExit);
   checkPosixError(ret, "pthread_key_create failed");
@@ -45,20 +45,26 @@ void StaticMetaBase::onThreadExit(void* ptr) {
   };
 
   {
-    std::lock_guard<std::mutex> g(meta.lock_);
-    meta.erase(&(*threadEntry));
-    // No need to hold the lock any longer; the ThreadEntry is private to this
-    // thread now that it's been removed from meta.
-  }
-  // NOTE: User-provided deleter / object dtor itself may be using ThreadLocal
-  // with the same Tag, so dispose() calls below may (re)create some of the
-  // elements or even increase elementsCapacity, thus multiple cleanup rounds
-  // may be required.
-  for (bool shouldRun = true; shouldRun;) {
-    shouldRun = false;
-    FOR_EACH_RANGE (i, 0, threadEntry->elementsCapacity) {
-      if (threadEntry->elements[i].dispose(TLPDestructionMode::THIS_THREAD)) {
-        shouldRun = true;
+    SharedMutex::ReadHolder rlock;
+    if (meta.strict_) {
+      rlock = SharedMutex::ReadHolder(meta.accessAllThreadsLock_);
+    }
+    {
+      std::lock_guard<std::mutex> g(meta.lock_);
+      meta.erase(&(*threadEntry));
+      // No need to hold the lock any longer; the ThreadEntry is private to this
+      // thread now that it's been removed from meta.
+    }
+    // NOTE: User-provided deleter / object dtor itself may be using ThreadLocal
+    // with the same Tag, so dispose() calls below may (re)create some of the
+    // elements or even increase elementsCapacity, thus multiple cleanup rounds
+    // may be required.
+    for (bool shouldRun = true; shouldRun;) {
+      shouldRun = false;
+      FOR_EACH_RANGE (i, 0, threadEntry->elementsCapacity) {
+        if (threadEntry->elements[i].dispose(TLPDestructionMode::THIS_THREAD)) {
+          shouldRun = true;
+        }
       }
     }
   }
index 68676a23433e4d3d93c3f6f1722d08dcd8f1c27f..fcd1a90f57c5ea0ce76be84b94d6849b9d074f72 100644 (file)
@@ -254,7 +254,7 @@ struct StaticMetaBase {
     }
   };
 
-  explicit StaticMetaBase(ThreadEntry* (*threadEntry)());
+  StaticMetaBase(ThreadEntry* (*threadEntry)(), bool strict);
 
   ~StaticMetaBase() {
     LOG(FATAL) << "StaticMeta lives forever!";
@@ -296,9 +296,11 @@ struct StaticMetaBase {
   uint32_t nextId_;
   std::vector<uint32_t> freeIds_;
   std::mutex lock_;
+  SharedMutex accessAllThreadsLock_;
   pthread_key_t pthreadKey_;
   ThreadEntry head_;
   ThreadEntry* (*threadEntry_)();
+  bool strict_;
 };
 
 // Held in a singleton to track our global instances.
@@ -308,19 +310,23 @@ struct StaticMetaBase {
 // Creating and destroying ThreadLocalPtr objects, as well as thread exit
 // for threads that use ThreadLocalPtr objects collide on a lock inside
 // StaticMeta; you can specify multiple Tag types to break that lock.
-template <class Tag>
+template <class Tag, class AccessMode>
 struct StaticMeta : StaticMetaBase {
-  StaticMeta() : StaticMetaBase(&StaticMeta::getThreadEntrySlow) {
+  StaticMeta()
+      : StaticMetaBase(
+            &StaticMeta::getThreadEntrySlow,
+            std::is_same<AccessMode, AccessModeStrict>::value) {
     registerAtFork(
         /*prepare*/ &StaticMeta::preFork,
         /*parent*/ &StaticMeta::onForkParent,
         /*child*/ &StaticMeta::onForkChild);
   }
 
-  static StaticMeta<Tag>& instance() {
+  static StaticMeta<Tag, AccessMode>& instance() {
     // Leak it on exit, there's only one per process and we don't have to
     // worry about synchronization with exiting threads.
-    static auto instance = detail::createGlobal<StaticMeta<Tag>, void>();
+    static auto instance =
+        detail::createGlobal<StaticMeta<Tag, AccessMode>, void>();
     return *instance;
   }
 
index 74f6c4dc8d316d12ee15ae2f2e9548ac4c65c36c..97ce60c207587e53a4412f80c8b55f208696d898 100644 (file)
@@ -17,6 +17,7 @@
 #include <folly/ThreadCachedInt.h>
 
 #include <atomic>
+#include <condition_variable>
 #include <thread>
 
 #include <glog/logging.h>
 
 using namespace folly;
 
+using std::unique_ptr;
+using std::vector;
+
+using Counter = ThreadCachedInt<int64_t>;
+
+class ThreadCachedIntTest : public testing::Test {
+ public:
+  uint32_t GetDeadThreadsTotal(const Counter& counter) {
+    return counter.readFast();
+  }
+};
+
+// Multithreaded tests.  Creates a specified number of threads each of
+// which iterates a different amount and dies.
+
+namespace {
+// Set cacheSize to be large so cached data moves to target_ only when
+// thread dies.
+Counter g_counter_for_mt_slow(0, UINT32_MAX);
+Counter g_counter_for_mt_fast(0, UINT32_MAX);
+
+// Used to sync between threads.  The value of this variable is the
+// maximum iteration index upto which Runner() is allowed to go.
+uint32_t g_sync_for_mt(0);
+std::condition_variable cv;
+std::mutex cv_m;
+
+// Performs the specified number of iterations.  Within each
+// iteration, it increments counter 10 times.  At the beginning of
+// each iteration it checks g_sync_for_mt to see if it can proceed,
+// otherwise goes into a loop sleeping and rechecking.
+void Runner(Counter* counter, uint32_t iterations) {
+  for (uint32_t i = 0; i < iterations; ++i) {
+    std::unique_lock<std::mutex> lk(cv_m);
+    cv.wait(lk, [i] { return i < g_sync_for_mt; });
+    for (uint32_t j = 0; j < 10; ++j) {
+      counter->increment(1);
+    }
+  }
+}
+}
+
+// Slow test with fewer threads where there are more busy waits and
+// many calls to readFull().  This attempts to test as many of the
+// code paths in Counter as possible to ensure that counter values are
+// properly passed from thread local state, both at calls to
+// readFull() and at thread death.
+TEST_F(ThreadCachedIntTest, MultithreadedSlow) {
+  static constexpr uint32_t kNumThreads = 20;
+  g_sync_for_mt = 0;
+  vector<unique_ptr<std::thread>> threads(kNumThreads);
+  // Creates kNumThreads threads.  Each thread performs a different
+  // number of iterations in Runner() - threads[0] performs 1
+  // iteration, threads[1] performs 2 iterations, threads[2] performs
+  // 3 iterations, and so on.
+  for (uint32_t i = 0; i < kNumThreads; ++i) {
+    threads[i].reset(new std::thread(Runner, &g_counter_for_mt_slow, i + 1));
+  }
+  // Variable to grab current counter value.
+  int32_t counter_value;
+  // The expected value of the counter.
+  int32_t total = 0;
+  // The expected value of GetDeadThreadsTotal().
+  int32_t dead_total = 0;
+  // Each iteration of the following thread allows one additional
+  // iteration of the threads.  Given that the threads perform
+  // different number of iterations from 1 through kNumThreads, one
+  // thread will complete in each of the iterations of the loop below.
+  for (uint32_t i = 0; i < kNumThreads; ++i) {
+    // Allow upto iteration i on all threads.
+    {
+      std::lock_guard<std::mutex> lk(cv_m);
+      g_sync_for_mt = i + 1;
+    }
+    cv.notify_all();
+    total += (kNumThreads - i) * 10;
+    // Loop until the counter reaches its expected value.
+    do {
+      counter_value = g_counter_for_mt_slow.readFull();
+    } while (counter_value < total);
+    // All threads have done what they can until iteration i, now make
+    // sure they don't go further by checking 10 more times in the
+    // following loop.
+    for (uint32_t j = 0; j < 10; ++j) {
+      counter_value = g_counter_for_mt_slow.readFull();
+      EXPECT_EQ(total, counter_value);
+    }
+    dead_total += (i + 1) * 10;
+    EXPECT_GE(dead_total, GetDeadThreadsTotal(g_counter_for_mt_slow));
+  }
+  // All threads are done.
+  for (uint32_t i = 0; i < kNumThreads; ++i) {
+    threads[i]->join();
+  }
+  counter_value = g_counter_for_mt_slow.readFull();
+  EXPECT_EQ(total, counter_value);
+  EXPECT_EQ(total, dead_total);
+  EXPECT_EQ(dead_total, GetDeadThreadsTotal(g_counter_for_mt_slow));
+}
+
+// Fast test with lots of threads and only one call to readFull()
+// at the end.
+TEST_F(ThreadCachedIntTest, MultithreadedFast) {
+  static constexpr uint32_t kNumThreads = 1000;
+  g_sync_for_mt = 0;
+  vector<unique_ptr<std::thread>> threads(kNumThreads);
+  // Creates kNumThreads threads.  Each thread performs a different
+  // number of iterations in Runner() - threads[0] performs 1
+  // iteration, threads[1] performs 2 iterations, threads[2] performs
+  // 3 iterations, and so on.
+  for (uint32_t i = 0; i < kNumThreads; ++i) {
+    threads[i].reset(new std::thread(Runner, &g_counter_for_mt_fast, i + 1));
+  }
+  // Let the threads run to completion.
+  {
+    std::lock_guard<std::mutex> lk(cv_m);
+    g_sync_for_mt = kNumThreads;
+  }
+  cv.notify_all();
+  // The expected value of the counter.
+  uint32_t total = 0;
+  for (uint32_t i = 0; i < kNumThreads; ++i) {
+    total += (kNumThreads - i) * 10;
+  }
+  // Wait for all threads to complete.
+  for (uint32_t i = 0; i < kNumThreads; ++i) {
+    threads[i]->join();
+  }
+  int32_t counter_value = g_counter_for_mt_fast.readFull();
+  EXPECT_EQ(total, counter_value);
+  EXPECT_EQ(total, GetDeadThreadsTotal(g_counter_for_mt_fast));
+}
+
 TEST(ThreadCachedInt, SingleThreadedNotCached) {
   ThreadCachedInt<int64_t> val(0, 0);
   EXPECT_EQ(0, val.readFast());