Get *=default*ed default constructors
[folly.git] / folly / wangle / concurrent / ThreadPoolExecutor.h
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #pragma once
18 #include <folly/Executor.h>
19 #include <folly/wangle/concurrent/LifoSemMPMCQueue.h>
20 #include <folly/wangle/concurrent/NamedThreadFactory.h>
21 #include <folly/wangle/rx/Observable.h>
22 #include <folly/Baton.h>
23 #include <folly/Memory.h>
24 #include <folly/RWSpinLock.h>
25
26 #include <algorithm>
27 #include <mutex>
28 #include <queue>
29
30 #include <glog/logging.h>
31
32 namespace folly { namespace wangle {
33
34 class ThreadPoolExecutor : public virtual Executor {
35  public:
36   explicit ThreadPoolExecutor(
37       size_t numThreads,
38       std::shared_ptr<ThreadFactory> threadFactory);
39
40   ~ThreadPoolExecutor();
41
42   virtual void add(Func func) override = 0;
43   virtual void add(
44       Func func,
45       std::chrono::milliseconds expiration,
46       Func expireCallback) = 0;
47
48   void setThreadFactory(std::shared_ptr<ThreadFactory> threadFactory) {
49     CHECK(numThreads() == 0);
50     threadFactory_ = std::move(threadFactory);
51   }
52
53   std::shared_ptr<ThreadFactory> getThreadFactory(void) {
54     return threadFactory_;
55   }
56
57   size_t numThreads();
58   void setNumThreads(size_t numThreads);
59   /*
60    * stop() is best effort - there is no guarantee that unexecuted tasks won't
61    * be executed before it returns. Specifically, IOThreadPoolExecutor's stop()
62    * behaves like join().
63    */
64   void stop();
65   void join();
66
67   struct PoolStats {
68     PoolStats() : threadCount(0), idleThreadCount(0), activeThreadCount(0),
69                   pendingTaskCount(0), totalTaskCount(0) {}
70     size_t threadCount, idleThreadCount, activeThreadCount;
71     uint64_t pendingTaskCount, totalTaskCount;
72   };
73
74   PoolStats getPoolStats();
75
76   struct TaskStats {
77     TaskStats() : expired(false), waitTime(0), runTime(0) {}
78     bool expired;
79     std::chrono::nanoseconds waitTime;
80     std::chrono::nanoseconds runTime;
81   };
82
83   Subscription<TaskStats> subscribeToTaskStats(
84       const ObserverPtr<TaskStats>& observer) {
85     return taskStatsSubject_->subscribe(observer);
86   }
87
88   /**
89    * Base class for threads created with ThreadPoolExecutor.
90    * Some subclasses have methods that operate on these
91    * handles.
92    */
93   class ThreadHandle {
94    public:
95     virtual ~ThreadHandle() = default;
96   };
97
98   /**
99    * Observer interface for thread start/stop.
100    * Provides hooks so actions can be taken when
101    * threads are created
102    */
103   class Observer {
104    public:
105     virtual void threadStarted(ThreadHandle*) = 0;
106     virtual void threadStopped(ThreadHandle*) = 0;
107     virtual void threadPreviouslyStarted(ThreadHandle* h) {
108       threadStarted(h);
109     }
110     virtual void threadNotYetStopped(ThreadHandle* h) {
111       threadStopped(h);
112     }
113     virtual ~Observer() = default;
114   };
115
116   void addObserver(std::shared_ptr<Observer>);
117   void removeObserver(std::shared_ptr<Observer>);
118
119  protected:
120   // Prerequisite: threadListLock_ writelocked
121   void addThreads(size_t n);
122   // Prerequisite: threadListLock_ writelocked
123   void removeThreads(size_t n, bool isJoin);
124
125   struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread : public ThreadHandle {
126     explicit Thread(ThreadPoolExecutor* pool)
127       : id(nextId++),
128         handle(),
129         idle(true),
130         taskStatsSubject(pool->taskStatsSubject_) {}
131
132     virtual ~Thread() = default;
133
134     static std::atomic<uint64_t> nextId;
135     uint64_t id;
136     std::thread handle;
137     bool idle;
138     Baton<> startupBaton;
139     std::shared_ptr<Subject<TaskStats>> taskStatsSubject;
140   };
141
142   typedef std::shared_ptr<Thread> ThreadPtr;
143
144   struct Task {
145     explicit Task(
146         Func&& func,
147         std::chrono::milliseconds expiration,
148         Func&& expireCallback);
149     Func func_;
150     TaskStats stats_;
151     std::chrono::steady_clock::time_point enqueueTime_;
152     std::chrono::milliseconds expiration_;
153     Func expireCallback_;
154   };
155
156   static void runTask(const ThreadPtr& thread, Task&& task);
157
158   // The function that will be bound to pool threads. It must call
159   // thread->startupBaton.post() when it's ready to consume work.
160   virtual void threadRun(ThreadPtr thread) = 0;
161
162   // Stop n threads and put their ThreadPtrs in the threadsStopped_ queue
163   // Prerequisite: threadListLock_ writelocked
164   virtual void stopThreads(size_t n) = 0;
165
166   // Create a suitable Thread struct
167   virtual ThreadPtr makeThread() {
168     return std::make_shared<Thread>(this);
169   }
170
171   // Prerequisite: threadListLock_ readlocked
172   virtual uint64_t getPendingTaskCount() = 0;
173
174   class ThreadList {
175    public:
176     void add(const ThreadPtr& state) {
177       auto it = std::lower_bound(vec_.begin(), vec_.end(), state,
178           // compare method is a static method of class
179           // and therefore cannot be inlined by compiler
180           // as a template predicate of the STL algorithm
181           // but wrapped up with the lambda function (lambda will be inlined)
182           // compiler can inline compare method as well
183           [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline
184             return compare(ts1, ts2);
185           });
186       vec_.insert(it, state);
187     }
188
189     void remove(const ThreadPtr& state) {
190       auto itPair = std::equal_range(vec_.begin(), vec_.end(), state,
191           // the same as above
192           [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline
193             return compare(ts1, ts2);
194           });
195       CHECK(itPair.first != vec_.end());
196       CHECK(std::next(itPair.first) == itPair.second);
197       vec_.erase(itPair.first);
198     }
199
200     const std::vector<ThreadPtr>& get() const {
201       return vec_;
202     }
203
204    private:
205     static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) {
206       return ts1->id < ts2->id;
207     }
208
209     std::vector<ThreadPtr> vec_;
210   };
211
212   class StoppedThreadQueue : public BlockingQueue<ThreadPtr> {
213    public:
214     void add(ThreadPtr item) override;
215     ThreadPtr take() override;
216     size_t size() override;
217
218    private:
219     LifoSem sem_;
220     std::mutex mutex_;
221     std::queue<ThreadPtr> queue_;
222   };
223
224   std::shared_ptr<ThreadFactory> threadFactory_;
225   ThreadList threadList_;
226   RWSpinLock threadListLock_;
227   StoppedThreadQueue stoppedThreads_;
228   std::atomic<bool> isJoin_; // whether the current downsizing is a join
229
230   std::shared_ptr<Subject<TaskStats>> taskStatsSubject_;
231   std::vector<std::shared_ptr<Observer>> observers_;
232 };
233
234 }} // folly::wangle