Fix a race in Observable context destruction
[folly.git] / folly / experimental / observer / Observable-inl.h
index fdf62f0bc1c123d287cd05f63a44b87eed3005e3..231991e6adafab0c83827bd82b3061a73a4b7fb8 100644 (file)
@@ -22,7 +22,9 @@ template <typename Observable, typename Traits>
 class ObserverCreator<Observable, Traits>::Context {
  public:
   template <typename... Args>
-  Context(Args&&... args) : observable_(std::forward<Args>(args)...) {}
+  Context(Args&&... args) : observable_(std::forward<Args>(args)...) {
+    updateValue();
+  }
 
   ~Context() {
     if (value_.copy()) {
@@ -47,21 +49,11 @@ class ObserverCreator<Observable, Traits>::Context {
     // callbacks (getting new value from observable and storing it into value_
     // is not atomic).
     std::lock_guard<std::mutex> lg(updateMutex_);
-
-    {
-      auto newValue = Traits::get(observable_);
-      if (!newValue) {
-        throw std::logic_error("Observable returned nullptr.");
-      }
-      value_.swap(newValue);
-    }
+    updateValue();
 
     bool expected = false;
     if (updateRequested_.compare_exchange_strong(expected, true)) {
-      if (auto core = coreWeak_.lock()) {
-        observer_detail::ObserverManager::scheduleRefreshNewVersion(
-            std::move(core));
-      }
+      observer_detail::ObserverManager::scheduleRefreshNewVersion(coreWeak_);
     }
   }
 
@@ -71,6 +63,14 @@ class ObserverCreator<Observable, Traits>::Context {
   }
 
  private:
+  void updateValue() {
+    auto newValue = Traits::get(observable_);
+    if (!newValue) {
+      throw std::logic_error("Observable returned nullptr.");
+    }
+    value_.swap(newValue);
+  }
+
   folly::Synchronized<std::shared_ptr<const T>> value_;
   std::atomic<bool> updateRequested_{false};
 
@@ -89,24 +89,68 @@ ObserverCreator<Observable, Traits>::ObserverCreator(Args&&... args)
 template <typename Observable, typename Traits>
 Observer<typename ObserverCreator<Observable, Traits>::T>
 ObserverCreator<Observable, Traits>::getObserver()&& {
-  auto core = observer_detail::Core::create([context = context_]() {
+  // This master shared_ptr allows grabbing derived weak_ptrs, pointing to the
+  // the same Context object, but using a separate reference count. Master
+  // shared_ptr destructor then blocks until all shared_ptrs obtained from
+  // derived weak_ptrs are released.
+  class ContextMasterPointer {
+   public:
+    explicit ContextMasterPointer(std::shared_ptr<Context> context)
+        : contextMaster_(std::move(context)),
+          context_(
+              contextMaster_.get(),
+              [destroyBaton = destroyBaton_](Context*) {
+                destroyBaton->post();
+              }) {}
+    ~ContextMasterPointer() {
+      if (context_) {
+        context_.reset();
+        destroyBaton_->wait();
+      }
+    }
+    ContextMasterPointer(const ContextMasterPointer&) = delete;
+    ContextMasterPointer(ContextMasterPointer&&) = default;
+    ContextMasterPointer& operator=(const ContextMasterPointer&) = delete;
+    ContextMasterPointer& operator=(ContextMasterPointer&&) = default;
+
+    Context* operator->() const {
+      return contextMaster_.get();
+    }
+
+    std::weak_ptr<Context> get_weak() {
+      return context_;
+    }
+
+   private:
+    std::shared_ptr<folly::Baton<>> destroyBaton_{
+        std::make_shared<folly::Baton<>>()};
+    std::shared_ptr<Context> contextMaster_;
+    std::shared_ptr<Context> context_;
+  };
+  // We want to make sure that Context can only be destroyed when Core is
+  // destroyed. So we have to avoid the situation when subscribe callback is
+  // locking Context shared_ptr and remains the last to release it.
+  // We solve this by having Core hold the master shared_ptr and subscription
+  // callback gets derived weak_ptr.
+  ContextMasterPointer contextMaster(context_);
+  auto contextWeak = contextMaster.get_weak();
+  auto observer = makeObserver([context = std::move(contextMaster)]() {
     return context->get();
   });
 
-  context_->setCore(core);
-
-  context_->subscribe([contextWeak = std::weak_ptr<Context>(context_)] {
+  context_->setCore(observer.core_);
+  context_->subscribe([contextWeak = std::move(contextWeak)] {
     if (auto context = contextWeak.lock()) {
       context->update();
     }
   });
 
+  // Do an extra update in case observable was updated between observer creation
+  // and setting updates callback.
   context_->update();
   context_.reset();
 
-  DCHECK(core->getVersion() > 0);
-
-  return Observer<T>(std::move(core));
+  return observer;
 }
 }
 }