Add TimedDrivableExecutor to folly.
[folly.git] / folly / executors / ThreadPoolExecutor.h
1 /*
2  * Copyright 2017-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 #include <folly/Executor.h>
18 #include <folly/Memory.h>
19 #include <folly/executors/GlobalThreadPoolList.h>
20 #include <folly/executors/task_queue/LifoSemMPMCQueue.h>
21 #include <folly/executors/thread_factory/NamedThreadFactory.h>
22 #include <folly/io/async/Request.h>
23 #include <folly/synchronization/Baton.h>
24 #include <folly/synchronization/RWSpinLock.h>
25
26 #include <algorithm>
27 #include <mutex>
28 #include <queue>
29
30 #include <glog/logging.h>
31
32 namespace folly {
33
34 class ThreadPoolExecutor : public virtual folly::Executor {
35  public:
36   explicit ThreadPoolExecutor(
37       size_t numThreads,
38       std::shared_ptr<ThreadFactory> threadFactory,
39       bool isWaitForAll = false);
40
41   ~ThreadPoolExecutor() override;
42
43   void add(Func func) override = 0;
44   virtual void
45   add(Func func, std::chrono::milliseconds expiration, Func expireCallback) = 0;
46
47   void setThreadFactory(std::shared_ptr<ThreadFactory> threadFactory) {
48     CHECK(numThreads() == 0);
49     threadFactory_ = std::move(threadFactory);
50   }
51
52   std::shared_ptr<ThreadFactory> getThreadFactory() {
53     return threadFactory_;
54   }
55
56   size_t numThreads();
57   void setNumThreads(size_t numThreads);
58   /*
59    * stop() is best effort - there is no guarantee that unexecuted tasks won't
60    * be executed before it returns. Specifically, IOThreadPoolExecutor's stop()
61    * behaves like join().
62    */
63   void stop();
64   void join();
65
66   struct PoolStats {
67     PoolStats()
68         : threadCount(0),
69           idleThreadCount(0),
70           activeThreadCount(0),
71           pendingTaskCount(0),
72           totalTaskCount(0),
73           maxIdleTime(0) {}
74     size_t threadCount, idleThreadCount, activeThreadCount;
75     uint64_t pendingTaskCount, totalTaskCount;
76     std::chrono::nanoseconds maxIdleTime;
77   };
78
79   PoolStats getPoolStats();
80   uint64_t getPendingTaskCount();
81
82   struct TaskStats {
83     TaskStats() : expired(false), waitTime(0), runTime(0) {}
84     bool expired;
85     std::chrono::nanoseconds waitTime;
86     std::chrono::nanoseconds runTime;
87   };
88
89   using TaskStatsCallback = std::function<void(TaskStats)>;
90   void subscribeToTaskStats(TaskStatsCallback cb);
91
92   /**
93    * Base class for threads created with ThreadPoolExecutor.
94    * Some subclasses have methods that operate on these
95    * handles.
96    */
97   class ThreadHandle {
98    public:
99     virtual ~ThreadHandle() = default;
100   };
101
102   /**
103    * Observer interface for thread start/stop.
104    * Provides hooks so actions can be taken when
105    * threads are created
106    */
107   class Observer {
108    public:
109     virtual void threadStarted(ThreadHandle*) = 0;
110     virtual void threadStopped(ThreadHandle*) = 0;
111     virtual void threadPreviouslyStarted(ThreadHandle* h) {
112       threadStarted(h);
113     }
114     virtual void threadNotYetStopped(ThreadHandle* h) {
115       threadStopped(h);
116     }
117     virtual ~Observer() = default;
118   };
119
120   void addObserver(std::shared_ptr<Observer>);
121   void removeObserver(std::shared_ptr<Observer>);
122
123  protected:
124   // Prerequisite: threadListLock_ writelocked
125   void addThreads(size_t n);
126   // Prerequisite: threadListLock_ writelocked
127   void removeThreads(size_t n, bool isJoin);
128
129   struct TaskStatsCallbackRegistry;
130
131   struct alignas(hardware_destructive_interference_size) Thread
132       : public ThreadHandle {
133     explicit Thread(ThreadPoolExecutor* pool)
134         : id(nextId++),
135           handle(),
136           idle(true),
137           lastActiveTime(std::chrono::steady_clock::now()),
138           taskStatsCallbacks(pool->taskStatsCallbacks_) {}
139
140     ~Thread() override = default;
141
142     static std::atomic<uint64_t> nextId;
143     uint64_t id;
144     std::thread handle;
145     bool idle;
146     std::chrono::steady_clock::time_point lastActiveTime;
147     folly::Baton<> startupBaton;
148     std::shared_ptr<TaskStatsCallbackRegistry> taskStatsCallbacks;
149   };
150
151   typedef std::shared_ptr<Thread> ThreadPtr;
152
153   struct Task {
154     explicit Task(
155         Func&& func,
156         std::chrono::milliseconds expiration,
157         Func&& expireCallback);
158     Func func_;
159     TaskStats stats_;
160     std::chrono::steady_clock::time_point enqueueTime_;
161     std::chrono::milliseconds expiration_;
162     Func expireCallback_;
163     std::shared_ptr<folly::RequestContext> context_;
164   };
165
166   static void runTask(const ThreadPtr& thread, Task&& task);
167
168   // The function that will be bound to pool threads. It must call
169   // thread->startupBaton.post() when it's ready to consume work.
170   virtual void threadRun(ThreadPtr thread) = 0;
171
172   // Stop n threads and put their ThreadPtrs in the stoppedThreads_ queue
173   // and remove them from threadList_, either synchronize or asynchronize
174   // Prerequisite: threadListLock_ writelocked
175   virtual void stopThreads(size_t n) = 0;
176
177   // Join n stopped threads and remove them from waitingForJoinThreads_ queue.
178   // Should not hold a lock because joining thread operation may invoke some
179   // cleanup operations on the thread, and those cleanup operations may
180   // require a lock on ThreadPoolExecutor.
181   void joinStoppedThreads(size_t n);
182
183   // Create a suitable Thread struct
184   virtual ThreadPtr makeThread() {
185     return std::make_shared<Thread>(this);
186   }
187
188   // Prerequisite: threadListLock_ readlocked
189   virtual uint64_t getPendingTaskCountImpl(const RWSpinLock::ReadHolder&) = 0;
190
191   class ThreadList {
192    public:
193     void add(const ThreadPtr& state) {
194       auto it = std::lower_bound(
195           vec_.begin(),
196           vec_.end(),
197           state,
198           // compare method is a static method of class
199           // and therefore cannot be inlined by compiler
200           // as a template predicate of the STL algorithm
201           // but wrapped up with the lambda function (lambda will be inlined)
202           // compiler can inline compare method as well
203           [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline
204             return compare(ts1, ts2);
205           });
206       vec_.insert(it, state);
207     }
208
209     void remove(const ThreadPtr& state) {
210       auto itPair = std::equal_range(
211           vec_.begin(),
212           vec_.end(),
213           state,
214           // the same as above
215           [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline
216             return compare(ts1, ts2);
217           });
218       CHECK(itPair.first != vec_.end());
219       CHECK(std::next(itPair.first) == itPair.second);
220       vec_.erase(itPair.first);
221     }
222
223     const std::vector<ThreadPtr>& get() const {
224       return vec_;
225     }
226
227    private:
228     static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) {
229       return ts1->id < ts2->id;
230     }
231
232     std::vector<ThreadPtr> vec_;
233   };
234
235   class StoppedThreadQueue : public BlockingQueue<ThreadPtr> {
236    public:
237     void add(ThreadPtr item) override;
238     ThreadPtr take() override;
239     size_t size() override;
240
241    private:
242     folly::LifoSem sem_;
243     std::mutex mutex_;
244     std::queue<ThreadPtr> queue_;
245   };
246
247   std::shared_ptr<ThreadFactory> threadFactory_;
248   const bool isWaitForAll_; // whether to wait till event base loop exits
249
250   ThreadList threadList_;
251   folly::RWSpinLock threadListLock_;
252   StoppedThreadQueue stoppedThreads_;
253   std::atomic<bool> isJoin_; // whether the current downsizing is a join
254
255   struct TaskStatsCallbackRegistry {
256     folly::ThreadLocal<bool> inCallback;
257     folly::Synchronized<std::vector<TaskStatsCallback>> callbackList;
258   };
259   std::shared_ptr<TaskStatsCallbackRegistry> taskStatsCallbacks_;
260   std::vector<std::shared_ptr<Observer>> observers_;
261   folly::ThreadPoolListHook threadPoolHook_;
262 };
263
264 } // namespace folly