in thread pools, take factory as shared ptr
[folly.git] / folly / experimental / wangle / concurrent / ThreadPoolExecutor.h
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #pragma once
18 #include <folly/experimental/wangle/concurrent/Executor.h>
19 #include <folly/experimental/wangle/concurrent/LifoSemMPMCQueue.h>
20 #include <folly/experimental/wangle/concurrent/NamedThreadFactory.h>
21 #include <folly/experimental/wangle/rx/Observable.h>
22 #include <folly/Baton.h>
23 #include <folly/Memory.h>
24 #include <folly/RWSpinLock.h>
25
26 #include <algorithm>
27 #include <mutex>
28 #include <queue>
29
30 #include <glog/logging.h>
31
32 namespace folly { namespace wangle {
33
34 class ThreadPoolExecutor : public experimental::Executor {
35  public:
36   explicit ThreadPoolExecutor(
37       size_t numThreads,
38       std::shared_ptr<ThreadFactory> threadFactory);
39
40   ~ThreadPoolExecutor();
41
42   virtual void add(Func func) override = 0;
43   virtual void add(
44       Func func,
45       std::chrono::milliseconds expiration,
46       Func expireCallback) = 0;
47
48   size_t numThreads();
49   void setNumThreads(size_t numThreads);
50   void stop();
51   void join();
52
53   struct PoolStats {
54     PoolStats() : threadCount(0), idleThreadCount(0), activeThreadCount(0),
55                   pendingTaskCount(0), totalTaskCount(0) {}
56     size_t threadCount, idleThreadCount, activeThreadCount;
57     uint64_t pendingTaskCount, totalTaskCount;
58   };
59
60   PoolStats getPoolStats();
61
62   struct TaskStats {
63     TaskStats() : expired(false), waitTime(0), runTime(0) {}
64     bool expired;
65     std::chrono::nanoseconds waitTime;
66     std::chrono::nanoseconds runTime;
67   };
68
69   Subscription subscribeToTaskStats(
70       const ObserverPtr<TaskStats>& observer) {
71     return taskStatsSubject_.subscribe(observer);
72   }
73
74  protected:
75   // Prerequisite: threadListLock_ writelocked
76   void addThreads(size_t n);
77   // Prerequisite: threadListLock_ writelocked
78   void removeThreads(size_t n, bool isJoin);
79
80   struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread {
81     virtual ~Thread() {}
82     Thread() : id(nextId++), handle(), idle(true) {};
83     static std::atomic<uint64_t> nextId;
84     uint64_t id;
85     std::thread handle;
86     bool idle;
87     Baton<> startupBaton;
88   };
89
90   typedef std::shared_ptr<Thread> ThreadPtr;
91
92   struct Task {
93     explicit Task(
94         Func&& func,
95         std::chrono::milliseconds expiration,
96         Func&& expireCallback);
97     Func func_;
98     TaskStats stats_;
99     std::chrono::steady_clock::time_point enqueueTime_;
100     std::chrono::milliseconds expiration_;
101     Func expireCallback_;
102   };
103
104   void runTask(const ThreadPtr& thread, Task&& task);
105
106   // The function that will be bound to pool threads. It must call
107   // thread->startupBaton.post() when it's ready to consume work.
108   virtual void threadRun(ThreadPtr thread) = 0;
109
110   // Stop n threads and put their ThreadPtrs in the threadsStopped_ queue
111   // Prerequisite: threadListLock_ writelocked
112   virtual void stopThreads(size_t n) = 0;
113
114   // Create a suitable Thread struct
115   virtual ThreadPtr makeThread() {
116     return std::make_shared<Thread>();
117   }
118
119   // Prerequisite: threadListLock_ readlocked
120   virtual uint64_t getPendingTaskCount() = 0;
121
122   class ThreadList {
123    public:
124     void add(const ThreadPtr& state) {
125       auto it = std::lower_bound(vec_.begin(), vec_.end(), state, compare);
126       vec_.insert(it, state);
127     }
128
129     void remove(const ThreadPtr& state) {
130       auto itPair = std::equal_range(vec_.begin(), vec_.end(), state, compare);
131       CHECK(itPair.first != vec_.end());
132       CHECK(std::next(itPair.first) == itPair.second);
133       vec_.erase(itPair.first);
134     }
135
136     const std::vector<ThreadPtr>& get() const {
137       return vec_;
138     }
139
140    private:
141     static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) {
142       return ts1->id < ts2->id;
143     }
144
145     std::vector<ThreadPtr> vec_;
146   };
147
148   class StoppedThreadQueue : public BlockingQueue<ThreadPtr> {
149    public:
150     void add(ThreadPtr item) override;
151     ThreadPtr take() override;
152     size_t size() override;
153
154    private:
155     LifoSem sem_;
156     std::mutex mutex_;
157     std::queue<ThreadPtr> queue_;
158   };
159
160   std::shared_ptr<ThreadFactory> threadFactory_;
161   ThreadList threadList_;
162   RWSpinLock threadListLock_;
163   StoppedThreadQueue stoppedThreads_;
164   std::atomic<bool> isJoin_; // whether the current downsizing is a join
165
166   Subject<TaskStats> taskStatsSubject_;
167 };
168
169 }} // folly::wangle