subscriptions
authorJames Sedgwick <jsedgwick@fb.com>
Thu, 2 Oct 2014 03:49:37 +0000 (20:49 -0700)
committerAndrii Grynenko <andrii@fb.com>
Wed, 15 Oct 2014 00:48:03 +0000 (17:48 -0700)
Summary:
I'm not thrilled with this implementation, but it mostly works. The big bummer is enforcing that Observables are held in shared_ptrs, or rather that enforcing that condition is impossible.
The protected constructor / friended dumb make_shared() pattern is clunky, and it'd be really easy for a subclasser to shoot themselves in the foot or even in the face.

It does seem like maybe Observable should be made an interface again, and all these details should live in a subclass (FanoutObservable?) where the restriction are super obvious. For instance, the langtech AudioStream object doesn't need all this crap because it overrides subscribe() without using the observer list, but it subclasses anyways.

I'm noodling another approach that (if it works) will not require the shared_ptr dancing, but will probably have some additional overhead... the observable would have to keep track of the subscriptions itself.

I like the RAII subscriptions, though perhaps subscriptions should be optional as long as it's clear that your subscription will last forever it you opt out of them.

Thoughts?

Test Plan: added unit

Reviewed By: hans@fb.com

Subscribers: rushix, hannesr, trunkagent, fugalh, mwa, jgehring, fuegen, njormrod, bmatheny

FB internal diff: D1580443

folly/experimental/wangle/concurrent/ThreadPoolExecutor.h
folly/experimental/wangle/concurrent/test/ThreadPoolExecutorTest.cpp
folly/experimental/wangle/rx/Observable.h
folly/experimental/wangle/rx/Observer.h
folly/experimental/wangle/rx/Subject.h
folly/experimental/wangle/rx/Subscription.h
folly/experimental/wangle/rx/test/RxTest.cpp

index 7cbbb3215ea6aa97959a2922c6675c35c7779b36..88c25e88efa8e2fec62b066ea998048061b6a4b5 100644 (file)
@@ -66,7 +66,7 @@ class ThreadPoolExecutor : public experimental::Executor {
     std::chrono::nanoseconds runTime;
   };
 
-  Subscription subscribeToTaskStats(
+  Subscription<TaskStats> subscribeToTaskStats(
       const ObserverPtr<TaskStats>& observer) {
     return taskStatsSubject_.subscribe(observer);
   }
index e3336615484ef1aa636f2a562995fbf7fda8163a..471c8c6dc63d057311ed3aaa114bb500c5361147 100644 (file)
@@ -165,8 +165,9 @@ template <class TPE>
 static void taskStats() {
   TPE tpe(1);
   std::atomic<int> c(0);
-  tpe.subscribeToTaskStats(Observer<ThreadPoolExecutor::TaskStats>::create(
-      [&] (ThreadPoolExecutor::TaskStats stats) {
+  auto s = tpe.subscribeToTaskStats(
+      Observer<ThreadPoolExecutor::TaskStats>::create(
+          [&](ThreadPoolExecutor::TaskStats stats) {
         int i = c++;
         EXPECT_LT(milliseconds(0), stats.runTime);
         if (i == 1) {
@@ -191,8 +192,9 @@ template <class TPE>
 static void expiration() {
   TPE tpe(1);
   std::atomic<int> statCbCount(0);
-  tpe.subscribeToTaskStats(Observer<ThreadPoolExecutor::TaskStats>::create(
-      [&] (ThreadPoolExecutor::TaskStats stats) {
+  auto s = tpe.subscribeToTaskStats(
+      Observer<ThreadPoolExecutor::TaskStats>::create(
+          [&](ThreadPoolExecutor::TaskStats stats) {
         int i = statCbCount++;
         if (i == 0) {
           EXPECT_FALSE(stats.expired);
index f85d52a252c42e39734169a0f7ad79981da6338a..c4bc70da85214acbeb78afc1841e37e96e5fca8f 100644 (file)
 
 #pragma once
 
-#include "types.h"
-#include "Subject.h"
-#include "Subscription.h"
+#include <folly/experimental/wangle/rx/Subject.h>
+#include <folly/experimental/wangle/rx/Subscription.h>
+#include <folly/experimental/wangle/rx/types.h>
 
 #include <folly/RWSpinLock.h>
+#include <folly/SmallLocks.h>
 #include <folly/ThreadLocal.h>
 #include <folly/wangle/Executor.h>
-#include <list>
+#include <map>
 #include <memory>
 
 namespace folly { namespace wangle {
@@ -31,32 +32,36 @@ namespace folly { namespace wangle {
 template <class T>
 class Observable {
  public:
-  Observable() = default;
-  Observable(Observable&& other) noexcept {
-    RWSpinLock::WriteHolder{&other.observersLock_};
-    observers_ = std::move(other.observers_);
+  Observable() : nextSubscriptionId_{1} {}
+
+  // TODO perhaps we want to provide this #5283229
+  Observable(Observable&& other) = delete;
+
+  virtual ~Observable() {
+    if (unsubscriber_) {
+      unsubscriber_->disable();
+    }
   }
 
-  virtual ~Observable() = default;
+  typedef typename std::map<uint64_t, ObserverPtr<T>> ObserverMap;
 
-  /// Subscribe the given Observer to this Observable.
-  // Eventually this will return a Subscription object of some kind, in order
-  // to support cancellation. This is kinda really important. Maybe I should
-  // just do it now, using an dummy Subscription object.
+  // Subscribe the given Observer to this Observable.
   //
   // If this is called within an Observer callback, the new observer will not
   // get the current update but will get subsequent updates.
-  virtual Subscription subscribe(ObserverPtr<T> o) {
+  virtual Subscription<T> subscribe(ObserverPtr<T> observer) {
+    auto subscription = makeSubscription();
+    typename ObserverMap::value_type kv{subscription.id_, std::move(observer)};
     if (inCallback_ && *inCallback_) {
       if (!newObservers_) {
-        newObservers_.reset(new std::list<ObserverPtr<T>>());
+        newObservers_.reset(new ObserverMap());
       }
-      newObservers_->push_back(o);
+      newObservers_->insert(std::move(kv));
     } else {
       RWSpinLock::WriteHolder{&observersLock_};
-      observers_.push_back(o);
+      observers_.insert(std::move(kv));
     }
-    return Subscription();
+    return subscription;
   }
 
   /// Returns a new Observable that will call back on the given Scheduler.
@@ -76,7 +81,7 @@ class Observable {
         : scheduler_(scheduler), observable_(obs)
       {}
 
-      Subscription subscribe(ObserverPtr<T> o) override {
+      Subscription<T> subscribe(ObserverPtr<T> o) override {
         return observable_->subscribe(
           Observer<T>::create(
             [=](T val) { scheduler_->add([o, val] { o->onNext(val); }); },
@@ -101,11 +106,11 @@ class Observable {
       Subject_(SchedulerPtr s, Observable* o) : scheduler_(s), observable_(o) {
       }
 
-      Subscription subscribe(ObserverPtr<T> o) {
+      Subscription<T> subscribe(ObserverPtr<T> o) {
         scheduler_->add([=] {
           observable_->subscribe(o);
         });
-        return Subscription();
+        return Subscription<T>(nullptr, 0); // TODO
       }
 
      protected:
@@ -117,7 +122,7 @@ class Observable {
   }
 
  protected:
-  const std::list<ObserverPtr<T>>& getObservers() {
+  const ObserverMap& getObservers() {
     return observers_;
   }
 
@@ -138,14 +143,23 @@ class Observable {
 
     ~ObserversGuard() {
       o_->observersLock_.unlock_shared();
-      if (UNLIKELY(o_->newObservers_ && !o_->newObservers_->empty())) {
+      if (UNLIKELY((o_->newObservers_ && !o_->newObservers_->empty()) ||
+                   (o_->oldObservers_ && !o_->oldObservers_->empty()))) {
         {
           RWSpinLock::WriteHolder(o_->observersLock_);
-          for (auto& o : *(o_->newObservers_)) {
-            o_->observers_.push_back(o);
+          if (o_->newObservers_) {
+            for (auto& kv : *(o_->newObservers_)) {
+              o_->observers_.insert(std::move(kv));
+            }
+            o_->newObservers_->clear();
+          }
+          if (o_->oldObservers_) {
+            for (auto id : *(o_->oldObservers_)) {
+              o_->observers_.erase(id);
+            }
+            o_->oldObservers_->clear();
           }
         }
-        o_->newObservers_->clear();
       }
       *o_->inCallback_ = false;
     }
@@ -155,10 +169,70 @@ class Observable {
   };
 
  private:
-  std::list<ObserverPtr<T>> observers_;
+  class Unsubscriber {
+   public:
+    explicit Unsubscriber(Observable* observable) : observable_(observable) {
+      CHECK(observable_);
+    }
+
+    void unsubscribe(uint64_t id) {
+      CHECK(id > 0);
+      RWSpinLock::ReadHolder guard(lock_);
+      if (observable_) {
+        observable_->unsubscribe(id);
+      }
+    }
+
+    void disable() {
+      RWSpinLock::WriteHolder guard(lock_);
+      observable_ = nullptr;
+    }
+
+   private:
+    RWSpinLock lock_;
+    Observable* observable_;
+  };
+
+  std::shared_ptr<Unsubscriber> unsubscriber_{nullptr};
+  MicroSpinLock unsubscriberLock_{0};
+
+  friend class Subscription<T>;
+
+  void unsubscribe(uint64_t id) {
+    if (inCallback_ && *inCallback_) {
+      if (!oldObservers_) {
+        oldObservers_.reset(new std::vector<uint64_t>());
+      }
+      if (newObservers_) {
+        auto it = newObservers_->find(id);
+        if (it != newObservers_->end()) {
+          newObservers_->erase(it);
+          return;
+        }
+      }
+      oldObservers_->push_back(id);
+    } else {
+      RWSpinLock::WriteHolder{&observersLock_};
+      observers_.erase(id);
+    }
+  }
+
+  Subscription<T> makeSubscription() {
+    if (!unsubscriber_) {
+      std::lock_guard<MicroSpinLock> guard(unsubscriberLock_);
+      if (!unsubscriber_) {
+        unsubscriber_ = std::make_shared<Unsubscriber>(this);
+      }
+    }
+    return Subscription<T>(unsubscriber_, nextSubscriptionId_++);
+  }
+
+  std::atomic<uint64_t> nextSubscriptionId_;
+  ObserverMap observers_;
   RWSpinLock observersLock_;
   folly::ThreadLocalPtr<bool> inCallback_;
-  folly::ThreadLocalPtr<std::list<ObserverPtr<T>>> newObservers_;
+  folly::ThreadLocalPtr<ObserverMap> newObservers_;
+  folly::ThreadLocalPtr<std::vector<uint64_t>> oldObservers_;
 };
 
 }}
index 8d4bbb42f3bcc443e5395f9a2b05980163c51232..e2a7486b715c44c51ff11e589c1fbb197fc13fcd 100644 (file)
@@ -16,7 +16,7 @@
 
 #pragma once
 
-#include "types.h"
+#include <folly/experimental/wangle/rx/types.h>
 #include <functional>
 #include <memory>
 #include <stdexcept>
index 7d4c7cb8da46464cfaa900171dc2fed80e78beb1..67857205c6c6c683e5cced378fbde858842bff7c 100644 (file)
@@ -15,8 +15,9 @@
  */
 
 #pragma once
-#include "Observable.h"
-#include "Observer.h"
+
+#include <folly/experimental/wangle/rx/Observable.h>
+#include <folly/experimental/wangle/rx/Observer.h>
 
 namespace folly { namespace wangle {
 
@@ -28,20 +29,20 @@ struct Subject : public Observable<T>, public Observer<T> {
   typedef typename Observable<T>::ObserversGuard ObserversGuard;
   void onNext(T val) override {
     ObserversGuard guard(this);
-    for (auto& o : Observable<T>::getObservers()) {
-      o->onNext(val);
+    for (auto& kv : Observable<T>::getObservers()) {
+      kv.second->onNext(val);
     }
   }
   void onError(Error e) override {
     ObserversGuard guard(this);
-    for (auto& o : Observable<T>::getObservers()) {
-      o->onError(e);
+    for (auto& kv : Observable<T>::getObservers()) {
+      kv.second->onError(e);
     }
   }
   void onCompleted() override {
     ObserversGuard guard(this);
-    for (auto& o : Observable<T>::getObservers()) {
-      o->onCompleted();
+    for (auto& kv : Observable<T>::getObservers()) {
+      kv.second->onCompleted();
     }
   }
 };
index 16406a23f474fb361e90b0376f398dbcb6620fbf..0cf667e6445b1ad620d2f854cd5183ebef246e9e 100644 (file)
 
 #pragma once
 
+#include <folly/experimental/wangle/rx/Observable.h>
+
 namespace folly { namespace wangle {
 
-// TODO
-struct Subscription {
+template <class T>
+class Subscription {
+ public:
+  Subscription() {}
+
+  Subscription(const Subscription&) = delete;
+
+  Subscription(Subscription&& other) noexcept {
+    *this = std::move(other);
+  }
+
+  Subscription& operator=(Subscription&& other) noexcept {
+    unsubscribe();
+    unsubscriber_ = std::move(other.unsubscriber_);
+    id_ = other.id_;
+    other.unsubscriber_ = nullptr;
+    other.id_ = 0;
+    return *this;
+  }
+
+  ~Subscription() {
+    unsubscribe();
+  }
+
+ private:
+  typedef typename Observable<T>::Unsubscriber Unsubscriber;
+
+  Subscription(std::shared_ptr<Unsubscriber> unsubscriber, uint64_t id)
+    : unsubscriber_(std::move(unsubscriber)), id_(id) {
+    CHECK(unsubscriber_);
+    CHECK(id_ > 0);
+  }
+
+  void unsubscribe() {
+    if (unsubscriber_ && id_ > 0) {
+      unsubscriber_->unsubscribe(id_);
+      id_ = 0;
+      unsubscriber_ = nullptr;
+    }
+  }
+
+  std::shared_ptr<Unsubscriber> unsubscriber_;
+  uint64_t id_{0};
+
+  friend class Observable<T>;
 };
 
 }}
index cf4d9dd006b6debb3fa3edde02efd8e8e22048a4..a62d11ce476f043002f18ff40530f35756739385 100644 (file)
 
 using namespace folly::wangle;
 
+static std::unique_ptr<Observer<int>> incrementer(int& counter) {
+  return Observer<int>::create([&] (int x) {
+    counter++;
+  });
+}
+
+TEST(RxTest, Subscription) {
+  Subject<int> subject;
+  auto count = 0;
+  {
+    auto s = subject.subscribe(incrementer(count));
+    subject.onNext(1);
+  }
+  // The subscription has gone out of scope so no one should get this.
+  subject.onNext(2);
+  EXPECT_EQ(1, count);
+}
+
+TEST(RxTest, SubscriptionMove) {
+  Subject<int> subject;
+  auto count = 0;
+  auto s = subject.subscribe(incrementer(count));
+  auto s2 = subject.subscribe(incrementer(count));
+  s2 = std::move(s);
+  subject.onNext(1);
+  Subscription<int> s3(std::move(s2));
+  subject.onNext(2);
+  EXPECT_EQ(2, count);
+}
+
+TEST(RxTest, SubscriptionOutlivesSubject) {
+  Subscription<int> s;
+  {
+    Subject<int> subject;
+    s = subject.subscribe(Observer<int>::create([](int){}));
+  }
+  // Don't explode when s is destroyed
+}
+
 TEST(RxTest, SubscribeDuringCallback) {
   // A subscriber who was subscribed in the course of a callback should get
   // subsequent updates but not the current update.
   Subject<int> subject;
-  int outerCount = 0;
-  int innerCount = 0;
-  subject.subscribe(Observer<int>::create([&] (int x) {
+  int outerCount = 0, innerCount = 0;
+  Subscription<int> s1, s2;
+  s1 = subject.subscribe(Observer<int>::create([&] (int x) {
     outerCount++;
-    subject.subscribe(Observer<int>::create([&] (int y) {
-      innerCount++;
-    }));
+    s2 = subject.subscribe(incrementer(innerCount));
   }));
   subject.onNext(42);
   subject.onNext(0xDEADBEEF);
   EXPECT_EQ(2, outerCount);
   EXPECT_EQ(1, innerCount);
 }
+
+TEST(RxTest, UnsubscribeDuringCallback) {
+  // A subscriber who was unsubscribed in the course of a callback should get
+  // the current update but not subsequent ones
+  Subject<int> subject;
+  int count1 = 0, count2 = 0;
+  auto s1 = subject.subscribe(incrementer(count1));
+  auto s2 = subject.subscribe(Observer<int>::create([&] (int x) {
+    count2++;
+    s1.~Subscription();
+  }));
+  subject.onNext(1);
+  subject.onNext(2);
+  EXPECT_EQ(1, count1);
+  EXPECT_EQ(2, count2);
+}
+
+TEST(RxTest, SubscribeUnsubscribeDuringCallback) {
+  // A subscriber who was subscribed and unsubscribed in the course of a
+  // callback should not get any updates
+  Subject<int> subject;
+  int outerCount = 0, innerCount = 0;
+  auto s2 = subject.subscribe(Observer<int>::create([&] (int x) {
+    outerCount++;
+    auto s2 = subject.subscribe(incrementer(innerCount));
+  }));
+  subject.onNext(1);
+  subject.onNext(2);
+  EXPECT_EQ(2, outerCount);
+  EXPECT_EQ(0, innerCount);
+}