From 30d69c32b8f6952ac5c0f4943ef0c5ddb1917fec Mon Sep 17 00:00:00 2001 From: James Sedgwick Date: Wed, 1 Oct 2014 20:49:37 -0700 Subject: [PATCH] subscriptions Summary: I'm not thrilled with this implementation, but it mostly works. The big bummer is enforcing that Observables are held in shared_ptrs, or rather that enforcing that condition is impossible. The protected constructor / friended dumb make_shared() pattern is clunky, and it'd be really easy for a subclasser to shoot themselves in the foot or even in the face. It does seem like maybe Observable should be made an interface again, and all these details should live in a subclass (FanoutObservable?) where the restriction are super obvious. For instance, the langtech AudioStream object doesn't need all this crap because it overrides subscribe() without using the observer list, but it subclasses anyways. I'm noodling another approach that (if it works) will not require the shared_ptr dancing, but will probably have some additional overhead... the observable would have to keep track of the subscriptions itself. I like the RAII subscriptions, though perhaps subscriptions should be optional as long as it's clear that your subscription will last forever it you opt out of them. Thoughts? Test Plan: added unit Reviewed By: hans@fb.com Subscribers: rushix, hannesr, trunkagent, fugalh, mwa, jgehring, fuegen, njormrod, bmatheny FB internal diff: D1580443 --- .../wangle/concurrent/ThreadPoolExecutor.h | 2 +- .../test/ThreadPoolExecutorTest.cpp | 10 +- folly/experimental/wangle/rx/Observable.h | 130 ++++++++++++++---- folly/experimental/wangle/rx/Observer.h | 2 +- folly/experimental/wangle/rx/Subject.h | 17 +-- folly/experimental/wangle/rx/Subscription.h | 49 ++++++- folly/experimental/wangle/rx/test/RxTest.cpp | 80 ++++++++++- 7 files changed, 240 insertions(+), 50 deletions(-) diff --git a/folly/experimental/wangle/concurrent/ThreadPoolExecutor.h b/folly/experimental/wangle/concurrent/ThreadPoolExecutor.h index 7cbbb321..88c25e88 100644 --- a/folly/experimental/wangle/concurrent/ThreadPoolExecutor.h +++ b/folly/experimental/wangle/concurrent/ThreadPoolExecutor.h @@ -66,7 +66,7 @@ class ThreadPoolExecutor : public experimental::Executor { std::chrono::nanoseconds runTime; }; - Subscription subscribeToTaskStats( + Subscription subscribeToTaskStats( const ObserverPtr& observer) { return taskStatsSubject_.subscribe(observer); } diff --git a/folly/experimental/wangle/concurrent/test/ThreadPoolExecutorTest.cpp b/folly/experimental/wangle/concurrent/test/ThreadPoolExecutorTest.cpp index e3336615..471c8c6d 100644 --- a/folly/experimental/wangle/concurrent/test/ThreadPoolExecutorTest.cpp +++ b/folly/experimental/wangle/concurrent/test/ThreadPoolExecutorTest.cpp @@ -165,8 +165,9 @@ template static void taskStats() { TPE tpe(1); std::atomic c(0); - tpe.subscribeToTaskStats(Observer::create( - [&] (ThreadPoolExecutor::TaskStats stats) { + auto s = tpe.subscribeToTaskStats( + Observer::create( + [&](ThreadPoolExecutor::TaskStats stats) { int i = c++; EXPECT_LT(milliseconds(0), stats.runTime); if (i == 1) { @@ -191,8 +192,9 @@ template static void expiration() { TPE tpe(1); std::atomic statCbCount(0); - tpe.subscribeToTaskStats(Observer::create( - [&] (ThreadPoolExecutor::TaskStats stats) { + auto s = tpe.subscribeToTaskStats( + Observer::create( + [&](ThreadPoolExecutor::TaskStats stats) { int i = statCbCount++; if (i == 0) { EXPECT_FALSE(stats.expired); diff --git a/folly/experimental/wangle/rx/Observable.h b/folly/experimental/wangle/rx/Observable.h index f85d52a2..c4bc70da 100644 --- a/folly/experimental/wangle/rx/Observable.h +++ b/folly/experimental/wangle/rx/Observable.h @@ -16,14 +16,15 @@ #pragma once -#include "types.h" -#include "Subject.h" -#include "Subscription.h" +#include +#include +#include #include +#include #include #include -#include +#include #include namespace folly { namespace wangle { @@ -31,32 +32,36 @@ namespace folly { namespace wangle { template class Observable { public: - Observable() = default; - Observable(Observable&& other) noexcept { - RWSpinLock::WriteHolder{&other.observersLock_}; - observers_ = std::move(other.observers_); + Observable() : nextSubscriptionId_{1} {} + + // TODO perhaps we want to provide this #5283229 + Observable(Observable&& other) = delete; + + virtual ~Observable() { + if (unsubscriber_) { + unsubscriber_->disable(); + } } - virtual ~Observable() = default; + typedef typename std::map> ObserverMap; - /// Subscribe the given Observer to this Observable. - // Eventually this will return a Subscription object of some kind, in order - // to support cancellation. This is kinda really important. Maybe I should - // just do it now, using an dummy Subscription object. + // Subscribe the given Observer to this Observable. // // If this is called within an Observer callback, the new observer will not // get the current update but will get subsequent updates. - virtual Subscription subscribe(ObserverPtr o) { + virtual Subscription subscribe(ObserverPtr observer) { + auto subscription = makeSubscription(); + typename ObserverMap::value_type kv{subscription.id_, std::move(observer)}; if (inCallback_ && *inCallback_) { if (!newObservers_) { - newObservers_.reset(new std::list>()); + newObservers_.reset(new ObserverMap()); } - newObservers_->push_back(o); + newObservers_->insert(std::move(kv)); } else { RWSpinLock::WriteHolder{&observersLock_}; - observers_.push_back(o); + observers_.insert(std::move(kv)); } - return Subscription(); + return subscription; } /// Returns a new Observable that will call back on the given Scheduler. @@ -76,7 +81,7 @@ class Observable { : scheduler_(scheduler), observable_(obs) {} - Subscription subscribe(ObserverPtr o) override { + Subscription subscribe(ObserverPtr o) override { return observable_->subscribe( Observer::create( [=](T val) { scheduler_->add([o, val] { o->onNext(val); }); }, @@ -101,11 +106,11 @@ class Observable { Subject_(SchedulerPtr s, Observable* o) : scheduler_(s), observable_(o) { } - Subscription subscribe(ObserverPtr o) { + Subscription subscribe(ObserverPtr o) { scheduler_->add([=] { observable_->subscribe(o); }); - return Subscription(); + return Subscription(nullptr, 0); // TODO } protected: @@ -117,7 +122,7 @@ class Observable { } protected: - const std::list>& getObservers() { + const ObserverMap& getObservers() { return observers_; } @@ -138,14 +143,23 @@ class Observable { ~ObserversGuard() { o_->observersLock_.unlock_shared(); - if (UNLIKELY(o_->newObservers_ && !o_->newObservers_->empty())) { + if (UNLIKELY((o_->newObservers_ && !o_->newObservers_->empty()) || + (o_->oldObservers_ && !o_->oldObservers_->empty()))) { { RWSpinLock::WriteHolder(o_->observersLock_); - for (auto& o : *(o_->newObservers_)) { - o_->observers_.push_back(o); + if (o_->newObservers_) { + for (auto& kv : *(o_->newObservers_)) { + o_->observers_.insert(std::move(kv)); + } + o_->newObservers_->clear(); + } + if (o_->oldObservers_) { + for (auto id : *(o_->oldObservers_)) { + o_->observers_.erase(id); + } + o_->oldObservers_->clear(); } } - o_->newObservers_->clear(); } *o_->inCallback_ = false; } @@ -155,10 +169,70 @@ class Observable { }; private: - std::list> observers_; + class Unsubscriber { + public: + explicit Unsubscriber(Observable* observable) : observable_(observable) { + CHECK(observable_); + } + + void unsubscribe(uint64_t id) { + CHECK(id > 0); + RWSpinLock::ReadHolder guard(lock_); + if (observable_) { + observable_->unsubscribe(id); + } + } + + void disable() { + RWSpinLock::WriteHolder guard(lock_); + observable_ = nullptr; + } + + private: + RWSpinLock lock_; + Observable* observable_; + }; + + std::shared_ptr unsubscriber_{nullptr}; + MicroSpinLock unsubscriberLock_{0}; + + friend class Subscription; + + void unsubscribe(uint64_t id) { + if (inCallback_ && *inCallback_) { + if (!oldObservers_) { + oldObservers_.reset(new std::vector()); + } + if (newObservers_) { + auto it = newObservers_->find(id); + if (it != newObservers_->end()) { + newObservers_->erase(it); + return; + } + } + oldObservers_->push_back(id); + } else { + RWSpinLock::WriteHolder{&observersLock_}; + observers_.erase(id); + } + } + + Subscription makeSubscription() { + if (!unsubscriber_) { + std::lock_guard guard(unsubscriberLock_); + if (!unsubscriber_) { + unsubscriber_ = std::make_shared(this); + } + } + return Subscription(unsubscriber_, nextSubscriptionId_++); + } + + std::atomic nextSubscriptionId_; + ObserverMap observers_; RWSpinLock observersLock_; folly::ThreadLocalPtr inCallback_; - folly::ThreadLocalPtr>> newObservers_; + folly::ThreadLocalPtr newObservers_; + folly::ThreadLocalPtr> oldObservers_; }; }} diff --git a/folly/experimental/wangle/rx/Observer.h b/folly/experimental/wangle/rx/Observer.h index 8d4bbb42..e2a7486b 100644 --- a/folly/experimental/wangle/rx/Observer.h +++ b/folly/experimental/wangle/rx/Observer.h @@ -16,7 +16,7 @@ #pragma once -#include "types.h" +#include #include #include #include diff --git a/folly/experimental/wangle/rx/Subject.h b/folly/experimental/wangle/rx/Subject.h index 7d4c7cb8..67857205 100644 --- a/folly/experimental/wangle/rx/Subject.h +++ b/folly/experimental/wangle/rx/Subject.h @@ -15,8 +15,9 @@ */ #pragma once -#include "Observable.h" -#include "Observer.h" + +#include +#include namespace folly { namespace wangle { @@ -28,20 +29,20 @@ struct Subject : public Observable, public Observer { typedef typename Observable::ObserversGuard ObserversGuard; void onNext(T val) override { ObserversGuard guard(this); - for (auto& o : Observable::getObservers()) { - o->onNext(val); + for (auto& kv : Observable::getObservers()) { + kv.second->onNext(val); } } void onError(Error e) override { ObserversGuard guard(this); - for (auto& o : Observable::getObservers()) { - o->onError(e); + for (auto& kv : Observable::getObservers()) { + kv.second->onError(e); } } void onCompleted() override { ObserversGuard guard(this); - for (auto& o : Observable::getObservers()) { - o->onCompleted(); + for (auto& kv : Observable::getObservers()) { + kv.second->onCompleted(); } } }; diff --git a/folly/experimental/wangle/rx/Subscription.h b/folly/experimental/wangle/rx/Subscription.h index 16406a23..0cf667e6 100644 --- a/folly/experimental/wangle/rx/Subscription.h +++ b/folly/experimental/wangle/rx/Subscription.h @@ -16,10 +16,55 @@ #pragma once +#include + namespace folly { namespace wangle { -// TODO -struct Subscription { +template +class Subscription { + public: + Subscription() {} + + Subscription(const Subscription&) = delete; + + Subscription(Subscription&& other) noexcept { + *this = std::move(other); + } + + Subscription& operator=(Subscription&& other) noexcept { + unsubscribe(); + unsubscriber_ = std::move(other.unsubscriber_); + id_ = other.id_; + other.unsubscriber_ = nullptr; + other.id_ = 0; + return *this; + } + + ~Subscription() { + unsubscribe(); + } + + private: + typedef typename Observable::Unsubscriber Unsubscriber; + + Subscription(std::shared_ptr unsubscriber, uint64_t id) + : unsubscriber_(std::move(unsubscriber)), id_(id) { + CHECK(unsubscriber_); + CHECK(id_ > 0); + } + + void unsubscribe() { + if (unsubscriber_ && id_ > 0) { + unsubscriber_->unsubscribe(id_); + id_ = 0; + unsubscriber_ = nullptr; + } + } + + std::shared_ptr unsubscriber_; + uint64_t id_{0}; + + friend class Observable; }; }} diff --git a/folly/experimental/wangle/rx/test/RxTest.cpp b/folly/experimental/wangle/rx/test/RxTest.cpp index cf4d9dd0..a62d11ce 100644 --- a/folly/experimental/wangle/rx/test/RxTest.cpp +++ b/folly/experimental/wangle/rx/test/RxTest.cpp @@ -20,20 +20,88 @@ using namespace folly::wangle; +static std::unique_ptr> incrementer(int& counter) { + return Observer::create([&] (int x) { + counter++; + }); +} + +TEST(RxTest, Subscription) { + Subject subject; + auto count = 0; + { + auto s = subject.subscribe(incrementer(count)); + subject.onNext(1); + } + // The subscription has gone out of scope so no one should get this. + subject.onNext(2); + EXPECT_EQ(1, count); +} + +TEST(RxTest, SubscriptionMove) { + Subject subject; + auto count = 0; + auto s = subject.subscribe(incrementer(count)); + auto s2 = subject.subscribe(incrementer(count)); + s2 = std::move(s); + subject.onNext(1); + Subscription s3(std::move(s2)); + subject.onNext(2); + EXPECT_EQ(2, count); +} + +TEST(RxTest, SubscriptionOutlivesSubject) { + Subscription s; + { + Subject subject; + s = subject.subscribe(Observer::create([](int){})); + } + // Don't explode when s is destroyed +} + TEST(RxTest, SubscribeDuringCallback) { // A subscriber who was subscribed in the course of a callback should get // subsequent updates but not the current update. Subject subject; - int outerCount = 0; - int innerCount = 0; - subject.subscribe(Observer::create([&] (int x) { + int outerCount = 0, innerCount = 0; + Subscription s1, s2; + s1 = subject.subscribe(Observer::create([&] (int x) { outerCount++; - subject.subscribe(Observer::create([&] (int y) { - innerCount++; - })); + s2 = subject.subscribe(incrementer(innerCount)); })); subject.onNext(42); subject.onNext(0xDEADBEEF); EXPECT_EQ(2, outerCount); EXPECT_EQ(1, innerCount); } + +TEST(RxTest, UnsubscribeDuringCallback) { + // A subscriber who was unsubscribed in the course of a callback should get + // the current update but not subsequent ones + Subject subject; + int count1 = 0, count2 = 0; + auto s1 = subject.subscribe(incrementer(count1)); + auto s2 = subject.subscribe(Observer::create([&] (int x) { + count2++; + s1.~Subscription(); + })); + subject.onNext(1); + subject.onNext(2); + EXPECT_EQ(1, count1); + EXPECT_EQ(2, count2); +} + +TEST(RxTest, SubscribeUnsubscribeDuringCallback) { + // A subscriber who was subscribed and unsubscribed in the course of a + // callback should not get any updates + Subject subject; + int outerCount = 0, innerCount = 0; + auto s2 = subject.subscribe(Observer::create([&] (int x) { + outerCount++; + auto s2 = subject.subscribe(incrementer(innerCount)); + })); + subject.onNext(1); + subject.onNext(2); + EXPECT_EQ(2, outerCount); + EXPECT_EQ(0, innerCount); +} -- 2.34.1