4eda2d36e3039fd703aeb254295c82b6629df015
[folly.git] / folly / experimental / wangle / concurrent / ThreadPoolExecutor.h
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #pragma once
18 #include <folly/experimental/wangle/concurrent/Executor.h>
19 #include <folly/experimental/wangle/concurrent/LifoSemMPMCQueue.h>
20 #include <folly/experimental/wangle/concurrent/NamedThreadFactory.h>
21 #include <folly/experimental/wangle/rx/Observable.h>
22 #include <folly/Memory.h>
23 #include <folly/RWSpinLock.h>
24
25 #include <algorithm>
26 #include <mutex>
27 #include <queue>
28
29 #include <glog/logging.h>
30
31 namespace folly { namespace wangle {
32
33 class ThreadPoolExecutor : public experimental::Executor {
34  public:
35   explicit ThreadPoolExecutor(
36       size_t numThreads,
37       std::unique_ptr<ThreadFactory> threadFactory);
38
39   ~ThreadPoolExecutor();
40
41   size_t numThreads();
42   void setNumThreads(size_t numThreads);
43   void stop();
44   void join();
45
46   struct PoolStats {
47     PoolStats() : threadCount(0), idleThreadCount(0), activeThreadCount(0),
48                   pendingTaskCount(0), totalTaskCount(0) {}
49     size_t threadCount, idleThreadCount, activeThreadCount;
50     uint64_t pendingTaskCount, totalTaskCount;
51   };
52
53   PoolStats getPoolStats();
54
55   struct TaskStats {
56     TaskStats() : expired(false), waitTime(0), runTime(0) {}
57     bool expired;
58     std::chrono::microseconds waitTime;
59     std::chrono::microseconds runTime;
60   };
61
62   Subscription subscribeToTaskStats(
63       const ObserverPtr<TaskStats>& observer) {
64     return taskStatsSubject_.subscribe(observer);
65   }
66
67  protected:
68   // Prerequisite: threadListLock_ writelocked
69   void addThreads(size_t n);
70   // Prerequisite: threadListLock_ writelocked
71   void removeThreads(size_t n, bool isJoin);
72
73   struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread {
74     virtual ~Thread() {}
75     Thread() : id(nextId++), handle(), idle(true) {};
76     static std::atomic<uint64_t> nextId;
77     uint64_t id;
78     std::thread handle;
79     bool idle;
80   };
81
82   typedef std::shared_ptr<Thread> ThreadPtr;
83
84   struct Task {
85     explicit Task(Func&& f) : func(std::move(f)) {
86       // Assume that the task in enqueued on creation
87       intervalBegin = std::chrono::steady_clock::now();
88     }
89
90     Func func;
91     TaskStats stats;
92     // TODO per-task timeouts, expirations
93
94     void started() {
95       auto now = std::chrono::steady_clock::now();
96       stats.waitTime = std::chrono::duration_cast<std::chrono::microseconds>(
97           now - intervalBegin);
98       intervalBegin = now;
99     }
100     void completed() {
101       stats.runTime = std::chrono::duration_cast<std::chrono::microseconds>(
102          std::chrono::steady_clock::now() - intervalBegin);
103     }
104
105     std::chrono::steady_clock::time_point intervalBegin;
106   };
107
108   void runTask(const ThreadPtr& thread, Task&& task);
109
110   // The function that will be bound to pool threads
111   virtual void threadRun(ThreadPtr thread) = 0;
112
113   // Stop n threads and put their ThreadPtrs in the threadsStopped_ queue
114   // Prerequisite: threadListLock_ writelocked
115   virtual void stopThreads(size_t n) = 0;
116
117   // Create a suitable Thread struct
118   virtual ThreadPtr makeThread() {
119     return std::make_shared<Thread>();
120   }
121
122   // Prerequisite: threadListLock_ readlocked
123   virtual uint64_t getPendingTaskCount() = 0;
124
125   class ThreadList {
126    public:
127     void add(const ThreadPtr& state) {
128       auto it = std::lower_bound(vec_.begin(), vec_.end(), state, compare);
129       vec_.insert(it, state);
130     }
131
132     void remove(const ThreadPtr& state) {
133       auto itPair = std::equal_range(vec_.begin(), vec_.end(), state, compare);
134       CHECK(itPair.first != vec_.end());
135       CHECK(std::next(itPair.first) == itPair.second);
136       vec_.erase(itPair.first);
137     }
138
139     const std::vector<ThreadPtr>& get() const {
140       return vec_;
141     }
142
143    private:
144     static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) {
145       return ts1->id < ts2->id;
146     }
147
148     std::vector<ThreadPtr> vec_;
149   };
150
151   class StoppedThreadQueue : public BlockingQueue<ThreadPtr> {
152    public:
153     void add(ThreadPtr item) override;
154     ThreadPtr take() override;
155     size_t size() override;
156
157    private:
158     LifoSem sem_;
159     std::mutex mutex_;
160     std::queue<ThreadPtr> queue_;
161   };
162
163   std::unique_ptr<ThreadFactory> threadFactory_;
164   ThreadList threadList_;
165   RWSpinLock threadListLock_;
166   StoppedThreadQueue stoppedThreads_;
167   std::atomic<bool> isJoin_; // whether the current downsizing is a join
168
169   Subject<TaskStats> taskStatsSubject_;
170 };
171
172 }} // folly::wangle