From 6039ee419aa6cf043e010297b29348267ab5be0d Mon Sep 17 00:00:00 2001 From: Dan Schatzberg Date: Thu, 29 Sep 2016 08:22:50 -0700 Subject: [PATCH] Fix ThreadCachedInt race condition Summary: Acquire a SharedMutex at ThreadExit to ensure that after unlinking the ThreadEntry from the list, future accessAllThreads() won't miss a destroying thread. This is quite a dangerous fix as it changes some lock ordering semantics. ThreadLocal elements are now destroyed while holding a lock, so if the destruction function acquires a different lock, ordering must be consistent with other uses of accessAllThreads(). I've made accessAllThreads() an opt-in feature via a template parameter and changed all existing uses. I've also fixed a few lock ordering issues that arose due to this change. Reviewed By: andriigrynenko Differential Revision: D3931072 fbshipit-source-id: 4d464408713184080079698df453b95873bb1a6c --- folly/Singleton.cpp | 2 +- folly/ThreadCachedInt.h | 13 ++- folly/ThreadLocal.h | 51 +++++++---- folly/detail/ThreadLocalDetail.cpp | 38 ++++---- folly/detail/ThreadLocalDetail.h | 16 ++-- folly/test/ThreadCachedIntTest.cpp | 134 +++++++++++++++++++++++++++++ 6 files changed, 212 insertions(+), 42 deletions(-) diff --git a/folly/Singleton.cpp b/folly/Singleton.cpp index c1031944..4a690c72 100644 --- a/folly/Singleton.cpp +++ b/folly/Singleton.cpp @@ -234,7 +234,7 @@ void SingletonVault::reenableInstances() { void SingletonVault::scheduleDestroyInstances() { // Add a dependency on folly::ThreadLocal to make sure all its static // singletons are initalized first. - threadlocal_detail::StaticMeta::instance(); + threadlocal_detail::StaticMeta::instance(); class SingletonVaultDestructor { public: diff --git a/folly/ThreadCachedInt.h b/folly/ThreadCachedInt.h index fcc4a7ea..10ccbbc5 100644 --- a/folly/ThreadCachedInt.h +++ b/folly/ThreadCachedInt.h @@ -63,8 +63,11 @@ class ThreadCachedInt : boost::noncopyable { // Reads the current value plus all the cached increments. Requires grabbing // a lock, so this is significantly slower than readFast(). IntT readFull() const { + // This could race with thread destruction and so the access lock should be + // acquired before reading the current value + auto accessor = cache_.accessAllThreads(); IntT ret = readFast(); - for (const auto& cache : cache_.accessAllThreads()) { + for (const auto& cache : accessor) { if (!cache.reset_.load(std::memory_order_acquire)) { ret += cache.val_.load(std::memory_order_relaxed); } @@ -82,8 +85,11 @@ class ThreadCachedInt : boost::noncopyable { // little off, however, but it should be much better than calling readFull() // and set(0) sequentially. IntT readFullAndReset() { + // This could race with thread destruction and so the access lock should be + // acquired before reading the current value + auto accessor = cache_.accessAllThreads(); IntT ret = readFastAndReset(); - for (auto& cache : cache_.accessAllThreads()) { + for (auto& cache : accessor) { if (!cache.reset_.load(std::memory_order_acquire)) { ret += cache.val_.load(std::memory_order_relaxed); cache.reset_.store(true, std::memory_order_release); @@ -128,7 +134,8 @@ class ThreadCachedInt : boost::noncopyable { private: std::atomic target_; std::atomic cacheSize_; - ThreadLocalPtr cache_; // Must be last for dtor ordering + ThreadLocalPtr + cache_; // Must be last for dtor ordering // This should only ever be modified by one thread struct IntCache { diff --git a/folly/ThreadLocal.h b/folly/ThreadLocal.h index b4774207..1e28daf0 100644 --- a/folly/ThreadLocal.h +++ b/folly/ThreadLocal.h @@ -23,6 +23,10 @@ * objects of a parent. accessAllThreads() initializes an accessor which holds * a global lock *that blocks all creation and destruction of ThreadLocal * objects with the same Tag* and can be used as an iterable container. + * accessAllThreads() can race with destruction of thread-local elements. We + * provide a strict mode which is dangerous because it requires the access lock + * to be held while destroying thread-local elements which could cause + * deadlocks. We gate this mode behind the AccessModeStrict template parameter. * * Intended use is for frequent write, infrequent read data access patterns such * as counters. @@ -36,10 +40,11 @@ #pragma once +#include #include #include #include -#include +#include #include #include @@ -48,15 +53,17 @@ enum class TLPDestructionMode { THIS_THREAD, ALL_THREADS }; +struct AccessModeStrict {}; } // namespace #include namespace folly { -template class ThreadLocalPtr; +template +class ThreadLocalPtr; -template +template class ThreadLocal { public: constexpr ThreadLocal() : constructor_([]() { @@ -89,7 +96,7 @@ class ThreadLocal { tlp_.reset(newPtr); } - typedef typename ThreadLocalPtr::Accessor Accessor; + typedef typename ThreadLocalPtr::Accessor Accessor; Accessor accessAllThreads() const { return tlp_.accessAllThreads(); } @@ -109,7 +116,7 @@ class ThreadLocal { return ptr; } - mutable ThreadLocalPtr tlp_; + mutable ThreadLocalPtr tlp_; std::function constructor_; }; @@ -139,10 +146,11 @@ class ThreadLocal { * with __declspec(thread) */ -template +template class ThreadLocalPtr { private: - typedef threadlocal_detail::StaticMeta StaticMeta; + typedef threadlocal_detail::StaticMeta StaticMeta; + public: constexpr ThreadLocalPtr() : id_() {} @@ -246,9 +254,10 @@ class ThreadLocalPtr { // Can be used as an iterable container. // Use accessAllThreads() to obtain one. class Accessor { - friend class ThreadLocalPtr; + friend class ThreadLocalPtr; threadlocal_detail::StaticMetaBase& meta_; + SharedMutex* accessAllThreadsLock_; std::mutex* lock_; uint32_t id_; @@ -321,10 +330,12 @@ class ThreadLocalPtr { Accessor& operator=(const Accessor&) = delete; Accessor(Accessor&& other) noexcept - : meta_(other.meta_), - lock_(other.lock_), - id_(other.id_) { + : meta_(other.meta_), + accessAllThreadsLock_(other.accessAllThreadsLock_), + lock_(other.lock_), + id_(other.id_) { other.id_ = 0; + other.accessAllThreadsLock_ = nullptr; other.lock_ = nullptr; } @@ -338,20 +349,23 @@ class ThreadLocalPtr { assert(&meta_ == &other.meta_); assert(lock_ == nullptr); using std::swap; + swap(accessAllThreadsLock_, other.accessAllThreadsLock_); swap(lock_, other.lock_); swap(id_, other.id_); } Accessor() - : meta_(threadlocal_detail::StaticMeta::instance()), - lock_(nullptr), - id_(0) { - } + : meta_(threadlocal_detail::StaticMeta::instance()), + accessAllThreadsLock_(nullptr), + lock_(nullptr), + id_(0) {} private: explicit Accessor(uint32_t id) - : meta_(threadlocal_detail::StaticMeta::instance()), - lock_(&meta_.lock_) { + : meta_(threadlocal_detail::StaticMeta::instance()), + accessAllThreadsLock_(&meta_.accessAllThreadsLock_), + lock_(&meta_.lock_) { + accessAllThreadsLock_->lock(); lock_->lock(); id_ = id; } @@ -359,8 +373,11 @@ class ThreadLocalPtr { void release() { if (lock_) { lock_->unlock(); + DCHECK(accessAllThreadsLock_ != nullptr); + accessAllThreadsLock_->unlock(); id_ = 0; lock_ = nullptr; + accessAllThreadsLock_ = nullptr; } } }; diff --git a/folly/detail/ThreadLocalDetail.cpp b/folly/detail/ThreadLocalDetail.cpp index 7dcb22d4..6be41ed4 100644 --- a/folly/detail/ThreadLocalDetail.cpp +++ b/folly/detail/ThreadLocalDetail.cpp @@ -20,8 +20,8 @@ namespace folly { namespace threadlocal_detail { -StaticMetaBase::StaticMetaBase(ThreadEntry* (*threadEntry)()) - : nextId_(1), threadEntry_(threadEntry) { +StaticMetaBase::StaticMetaBase(ThreadEntry* (*threadEntry)(), bool strict) + : nextId_(1), threadEntry_(threadEntry), strict_(strict) { head_.next = head_.prev = &head_; int ret = pthread_key_create(&pthreadKey_, &onThreadExit); checkPosixError(ret, "pthread_key_create failed"); @@ -45,20 +45,26 @@ void StaticMetaBase::onThreadExit(void* ptr) { }; { - std::lock_guard g(meta.lock_); - meta.erase(&(*threadEntry)); - // No need to hold the lock any longer; the ThreadEntry is private to this - // thread now that it's been removed from meta. - } - // NOTE: User-provided deleter / object dtor itself may be using ThreadLocal - // with the same Tag, so dispose() calls below may (re)create some of the - // elements or even increase elementsCapacity, thus multiple cleanup rounds - // may be required. - for (bool shouldRun = true; shouldRun;) { - shouldRun = false; - FOR_EACH_RANGE (i, 0, threadEntry->elementsCapacity) { - if (threadEntry->elements[i].dispose(TLPDestructionMode::THIS_THREAD)) { - shouldRun = true; + SharedMutex::ReadHolder rlock; + if (meta.strict_) { + rlock = SharedMutex::ReadHolder(meta.accessAllThreadsLock_); + } + { + std::lock_guard g(meta.lock_); + meta.erase(&(*threadEntry)); + // No need to hold the lock any longer; the ThreadEntry is private to this + // thread now that it's been removed from meta. + } + // NOTE: User-provided deleter / object dtor itself may be using ThreadLocal + // with the same Tag, so dispose() calls below may (re)create some of the + // elements or even increase elementsCapacity, thus multiple cleanup rounds + // may be required. + for (bool shouldRun = true; shouldRun;) { + shouldRun = false; + FOR_EACH_RANGE (i, 0, threadEntry->elementsCapacity) { + if (threadEntry->elements[i].dispose(TLPDestructionMode::THIS_THREAD)) { + shouldRun = true; + } } } } diff --git a/folly/detail/ThreadLocalDetail.h b/folly/detail/ThreadLocalDetail.h index 68676a23..fcd1a90f 100644 --- a/folly/detail/ThreadLocalDetail.h +++ b/folly/detail/ThreadLocalDetail.h @@ -254,7 +254,7 @@ struct StaticMetaBase { } }; - explicit StaticMetaBase(ThreadEntry* (*threadEntry)()); + StaticMetaBase(ThreadEntry* (*threadEntry)(), bool strict); ~StaticMetaBase() { LOG(FATAL) << "StaticMeta lives forever!"; @@ -296,9 +296,11 @@ struct StaticMetaBase { uint32_t nextId_; std::vector freeIds_; std::mutex lock_; + SharedMutex accessAllThreadsLock_; pthread_key_t pthreadKey_; ThreadEntry head_; ThreadEntry* (*threadEntry_)(); + bool strict_; }; // Held in a singleton to track our global instances. @@ -308,19 +310,23 @@ struct StaticMetaBase { // Creating and destroying ThreadLocalPtr objects, as well as thread exit // for threads that use ThreadLocalPtr objects collide on a lock inside // StaticMeta; you can specify multiple Tag types to break that lock. -template +template struct StaticMeta : StaticMetaBase { - StaticMeta() : StaticMetaBase(&StaticMeta::getThreadEntrySlow) { + StaticMeta() + : StaticMetaBase( + &StaticMeta::getThreadEntrySlow, + std::is_same::value) { registerAtFork( /*prepare*/ &StaticMeta::preFork, /*parent*/ &StaticMeta::onForkParent, /*child*/ &StaticMeta::onForkChild); } - static StaticMeta& instance() { + static StaticMeta& instance() { // Leak it on exit, there's only one per process and we don't have to // worry about synchronization with exiting threads. - static auto instance = detail::createGlobal, void>(); + static auto instance = + detail::createGlobal, void>(); return *instance; } diff --git a/folly/test/ThreadCachedIntTest.cpp b/folly/test/ThreadCachedIntTest.cpp index 74f6c4dc..97ce60c2 100644 --- a/folly/test/ThreadCachedIntTest.cpp +++ b/folly/test/ThreadCachedIntTest.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -28,6 +29,139 @@ using namespace folly; +using std::unique_ptr; +using std::vector; + +using Counter = ThreadCachedInt; + +class ThreadCachedIntTest : public testing::Test { + public: + uint32_t GetDeadThreadsTotal(const Counter& counter) { + return counter.readFast(); + } +}; + +// Multithreaded tests. Creates a specified number of threads each of +// which iterates a different amount and dies. + +namespace { +// Set cacheSize to be large so cached data moves to target_ only when +// thread dies. +Counter g_counter_for_mt_slow(0, UINT32_MAX); +Counter g_counter_for_mt_fast(0, UINT32_MAX); + +// Used to sync between threads. The value of this variable is the +// maximum iteration index upto which Runner() is allowed to go. +uint32_t g_sync_for_mt(0); +std::condition_variable cv; +std::mutex cv_m; + +// Performs the specified number of iterations. Within each +// iteration, it increments counter 10 times. At the beginning of +// each iteration it checks g_sync_for_mt to see if it can proceed, +// otherwise goes into a loop sleeping and rechecking. +void Runner(Counter* counter, uint32_t iterations) { + for (uint32_t i = 0; i < iterations; ++i) { + std::unique_lock lk(cv_m); + cv.wait(lk, [i] { return i < g_sync_for_mt; }); + for (uint32_t j = 0; j < 10; ++j) { + counter->increment(1); + } + } +} +} + +// Slow test with fewer threads where there are more busy waits and +// many calls to readFull(). This attempts to test as many of the +// code paths in Counter as possible to ensure that counter values are +// properly passed from thread local state, both at calls to +// readFull() and at thread death. +TEST_F(ThreadCachedIntTest, MultithreadedSlow) { + static constexpr uint32_t kNumThreads = 20; + g_sync_for_mt = 0; + vector> threads(kNumThreads); + // Creates kNumThreads threads. Each thread performs a different + // number of iterations in Runner() - threads[0] performs 1 + // iteration, threads[1] performs 2 iterations, threads[2] performs + // 3 iterations, and so on. + for (uint32_t i = 0; i < kNumThreads; ++i) { + threads[i].reset(new std::thread(Runner, &g_counter_for_mt_slow, i + 1)); + } + // Variable to grab current counter value. + int32_t counter_value; + // The expected value of the counter. + int32_t total = 0; + // The expected value of GetDeadThreadsTotal(). + int32_t dead_total = 0; + // Each iteration of the following thread allows one additional + // iteration of the threads. Given that the threads perform + // different number of iterations from 1 through kNumThreads, one + // thread will complete in each of the iterations of the loop below. + for (uint32_t i = 0; i < kNumThreads; ++i) { + // Allow upto iteration i on all threads. + { + std::lock_guard lk(cv_m); + g_sync_for_mt = i + 1; + } + cv.notify_all(); + total += (kNumThreads - i) * 10; + // Loop until the counter reaches its expected value. + do { + counter_value = g_counter_for_mt_slow.readFull(); + } while (counter_value < total); + // All threads have done what they can until iteration i, now make + // sure they don't go further by checking 10 more times in the + // following loop. + for (uint32_t j = 0; j < 10; ++j) { + counter_value = g_counter_for_mt_slow.readFull(); + EXPECT_EQ(total, counter_value); + } + dead_total += (i + 1) * 10; + EXPECT_GE(dead_total, GetDeadThreadsTotal(g_counter_for_mt_slow)); + } + // All threads are done. + for (uint32_t i = 0; i < kNumThreads; ++i) { + threads[i]->join(); + } + counter_value = g_counter_for_mt_slow.readFull(); + EXPECT_EQ(total, counter_value); + EXPECT_EQ(total, dead_total); + EXPECT_EQ(dead_total, GetDeadThreadsTotal(g_counter_for_mt_slow)); +} + +// Fast test with lots of threads and only one call to readFull() +// at the end. +TEST_F(ThreadCachedIntTest, MultithreadedFast) { + static constexpr uint32_t kNumThreads = 1000; + g_sync_for_mt = 0; + vector> threads(kNumThreads); + // Creates kNumThreads threads. Each thread performs a different + // number of iterations in Runner() - threads[0] performs 1 + // iteration, threads[1] performs 2 iterations, threads[2] performs + // 3 iterations, and so on. + for (uint32_t i = 0; i < kNumThreads; ++i) { + threads[i].reset(new std::thread(Runner, &g_counter_for_mt_fast, i + 1)); + } + // Let the threads run to completion. + { + std::lock_guard lk(cv_m); + g_sync_for_mt = kNumThreads; + } + cv.notify_all(); + // The expected value of the counter. + uint32_t total = 0; + for (uint32_t i = 0; i < kNumThreads; ++i) { + total += (kNumThreads - i) * 10; + } + // Wait for all threads to complete. + for (uint32_t i = 0; i < kNumThreads; ++i) { + threads[i]->join(); + } + int32_t counter_value = g_counter_for_mt_fast.readFull(); + EXPECT_EQ(total, counter_value); + EXPECT_EQ(total, GetDeadThreadsTotal(g_counter_for_mt_fast)); +} + TEST(ThreadCachedInt, SingleThreadedNotCached) { ThreadCachedInt val(0, 0); EXPECT_EQ(0, val.readFast()); -- 2.34.1