From 2642bd3dcf9658be7da3e5b5bc622fe051200e97 Mon Sep 17 00:00:00 2001 From: Andrii Grynenko Date: Mon, 1 May 2017 14:51:49 -0700 Subject: [PATCH] Fix a race in Observable context destruction Summary: In the subscribe callback It's possible that we lock the Context shared_ptr and while update is running, all other shared_ptr's are released. This will result in Context to be destroyed from the wrong thread (thread runnning subcribe callback), which is not desired. Reviewed By: yfeldblum Differential Revision: D4964605 fbshipit-source-id: 285327a6873ccb7393fa3067ba7e612c29dbc454 --- folly/experimental/observer/Observable-inl.h | 84 ++++++++++++++----- folly/experimental/observer/Observer-inl.h | 4 +- folly/experimental/observer/Observer.h | 3 + .../observer/detail/ObserverManager.cpp | 35 +++++--- .../observer/detail/ObserverManager.h | 16 ++-- .../observer/test/ObserverTest.cpp | 59 +++++++++++++ 6 files changed, 158 insertions(+), 43 deletions(-) diff --git a/folly/experimental/observer/Observable-inl.h b/folly/experimental/observer/Observable-inl.h index fdf62f0b..231991e6 100644 --- a/folly/experimental/observer/Observable-inl.h +++ b/folly/experimental/observer/Observable-inl.h @@ -22,7 +22,9 @@ template class ObserverCreator::Context { public: template - Context(Args&&... args) : observable_(std::forward(args)...) {} + Context(Args&&... args) : observable_(std::forward(args)...) { + updateValue(); + } ~Context() { if (value_.copy()) { @@ -47,21 +49,11 @@ class ObserverCreator::Context { // callbacks (getting new value from observable and storing it into value_ // is not atomic). std::lock_guard lg(updateMutex_); - - { - auto newValue = Traits::get(observable_); - if (!newValue) { - throw std::logic_error("Observable returned nullptr."); - } - value_.swap(newValue); - } + updateValue(); bool expected = false; if (updateRequested_.compare_exchange_strong(expected, true)) { - if (auto core = coreWeak_.lock()) { - observer_detail::ObserverManager::scheduleRefreshNewVersion( - std::move(core)); - } + observer_detail::ObserverManager::scheduleRefreshNewVersion(coreWeak_); } } @@ -71,6 +63,14 @@ class ObserverCreator::Context { } private: + void updateValue() { + auto newValue = Traits::get(observable_); + if (!newValue) { + throw std::logic_error("Observable returned nullptr."); + } + value_.swap(newValue); + } + folly::Synchronized> value_; std::atomic updateRequested_{false}; @@ -89,24 +89,68 @@ ObserverCreator::ObserverCreator(Args&&... args) template Observer::T> ObserverCreator::getObserver()&& { - auto core = observer_detail::Core::create([context = context_]() { + // This master shared_ptr allows grabbing derived weak_ptrs, pointing to the + // the same Context object, but using a separate reference count. Master + // shared_ptr destructor then blocks until all shared_ptrs obtained from + // derived weak_ptrs are released. + class ContextMasterPointer { + public: + explicit ContextMasterPointer(std::shared_ptr context) + : contextMaster_(std::move(context)), + context_( + contextMaster_.get(), + [destroyBaton = destroyBaton_](Context*) { + destroyBaton->post(); + }) {} + ~ContextMasterPointer() { + if (context_) { + context_.reset(); + destroyBaton_->wait(); + } + } + ContextMasterPointer(const ContextMasterPointer&) = delete; + ContextMasterPointer(ContextMasterPointer&&) = default; + ContextMasterPointer& operator=(const ContextMasterPointer&) = delete; + ContextMasterPointer& operator=(ContextMasterPointer&&) = default; + + Context* operator->() const { + return contextMaster_.get(); + } + + std::weak_ptr get_weak() { + return context_; + } + + private: + std::shared_ptr> destroyBaton_{ + std::make_shared>()}; + std::shared_ptr contextMaster_; + std::shared_ptr context_; + }; + // We want to make sure that Context can only be destroyed when Core is + // destroyed. So we have to avoid the situation when subscribe callback is + // locking Context shared_ptr and remains the last to release it. + // We solve this by having Core hold the master shared_ptr and subscription + // callback gets derived weak_ptr. + ContextMasterPointer contextMaster(context_); + auto contextWeak = contextMaster.get_weak(); + auto observer = makeObserver([context = std::move(contextMaster)]() { return context->get(); }); - context_->setCore(core); - - context_->subscribe([contextWeak = std::weak_ptr(context_)] { + context_->setCore(observer.core_); + context_->subscribe([contextWeak = std::move(contextWeak)] { if (auto context = contextWeak.lock()) { context->update(); } }); + // Do an extra update in case observable was updated between observer creation + // and setting updates callback. context_->update(); context_.reset(); - DCHECK(core->getVersion() > 0); - - return Observer(std::move(core)); + return observer; } } } diff --git a/folly/experimental/observer/Observer-inl.h b/folly/experimental/observer/Observer-inl.h index 55088cdf..bdc62a09 100644 --- a/folly/experimental/observer/Observer-inl.h +++ b/folly/experimental/observer/Observer-inl.h @@ -38,10 +38,10 @@ Observer> makeObserver( F&& creator) { auto core = observer_detail::Core:: create([creator = std::forward(creator)]() mutable { - return std::static_pointer_cast(creator()); + return std::static_pointer_cast(creator()); }); - observer_detail::ObserverManager::scheduleRefreshNewVersion(core); + observer_detail::ObserverManager::initCore(core); return Observer>(core); } diff --git a/folly/experimental/observer/Observer.h b/folly/experimental/observer/Observer.h index 662a0113..192293e5 100644 --- a/folly/experimental/observer/Observer.h +++ b/folly/experimental/observer/Observer.h @@ -134,6 +134,9 @@ class Observer { } private: + template + friend class ObserverCreator; + observer_detail::Core::Ptr core_; }; diff --git a/folly/experimental/observer/detail/ObserverManager.cpp b/folly/experimental/observer/detail/ObserverManager.cpp index 7654dff5..f909ef57 100644 --- a/folly/experimental/observer/detail/ObserverManager.cpp +++ b/folly/experimental/observer/detail/ObserverManager.cpp @@ -106,28 +106,35 @@ class ObserverManager::NextQueue { explicit NextQueue(ObserverManager& manager) : manager_(manager), queue_(kNextQueueSize) { thread_ = std::thread([&]() { - Core::Ptr queueCore; + Core::WeakPtr queueCoreWeak; while (true) { - queue_.blockingRead(queueCore); - - if (!queueCore) { + queue_.blockingRead(queueCoreWeak); + if (stop_) { return; } std::vector cores; - cores.emplace_back(std::move(queueCore)); + { + auto queueCore = queueCoreWeak.lock(); + if (!queueCore) { + continue; + } + cores.emplace_back(std::move(queueCore)); + } { SharedMutexReadPriority::WriteHolder wh(manager_.versionMutex_); // We can't pick more tasks from the queue after we bumped the // version, so we have to do this while holding the lock. - while (cores.size() < kNextQueueSize && queue_.read(queueCore)) { - if (!queueCore) { + while (cores.size() < kNextQueueSize && queue_.read(queueCoreWeak)) { + if (stop_) { return; } - cores.emplace_back(std::move(queueCore)); + if (auto queueCore = queueCoreWeak.lock()) { + cores.emplace_back(std::move(queueCore)); + } } ++manager_.version_; @@ -140,20 +147,22 @@ class ObserverManager::NextQueue { }); } - void add(Core::Ptr core) { + void add(Core::WeakPtr core) { queue_.blockingWrite(std::move(core)); } ~NextQueue() { - // Emtpy element signals thread to terminate - queue_.blockingWrite(nullptr); + stop_ = true; + // Write to the queue to notify the thread. + queue_.blockingWrite(Core::WeakPtr()); thread_.join(); } private: ObserverManager& manager_; - MPMCQueue queue_; + MPMCQueue queue_; std::thread thread_; + std::atomic stop_{false}; }; ObserverManager::ObserverManager() { @@ -172,7 +181,7 @@ void ObserverManager::scheduleCurrent(Function task) { currentQueue_->add(std::move(task)); } -void ObserverManager::scheduleNext(Core::Ptr core) { +void ObserverManager::scheduleNext(Core::WeakPtr core) { nextQueue_->add(std::move(core)); } diff --git a/folly/experimental/observer/detail/ObserverManager.h b/folly/experimental/observer/detail/ObserverManager.h index 5e206dfb..cfb1e70a 100644 --- a/folly/experimental/observer/detail/ObserverManager.h +++ b/folly/experimental/observer/detail/ObserverManager.h @@ -93,19 +93,19 @@ class ObserverManager { return future; } - static void scheduleRefreshNewVersion(Core::Ptr core) { - if (core->getVersion() == 0) { - scheduleRefresh(std::move(core), 1).get(); - return; - } - + static void scheduleRefreshNewVersion(Core::WeakPtr coreWeak) { auto instance = getInstance(); if (!instance) { return; } - instance->scheduleNext(std::move(core)); + instance->scheduleNext(std::move(coreWeak)); + } + + static void initCore(Core::Ptr core) { + DCHECK(core->getVersion() == 0); + scheduleRefresh(std::move(core), 1).get(); } class DependencyRecorder { @@ -189,7 +189,7 @@ class ObserverManager { struct Singleton; void scheduleCurrent(Function); - void scheduleNext(Core::Ptr); + void scheduleNext(Core::WeakPtr); class CurrentQueue; class NextQueue; diff --git a/folly/experimental/observer/test/ObserverTest.cpp b/folly/experimental/observer/test/ObserverTest.cpp index 62fcf57b..ed372f5a 100644 --- a/folly/experimental/observer/test/ObserverTest.cpp +++ b/folly/experimental/observer/test/ObserverTest.cpp @@ -262,3 +262,62 @@ TEST(Observer, TLObserver) { k = std::make_unique>(createTLObserver(41)); EXPECT_EQ(41, ***k); } + +TEST(Observer, SubscribeCallback) { + static auto mainThreadId = std::this_thread::get_id(); + static std::function updatesCob; + static bool slowGet = false; + static std::atomic getCallsStart{0}; + static std::atomic getCallsFinish{0}; + + struct Observable { + ~Observable() { + EXPECT_EQ(mainThreadId, std::this_thread::get_id()); + } + }; + struct Traits { + using element_type = int; + static std::shared_ptr get(Observable&) { + ++getCallsStart; + if (slowGet) { + /* sleep override */ std::this_thread::sleep_for( + std::chrono::seconds{2}); + } + ++getCallsFinish; + return std::make_shared(42); + } + + static void subscribe(Observable&, std::function cob) { + updatesCob = std::move(cob); + } + + static void unsubscribe(Observable&) {} + }; + + std::thread cobThread; + { + auto observer = + folly::observer::ObserverCreator().getObserver(); + + EXPECT_TRUE(updatesCob); + EXPECT_EQ(2, getCallsStart); + EXPECT_EQ(2, getCallsFinish); + + updatesCob(); + EXPECT_EQ(3, getCallsStart); + EXPECT_EQ(3, getCallsFinish); + + slowGet = true; + cobThread = std::thread([] { updatesCob(); }); + /* sleep override */ std::this_thread::sleep_for(std::chrono::seconds{1}); + EXPECT_EQ(4, getCallsStart); + EXPECT_EQ(3, getCallsFinish); + + // Observer is destroyed here + } + + // Make sure that destroying the observer actually joined the updates callback + EXPECT_EQ(4, getCallsStart); + EXPECT_EQ(4, getCallsFinish); + cobThread.join(); +} -- 2.34.1