Thread Observer
authorDave Watson <davejwatson@fb.com>
Tue, 6 Jan 2015 18:36:31 +0000 (10:36 -0800)
committerViswanath Sivakumar <viswanath@fb.com>
Tue, 13 Jan 2015 19:01:04 +0000 (11:01 -0800)
Summary: Observer methods, so users of IOThreadPoolExecutor can do stuff when threads are added/removed.  As a use case, previously the thrift server only used the threads already started when it started up, and assumed iothreadpool was never resized.

Test Plan: Added several unittests

Reviewed By: jsedgwick@fb.com

Subscribers: trunkagent, doug, fugalh, alandau, bmatheny, mshneer, folly-diffs@

FB internal diff: D1753861

Signature: t1:1753861:1420236825:54cbdfee0efb3b97dea35faba29c134f2b10a480

folly/wangle/concurrent/CPUThreadPoolExecutor.cpp
folly/wangle/concurrent/IOThreadPoolExecutor.cpp
folly/wangle/concurrent/IOThreadPoolExecutor.h
folly/wangle/concurrent/ThreadPoolExecutor.cpp
folly/wangle/concurrent/ThreadPoolExecutor.h
folly/wangle/concurrent/test/ThreadPoolExecutorTest.cpp

index a03c6151040e240b76ab8f71ced395ff59ba9389..e0ad08c83ce8a76379a147a5a52c0206ed8872d1 100644 (file)
@@ -105,6 +105,10 @@ void CPUThreadPoolExecutor::threadRun(std::shared_ptr<Thread> thread) {
     auto task = taskQueue_->take();
     if (UNLIKELY(task.poison)) {
       CHECK(threadsToStop_-- > 0);
+      for (auto& o : observers_) {
+        o->threadStopped(thread.get());
+      }
+
       stoppedThreads_.add(thread);
       return;
     } else {
index 5c97bf442641ff4c571d142421e2c2181359388d..721083a10f0b6d4ad1c129ab79c91b6214104677 100644 (file)
@@ -117,6 +117,17 @@ EventBase* IOThreadPoolExecutor::getEventBase() {
   return pickThread()->eventBase;
 }
 
+EventBase* IOThreadPoolExecutor::getEventBase(
+    ThreadPoolExecutor::ThreadHandle* h) {
+  auto thread = dynamic_cast<IOThread*>(h);
+
+  if (thread) {
+    return thread->eventBase;
+  }
+
+  return nullptr;
+}
+
 std::shared_ptr<ThreadPoolExecutor::Thread>
 IOThreadPoolExecutor::makeThread() {
   return std::make_shared<IOThread>(this);
@@ -148,21 +159,14 @@ void IOThreadPoolExecutor::stopThreads(size_t n) {
   for (size_t i = 0; i < n; i++) {
     const auto ioThread = std::static_pointer_cast<IOThread>(
         threadList_.get()[i]);
+    for (auto& o : observers_) {
+      o->threadStopped(ioThread.get());
+    }
     ioThread->shouldRun = false;
     ioThread->eventBase->terminateLoopSoon();
   }
 }
 
-std::vector<EventBase*> IOThreadPoolExecutor::getEventBases() {
-  std::vector<EventBase*> bases;
-  RWSpinLock::ReadHolder{&threadListLock_};
-  for (const auto& thread : threadList_.get()) {
-    auto ioThread = std::static_pointer_cast<IOThread>(thread);
-    bases.push_back(ioThread->eventBase);
-  }
-  return bases;
-}
-
 // threadListLock_ is readlocked
 uint64_t IOThreadPoolExecutor::getPendingTaskCount() {
   uint64_t count = 0;
index 7c919d1e6df4209b2cf2c5a50211c34b94710a5e..1e4e8b0734bc46346d65465e5c8dda64cb93b3a2 100644 (file)
@@ -41,7 +41,7 @@ class IOThreadPoolExecutor : public ThreadPoolExecutor, public IOExecutor {
 
   EventBase* getEventBase() override;
 
-  std::vector<EventBase*> getEventBases();
+  EventBase* getEventBase(ThreadPoolExecutor::ThreadHandle*);
 
  private:
   struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING IOThread : public Thread {
index 4069475487aa2dbcf81e775552381f39039f5514..25660db13fd5956ab388fa41b8539b165bd7bcd8 100644 (file)
@@ -99,6 +99,11 @@ void ThreadPoolExecutor::addThreads(size_t n) {
   for (auto& thread : newThreads) {
     thread->startupBaton.wait();
   }
+  for (auto& o : observers_) {
+    for (auto& thread : newThreads) {
+      o->threadStarted(thread.get());
+    }
+  }
 }
 
 // threadListLock_ is writelocked
@@ -171,4 +176,27 @@ size_t ThreadPoolExecutor::StoppedThreadQueue::size() {
   return queue_.size();
 }
 
+void ThreadPoolExecutor::addObserver(std::shared_ptr<Observer> o) {
+  RWSpinLock::ReadHolder{&threadListLock_};
+  observers_.push_back(o);
+  for (auto& thread : threadList_.get()) {
+    o->threadPreviouslyStarted(thread.get());
+  }
+}
+
+void ThreadPoolExecutor::removeObserver(std::shared_ptr<Observer> o) {
+  RWSpinLock::ReadHolder{&threadListLock_};
+  for (auto& thread : threadList_.get()) {
+    o->threadNotYetStopped(thread.get());
+  }
+
+  for (auto it = observers_.begin(); it != observers_.end(); it++) {
+    if (*it == o) {
+      observers_.erase(it);
+      return;
+    }
+  }
+  DCHECK(false);
+}
+
 }} // folly::wangle
index be8f796869839712777643dae63048d02a108a90..f978a5e372a16d261e86c9ca570afe04cf75a06d 100644 (file)
@@ -85,13 +85,40 @@ class ThreadPoolExecutor : public virtual Executor {
     return taskStatsSubject_->subscribe(observer);
   }
 
+  /**
+   * Base class for threads created with ThreadPoolExecutor.
+   * Some subclasses have methods that operate on these
+   * handles.
+   */
+  class ThreadHandle {
+   public:
+    virtual ~ThreadHandle() = default;
+  };
+
+  /**
+   * Observer interface for thread start/stop.
+   * Provides hooks so actions can be taken when
+   * threads are created
+   */
+  class Observer {
+   public:
+    virtual void threadStarted(ThreadHandle*) = 0;
+    virtual void threadStopped(ThreadHandle*) = 0;
+    virtual void threadPreviouslyStarted(ThreadHandle*) = 0;
+    virtual void threadNotYetStopped(ThreadHandle*) = 0;
+    virtual ~Observer() = default;
+  };
+
+  void addObserver(std::shared_ptr<Observer>);
+  void removeObserver(std::shared_ptr<Observer>);
+
  protected:
   // Prerequisite: threadListLock_ writelocked
   void addThreads(size_t n);
   // Prerequisite: threadListLock_ writelocked
   void removeThreads(size_t n, bool isJoin);
 
-  struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread {
+  struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread : public ThreadHandle {
     explicit Thread(ThreadPoolExecutor* pool)
       : id(nextId++),
         handle(),
@@ -185,6 +212,7 @@ class ThreadPoolExecutor : public virtual Executor {
   std::atomic<bool> isJoin_; // whether the current downsizing is a join
 
   std::shared_ptr<Subject<TaskStats>> taskStatsSubject_;
+  std::vector<std::shared_ptr<Observer>> observers_;
 };
 
 }} // folly::wangle
index 596e27849c0dedbfb2cb34262766133d19d9fe8d..385d2b0e984748c287f96e46d43263dae24c5627 100644 (file)
@@ -318,3 +318,56 @@ TEST(ThreadPoolExecutorTest, PriorityPreemptionTest) {
   pool.join();
   EXPECT_EQ(100, completed);
 }
+
+class TestObserver : public ThreadPoolExecutor::Observer {
+ public:
+  void threadStarted(ThreadPoolExecutor::ThreadHandle*) {
+    threads_++;
+  }
+  void threadStopped(ThreadPoolExecutor::ThreadHandle*) {
+    threads_--;
+  }
+  void threadPreviouslyStarted(ThreadPoolExecutor::ThreadHandle*) {
+    threads_++;
+  }
+  void threadNotYetStopped(ThreadPoolExecutor::ThreadHandle*) {
+    threads_--;
+  }
+  void checkCalls() {
+    ASSERT_EQ(threads_, 0);
+  }
+ private:
+  int threads_{0};
+};
+
+TEST(ThreadPoolExecutorTest, IOObserver) {
+  auto observer = std::make_shared<TestObserver>();
+
+  {
+    IOThreadPoolExecutor exe(10);
+    exe.addObserver(observer);
+    exe.setNumThreads(3);
+    exe.setNumThreads(0);
+    exe.setNumThreads(7);
+    exe.removeObserver(observer);
+    exe.setNumThreads(10);
+  }
+
+  observer->checkCalls();
+}
+
+TEST(ThreadPoolExecutorTest, CPUObserver) {
+  auto observer = std::make_shared<TestObserver>();
+
+  {
+    CPUThreadPoolExecutor exe(10);
+    exe.addObserver(observer);
+    exe.setNumThreads(3);
+    exe.setNumThreads(0);
+    exe.setNumThreads(7);
+    exe.removeObserver(observer);
+    exe.setNumThreads(10);
+  }
+
+  observer->checkCalls();
+}