2 * Copyright 2014 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include <folly/experimental/wangle/concurrent/Executor.h>
19 #include <folly/experimental/wangle/concurrent/LifoSemMPMCQueue.h>
20 #include <folly/experimental/wangle/concurrent/NamedThreadFactory.h>
21 #include <folly/experimental/wangle/rx/Observable.h>
22 #include <folly/Baton.h>
23 #include <folly/Memory.h>
24 #include <folly/RWSpinLock.h>
30 #include <glog/logging.h>
32 namespace folly { namespace wangle {
34 class ThreadPoolExecutor : public experimental::Executor {
36 explicit ThreadPoolExecutor(
38 std::shared_ptr<ThreadFactory> threadFactory);
40 ~ThreadPoolExecutor();
42 virtual void add(Func func) override = 0;
45 std::chrono::milliseconds expiration,
46 Func expireCallback) = 0;
49 void setNumThreads(size_t numThreads);
54 PoolStats() : threadCount(0), idleThreadCount(0), activeThreadCount(0),
55 pendingTaskCount(0), totalTaskCount(0) {}
56 size_t threadCount, idleThreadCount, activeThreadCount;
57 uint64_t pendingTaskCount, totalTaskCount;
60 PoolStats getPoolStats();
63 TaskStats() : expired(false), waitTime(0), runTime(0) {}
65 std::chrono::nanoseconds waitTime;
66 std::chrono::nanoseconds runTime;
69 Subscription subscribeToTaskStats(
70 const ObserverPtr<TaskStats>& observer) {
71 return taskStatsSubject_.subscribe(observer);
75 // Prerequisite: threadListLock_ writelocked
76 void addThreads(size_t n);
77 // Prerequisite: threadListLock_ writelocked
78 void removeThreads(size_t n, bool isJoin);
80 struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread {
82 Thread() : id(nextId++), handle(), idle(true) {};
83 static std::atomic<uint64_t> nextId;
90 typedef std::shared_ptr<Thread> ThreadPtr;
95 std::chrono::milliseconds expiration,
96 Func&& expireCallback);
99 std::chrono::steady_clock::time_point enqueueTime_;
100 std::chrono::milliseconds expiration_;
101 Func expireCallback_;
104 void runTask(const ThreadPtr& thread, Task&& task);
106 // The function that will be bound to pool threads. It must call
107 // thread->startupBaton.post() when it's ready to consume work.
108 virtual void threadRun(ThreadPtr thread) = 0;
110 // Stop n threads and put their ThreadPtrs in the threadsStopped_ queue
111 // Prerequisite: threadListLock_ writelocked
112 virtual void stopThreads(size_t n) = 0;
114 // Create a suitable Thread struct
115 virtual ThreadPtr makeThread() {
116 return std::make_shared<Thread>();
119 // Prerequisite: threadListLock_ readlocked
120 virtual uint64_t getPendingTaskCount() = 0;
124 void add(const ThreadPtr& state) {
125 auto it = std::lower_bound(vec_.begin(), vec_.end(), state, compare);
126 vec_.insert(it, state);
129 void remove(const ThreadPtr& state) {
130 auto itPair = std::equal_range(vec_.begin(), vec_.end(), state, compare);
131 CHECK(itPair.first != vec_.end());
132 CHECK(std::next(itPair.first) == itPair.second);
133 vec_.erase(itPair.first);
136 const std::vector<ThreadPtr>& get() const {
141 static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) {
142 return ts1->id < ts2->id;
145 std::vector<ThreadPtr> vec_;
148 class StoppedThreadQueue : public BlockingQueue<ThreadPtr> {
150 void add(ThreadPtr item) override;
151 ThreadPtr take() override;
152 size_t size() override;
157 std::queue<ThreadPtr> queue_;
160 std::shared_ptr<ThreadFactory> threadFactory_;
161 ThreadList threadList_;
162 RWSpinLock threadListLock_;
163 StoppedThreadQueue stoppedThreads_;
164 std::atomic<bool> isJoin_; // whether the current downsizing is a join
166 Subject<TaskStats> taskStatsSubject_;