07adf7b458e36173d9776615e56687d770e4ef0a
[folly.git] / folly / fibers / TimedMutex-inl.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <mutex>
19
20 namespace folly {
21 namespace fibers {
22
23 //
24 // TimedMutex implementation
25 //
26
27 template <typename WaitFunc>
28 TimedMutex::LockResult TimedMutex::lockHelper(WaitFunc&& waitFunc) {
29   std::unique_lock<folly::SpinLock> lock(lock_);
30   if (!locked_) {
31     locked_ = true;
32     return LockResult::SUCCESS;
33   }
34
35   const auto isOnFiber = onFiber();
36
37   if (!isOnFiber && notifiedFiber_ != nullptr) {
38     // lock() was called on a thread and while some other fiber was already
39     // notified, it hasn't be run yet. We steal the lock from that fiber then
40     // to avoid potential deadlock.
41     DCHECK(threadWaiters_.empty());
42     notifiedFiber_ = nullptr;
43     return LockResult::SUCCESS;
44   }
45
46   // Delay constructing the waiter until it is actually required.
47   // This makes a huge difference, at least in the benchmarks,
48   // when the mutex isn't locked.
49   MutexWaiter waiter;
50   if (isOnFiber) {
51     fiberWaiters_.push_back(waiter);
52   } else {
53     threadWaiters_.push_back(waiter);
54   }
55
56   lock.unlock();
57
58   if (!waitFunc(waiter)) {
59     return LockResult::TIMEOUT;
60   }
61
62   if (isOnFiber) {
63     auto lockStolen = [&] {
64       std::lock_guard<folly::SpinLock> lg(lock_);
65
66       auto stolen = notifiedFiber_ != &waiter;
67       if (!stolen) {
68         notifiedFiber_ = nullptr;
69       }
70       return stolen;
71     }();
72
73     if (lockStolen) {
74       return LockResult::STOLEN;
75     }
76   }
77
78   return LockResult::SUCCESS;
79 }
80
81 inline void TimedMutex::lock() {
82   auto result = lockHelper([](MutexWaiter& waiter) {
83     waiter.baton.wait();
84     return true;
85   });
86
87   DCHECK(result != LockResult::TIMEOUT);
88   if (result == LockResult::SUCCESS) {
89     return;
90   }
91   lock();
92 }
93
94 template <typename Rep, typename Period>
95 bool TimedMutex::timed_lock(
96     const std::chrono::duration<Rep, Period>& duration) {
97   auto result = lockHelper([&](MutexWaiter& waiter) {
98     if (!waiter.baton.timed_wait(duration)) {
99       // We timed out. Two cases:
100       // 1. We're still in the waiter list and we truly timed out
101       // 2. We're not in the waiter list anymore. This could happen if the baton
102       //    times out but the mutex is unlocked before we reach this code. In
103       //    this
104       //    case we'll pretend we got the lock on time.
105       std::lock_guard<folly::SpinLock> lg(lock_);
106       if (waiter.hook.is_linked()) {
107         waiter.hook.unlink();
108         return false;
109       }
110     }
111     return true;
112   });
113
114   switch (result) {
115     case LockResult::SUCCESS:
116       return true;
117     case LockResult::TIMEOUT:
118       return false;
119     case LockResult::STOLEN:
120       // We don't respect the duration if lock was stolen
121       lock();
122       return true;
123   }
124   assume_unreachable();
125 }
126
127 inline bool TimedMutex::try_lock() {
128   std::lock_guard<folly::SpinLock> lg(lock_);
129   if (locked_) {
130     return false;
131   }
132   locked_ = true;
133   return true;
134 }
135
136 inline void TimedMutex::unlock() {
137   std::lock_guard<folly::SpinLock> lg(lock_);
138   if (!threadWaiters_.empty()) {
139     auto& to_wake = threadWaiters_.front();
140     threadWaiters_.pop_front();
141     to_wake.baton.post();
142   } else if (!fiberWaiters_.empty()) {
143     auto& to_wake = fiberWaiters_.front();
144     fiberWaiters_.pop_front();
145     notifiedFiber_ = &to_wake;
146     to_wake.baton.post();
147   } else {
148     locked_ = false;
149   }
150 }
151
152 //
153 // TimedRWMutex implementation
154 //
155
156 template <typename BatonType>
157 void TimedRWMutex<BatonType>::read_lock() {
158   std::unique_lock<folly::SpinLock> lock{lock_};
159   if (state_ == State::WRITE_LOCKED) {
160     MutexWaiter waiter;
161     read_waiters_.push_back(waiter);
162     lock.unlock();
163     waiter.baton.wait();
164     assert(state_ == State::READ_LOCKED);
165     return;
166   }
167   assert(
168       (state_ == State::UNLOCKED && readers_ == 0) ||
169       (state_ == State::READ_LOCKED && readers_ > 0));
170   assert(read_waiters_.empty());
171   state_ = State::READ_LOCKED;
172   readers_ += 1;
173 }
174
175 template <typename BatonType>
176 template <typename Rep, typename Period>
177 bool TimedRWMutex<BatonType>::timed_read_lock(
178     const std::chrono::duration<Rep, Period>& duration) {
179   std::unique_lock<folly::SpinLock> lock{lock_};
180   if (state_ == State::WRITE_LOCKED) {
181     MutexWaiter waiter;
182     read_waiters_.push_back(waiter);
183     lock.unlock();
184
185     if (!waiter.baton.timed_wait(duration)) {
186       // We timed out. Two cases:
187       // 1. We're still in the waiter list and we truly timed out
188       // 2. We're not in the waiter list anymore. This could happen if the baton
189       //    times out but the mutex is unlocked before we reach this code. In
190       //    this case we'll pretend we got the lock on time.
191       std::lock_guard<SpinLock> guard{lock_};
192       if (waiter.hook.is_linked()) {
193         read_waiters_.erase(read_waiters_.iterator_to(waiter));
194         return false;
195       }
196     }
197     return true;
198   }
199   assert(
200       (state_ == State::UNLOCKED && readers_ == 0) ||
201       (state_ == State::READ_LOCKED && readers_ > 0));
202   assert(read_waiters_.empty());
203   state_ = State::READ_LOCKED;
204   readers_ += 1;
205   return true;
206 }
207
208 template <typename BatonType>
209 bool TimedRWMutex<BatonType>::try_read_lock() {
210   std::lock_guard<SpinLock> guard{lock_};
211   if (state_ != State::WRITE_LOCKED) {
212     assert(
213         (state_ == State::UNLOCKED && readers_ == 0) ||
214         (state_ == State::READ_LOCKED && readers_ > 0));
215     assert(read_waiters_.empty());
216     state_ = State::READ_LOCKED;
217     readers_ += 1;
218     return true;
219   }
220   return false;
221 }
222
223 template <typename BatonType>
224 void TimedRWMutex<BatonType>::write_lock() {
225   std::unique_lock<folly::SpinLock> lock{lock_};
226   if (state_ == State::UNLOCKED) {
227     verify_unlocked_properties();
228     state_ = State::WRITE_LOCKED;
229     return;
230   }
231   MutexWaiter waiter;
232   write_waiters_.push_back(waiter);
233   lock.unlock();
234   waiter.baton.wait();
235 }
236
237 template <typename BatonType>
238 template <typename Rep, typename Period>
239 bool TimedRWMutex<BatonType>::timed_write_lock(
240     const std::chrono::duration<Rep, Period>& duration) {
241   std::unique_lock<folly::SpinLock> lock{lock_};
242   if (state_ == State::UNLOCKED) {
243     verify_unlocked_properties();
244     state_ = State::WRITE_LOCKED;
245     return true;
246   }
247   MutexWaiter waiter;
248   write_waiters_.push_back(waiter);
249   lock.unlock();
250
251   if (!waiter.baton.timed_wait(duration)) {
252     // We timed out. Two cases:
253     // 1. We're still in the waiter list and we truly timed out
254     // 2. We're not in the waiter list anymore. This could happen if the baton
255     //    times out but the mutex is unlocked before we reach this code. In
256     //    this case we'll pretend we got the lock on time.
257     std::lock_guard<SpinLock> guard{lock_};
258     if (waiter.hook.is_linked()) {
259       write_waiters_.erase(write_waiters_.iterator_to(waiter));
260       return false;
261     }
262   }
263   assert(state_ == State::WRITE_LOCKED);
264   return true;
265 }
266
267 template <typename BatonType>
268 bool TimedRWMutex<BatonType>::try_write_lock() {
269   std::lock_guard<SpinLock> guard{lock_};
270   if (state_ == State::UNLOCKED) {
271     verify_unlocked_properties();
272     state_ = State::WRITE_LOCKED;
273     return true;
274   }
275   return false;
276 }
277
278 template <typename BatonType>
279 void TimedRWMutex<BatonType>::unlock() {
280   std::lock_guard<SpinLock> guard{lock_};
281   assert(state_ != State::UNLOCKED);
282   assert(
283       (state_ == State::READ_LOCKED && readers_ > 0) ||
284       (state_ == State::WRITE_LOCKED && readers_ == 0));
285   if (state_ == State::READ_LOCKED) {
286     readers_ -= 1;
287   }
288
289   if (!read_waiters_.empty()) {
290     assert(
291         state_ == State::WRITE_LOCKED && readers_ == 0 &&
292         "read waiters can only accumulate while write locked");
293     state_ = State::READ_LOCKED;
294     readers_ = read_waiters_.size();
295
296     while (!read_waiters_.empty()) {
297       MutexWaiter& to_wake = read_waiters_.front();
298       read_waiters_.pop_front();
299       to_wake.baton.post();
300     }
301   } else if (readers_ == 0) {
302     if (!write_waiters_.empty()) {
303       assert(read_waiters_.empty());
304       state_ = State::WRITE_LOCKED;
305
306       // Wake a single writer (after releasing the spin lock)
307       MutexWaiter& to_wake = write_waiters_.front();
308       write_waiters_.pop_front();
309       to_wake.baton.post();
310     } else {
311       verify_unlocked_properties();
312       state_ = State::UNLOCKED;
313     }
314   } else {
315     assert(state_ == State::READ_LOCKED);
316   }
317 }
318
319 template <typename BatonType>
320 void TimedRWMutex<BatonType>::downgrade() {
321   std::lock_guard<SpinLock> guard{lock_};
322   assert(state_ == State::WRITE_LOCKED && readers_ == 0);
323   state_ = State::READ_LOCKED;
324   readers_ += 1;
325
326   if (!read_waiters_.empty()) {
327     readers_ += read_waiters_.size();
328
329     while (!read_waiters_.empty()) {
330       MutexWaiter& to_wake = read_waiters_.front();
331       read_waiters_.pop_front();
332       to_wake.baton.post();
333     }
334   }
335 }
336 }
337 }