Prefer bool literals rather than integers in boolean contexts
[folly.git] / folly / executors / ThreadPoolExecutor.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/executors/ThreadPoolExecutor.h>
18
19 #include <folly/executors/GlobalThreadPoolList.h>
20
21 namespace folly {
22
23 ThreadPoolExecutor::ThreadPoolExecutor(
24     size_t /* numThreads */,
25     std::shared_ptr<ThreadFactory> threadFactory,
26     bool isWaitForAll)
27     : threadFactory_(std::move(threadFactory)),
28       isWaitForAll_(isWaitForAll),
29       taskStatsCallbacks_(std::make_shared<TaskStatsCallbackRegistry>()),
30       threadPoolHook_("Wangle::ThreadPoolExecutor") {}
31
32 ThreadPoolExecutor::~ThreadPoolExecutor() {
33   CHECK_EQ(0, threadList_.get().size());
34 }
35
36 ThreadPoolExecutor::Task::Task(
37     Func&& func,
38     std::chrono::milliseconds expiration,
39     Func&& expireCallback)
40     : func_(std::move(func)),
41       expiration_(expiration),
42       expireCallback_(std::move(expireCallback)),
43       context_(folly::RequestContext::saveContext()) {
44   // Assume that the task in enqueued on creation
45   enqueueTime_ = std::chrono::steady_clock::now();
46 }
47
48 void ThreadPoolExecutor::runTask(const ThreadPtr& thread, Task&& task) {
49   thread->idle = false;
50   auto startTime = std::chrono::steady_clock::now();
51   task.stats_.waitTime = startTime - task.enqueueTime_;
52   if (task.expiration_ > std::chrono::milliseconds(0) &&
53       task.stats_.waitTime >= task.expiration_) {
54     task.stats_.expired = true;
55     if (task.expireCallback_ != nullptr) {
56       task.expireCallback_();
57     }
58   } else {
59     folly::RequestContextScopeGuard rctx(task.context_);
60     try {
61       task.func_();
62     } catch (const std::exception& e) {
63       LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled "
64                  << typeid(e).name() << " exception: " << e.what();
65     } catch (...) {
66       LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled non-exception "
67                     "object";
68     }
69     task.stats_.runTime = std::chrono::steady_clock::now() - startTime;
70   }
71   thread->idle = true;
72   thread->lastActiveTime = std::chrono::steady_clock::now();
73   thread->taskStatsCallbacks->callbackList.withRLock([&](auto& callbacks) {
74     *thread->taskStatsCallbacks->inCallback = true;
75     SCOPE_EXIT {
76       *thread->taskStatsCallbacks->inCallback = false;
77     };
78     try {
79       for (auto& callback : callbacks) {
80         callback(task.stats_);
81       }
82     } catch (const std::exception& e) {
83       LOG(ERROR) << "ThreadPoolExecutor: task stats callback threw "
84                     "unhandled "
85                  << typeid(e).name() << " exception: " << e.what();
86     } catch (...) {
87       LOG(ERROR) << "ThreadPoolExecutor: task stats callback threw "
88                     "unhandled non-exception object";
89     }
90   });
91 }
92
93 size_t ThreadPoolExecutor::numThreads() {
94   RWSpinLock::ReadHolder r{&threadListLock_};
95   return threadList_.get().size();
96 }
97
98 void ThreadPoolExecutor::setNumThreads(size_t n) {
99   size_t numThreadsToJoin = 0;
100   {
101     RWSpinLock::WriteHolder w{&threadListLock_};
102     const auto current = threadList_.get().size();
103     if (n > current) {
104       addThreads(n - current);
105     } else if (n < current) {
106       numThreadsToJoin = current - n;
107       removeThreads(numThreadsToJoin, true);
108     }
109   }
110   joinStoppedThreads(numThreadsToJoin);
111   CHECK_EQ(n, threadList_.get().size());
112   CHECK_EQ(0, stoppedThreads_.size());
113 }
114
115 // threadListLock_ is writelocked
116 void ThreadPoolExecutor::addThreads(size_t n) {
117   std::vector<ThreadPtr> newThreads;
118   for (size_t i = 0; i < n; i++) {
119     newThreads.push_back(makeThread());
120   }
121   for (auto& thread : newThreads) {
122     // TODO need a notion of failing to create the thread
123     // and then handling for that case
124     thread->handle = threadFactory_->newThread(
125         std::bind(&ThreadPoolExecutor::threadRun, this, thread));
126     threadList_.add(thread);
127   }
128   for (auto& thread : newThreads) {
129     thread->startupBaton.wait();
130   }
131   for (auto& o : observers_) {
132     for (auto& thread : newThreads) {
133       o->threadStarted(thread.get());
134     }
135   }
136 }
137
138 // threadListLock_ is writelocked
139 void ThreadPoolExecutor::removeThreads(size_t n, bool isJoin) {
140   CHECK_LE(n, threadList_.get().size());
141   isJoin_ = isJoin;
142   stopThreads(n);
143 }
144
145 void ThreadPoolExecutor::joinStoppedThreads(size_t n) {
146   for (size_t i = 0; i < n; i++) {
147     auto thread = stoppedThreads_.take();
148     thread->handle.join();
149   }
150 }
151
152 void ThreadPoolExecutor::stop() {
153   size_t n = 0;
154   {
155     RWSpinLock::WriteHolder w{&threadListLock_};
156     n = threadList_.get().size();
157     removeThreads(n, false);
158   }
159   joinStoppedThreads(n);
160   CHECK_EQ(0, threadList_.get().size());
161   CHECK_EQ(0, stoppedThreads_.size());
162 }
163
164 void ThreadPoolExecutor::join() {
165   size_t n = 0;
166   {
167     RWSpinLock::WriteHolder w{&threadListLock_};
168     n = threadList_.get().size();
169     removeThreads(n, true);
170   }
171   joinStoppedThreads(n);
172   CHECK_EQ(0, threadList_.get().size());
173   CHECK_EQ(0, stoppedThreads_.size());
174 }
175
176 ThreadPoolExecutor::PoolStats ThreadPoolExecutor::getPoolStats() {
177   const auto now = std::chrono::steady_clock::now();
178   RWSpinLock::ReadHolder r{&threadListLock_};
179   ThreadPoolExecutor::PoolStats stats;
180   stats.threadCount = threadList_.get().size();
181   for (auto thread : threadList_.get()) {
182     if (thread->idle) {
183       stats.idleThreadCount++;
184       const std::chrono::nanoseconds idleTime = now - thread->lastActiveTime;
185       stats.maxIdleTime = std::max(stats.maxIdleTime, idleTime);
186     } else {
187       stats.activeThreadCount++;
188     }
189   }
190   stats.pendingTaskCount = getPendingTaskCountImpl(r);
191   stats.totalTaskCount = stats.pendingTaskCount + stats.activeThreadCount;
192   return stats;
193 }
194
195 uint64_t ThreadPoolExecutor::getPendingTaskCount() {
196   RWSpinLock::ReadHolder r{&threadListLock_};
197   return getPendingTaskCountImpl(r);
198 }
199
200 std::atomic<uint64_t> ThreadPoolExecutor::Thread::nextId(0);
201
202 void ThreadPoolExecutor::subscribeToTaskStats(TaskStatsCallback cb) {
203   if (*taskStatsCallbacks_->inCallback) {
204     throw std::runtime_error("cannot subscribe in task stats callback");
205   }
206   taskStatsCallbacks_->callbackList.wlock()->push_back(std::move(cb));
207 }
208
209 void ThreadPoolExecutor::StoppedThreadQueue::add(
210     ThreadPoolExecutor::ThreadPtr item) {
211   std::lock_guard<std::mutex> guard(mutex_);
212   queue_.push(std::move(item));
213   sem_.post();
214 }
215
216 ThreadPoolExecutor::ThreadPtr ThreadPoolExecutor::StoppedThreadQueue::take() {
217   while (true) {
218     {
219       std::lock_guard<std::mutex> guard(mutex_);
220       if (queue_.size() > 0) {
221         auto item = std::move(queue_.front());
222         queue_.pop();
223         return item;
224       }
225     }
226     sem_.wait();
227   }
228 }
229
230 size_t ThreadPoolExecutor::StoppedThreadQueue::size() {
231   std::lock_guard<std::mutex> guard(mutex_);
232   return queue_.size();
233 }
234
235 void ThreadPoolExecutor::addObserver(std::shared_ptr<Observer> o) {
236   RWSpinLock::ReadHolder r{&threadListLock_};
237   observers_.push_back(o);
238   for (auto& thread : threadList_.get()) {
239     o->threadPreviouslyStarted(thread.get());
240   }
241 }
242
243 void ThreadPoolExecutor::removeObserver(std::shared_ptr<Observer> o) {
244   RWSpinLock::ReadHolder r{&threadListLock_};
245   for (auto& thread : threadList_.get()) {
246     o->threadNotYetStopped(thread.get());
247   }
248
249   for (auto it = observers_.begin(); it != observers_.end(); it++) {
250     if (*it == o) {
251       observers_.erase(it);
252       return;
253     }
254   }
255   DCHECK(false);
256 }
257
258 } // namespace folly