2 * Copyright 2015 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include <folly/wangle/concurrent/ThreadPoolExecutor.h>
19 namespace folly { namespace wangle {
21 ThreadPoolExecutor::ThreadPoolExecutor(
23 std::shared_ptr<ThreadFactory> threadFactory)
24 : threadFactory_(std::move(threadFactory)),
25 taskStatsSubject_(std::make_shared<Subject<TaskStats>>()) {}
27 ThreadPoolExecutor::~ThreadPoolExecutor() {
28 CHECK(threadList_.get().size() == 0);
31 ThreadPoolExecutor::Task::Task(
33 std::chrono::milliseconds expiration,
34 Func&& expireCallback)
35 : func_(std::move(func)),
36 expiration_(expiration),
37 expireCallback_(std::move(expireCallback)) {
38 // Assume that the task in enqueued on creation
39 enqueueTime_ = std::chrono::steady_clock::now();
42 void ThreadPoolExecutor::runTask(
43 const ThreadPtr& thread,
46 auto startTime = std::chrono::steady_clock::now();
47 task.stats_.waitTime = startTime - task.enqueueTime_;
48 if (task.expiration_ > std::chrono::milliseconds(0) &&
49 task.stats_.waitTime >= task.expiration_) {
50 task.stats_.expired = true;
51 if (task.expireCallback_ != nullptr) {
52 task.expireCallback_();
57 } catch (const std::exception& e) {
58 LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled " <<
59 typeid(e).name() << " exception: " << e.what();
61 LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled non-exception "
64 task.stats_.runTime = std::chrono::steady_clock::now() - startTime;
67 thread->taskStatsSubject->onNext(std::move(task.stats_));
70 size_t ThreadPoolExecutor::numThreads() {
71 RWSpinLock::ReadHolder{&threadListLock_};
72 return threadList_.get().size();
75 void ThreadPoolExecutor::setNumThreads(size_t n) {
76 RWSpinLock::WriteHolder{&threadListLock_};
77 const auto current = threadList_.get().size();
79 addThreads(n - current);
80 } else if (n < current) {
81 removeThreads(current - n, true);
83 CHECK(threadList_.get().size() == n);
86 // threadListLock_ is writelocked
87 void ThreadPoolExecutor::addThreads(size_t n) {
88 std::vector<ThreadPtr> newThreads;
89 for (size_t i = 0; i < n; i++) {
90 newThreads.push_back(makeThread());
92 for (auto& thread : newThreads) {
93 // TODO need a notion of failing to create the thread
94 // and then handling for that case
95 thread->handle = threadFactory_->newThread(
96 std::bind(&ThreadPoolExecutor::threadRun, this, thread));
97 threadList_.add(thread);
99 for (auto& thread : newThreads) {
100 thread->startupBaton.wait();
102 for (auto& o : observers_) {
103 for (auto& thread : newThreads) {
104 o->threadStarted(thread.get());
109 // threadListLock_ is writelocked
110 void ThreadPoolExecutor::removeThreads(size_t n, bool isJoin) {
111 CHECK(n <= threadList_.get().size());
112 CHECK(stoppedThreads_.size() == 0);
115 for (size_t i = 0; i < n; i++) {
116 auto thread = stoppedThreads_.take();
117 thread->handle.join();
118 threadList_.remove(thread);
120 CHECK(stoppedThreads_.size() == 0);
123 void ThreadPoolExecutor::stop() {
124 RWSpinLock::WriteHolder{&threadListLock_};
125 removeThreads(threadList_.get().size(), false);
126 CHECK(threadList_.get().size() == 0);
129 void ThreadPoolExecutor::join() {
130 RWSpinLock::WriteHolder{&threadListLock_};
131 removeThreads(threadList_.get().size(), true);
132 CHECK(threadList_.get().size() == 0);
135 ThreadPoolExecutor::PoolStats ThreadPoolExecutor::getPoolStats() {
136 RWSpinLock::ReadHolder{&threadListLock_};
137 ThreadPoolExecutor::PoolStats stats;
138 stats.threadCount = threadList_.get().size();
139 for (auto thread : threadList_.get()) {
141 stats.idleThreadCount++;
143 stats.activeThreadCount++;
146 stats.pendingTaskCount = getPendingTaskCount();
147 stats.totalTaskCount = stats.pendingTaskCount + stats.activeThreadCount;
151 std::atomic<uint64_t> ThreadPoolExecutor::Thread::nextId(0);
153 void ThreadPoolExecutor::StoppedThreadQueue::add(
154 ThreadPoolExecutor::ThreadPtr item) {
155 std::lock_guard<std::mutex> guard(mutex_);
156 queue_.push(std::move(item));
160 ThreadPoolExecutor::ThreadPtr ThreadPoolExecutor::StoppedThreadQueue::take() {
163 std::lock_guard<std::mutex> guard(mutex_);
164 if (queue_.size() > 0) {
165 auto item = std::move(queue_.front());
174 size_t ThreadPoolExecutor::StoppedThreadQueue::size() {
175 std::lock_guard<std::mutex> guard(mutex_);
176 return queue_.size();
179 void ThreadPoolExecutor::addObserver(std::shared_ptr<Observer> o) {
180 RWSpinLock::ReadHolder{&threadListLock_};
181 observers_.push_back(o);
182 for (auto& thread : threadList_.get()) {
183 o->threadPreviouslyStarted(thread.get());
187 void ThreadPoolExecutor::removeObserver(std::shared_ptr<Observer> o) {
188 RWSpinLock::ReadHolder{&threadListLock_};
189 for (auto& thread : threadList_.get()) {
190 o->threadNotYetStopped(thread.get());
193 for (auto it = observers_.begin(); it != observers_.end(); it++) {
195 observers_.erase(it);