Fix TimedMutex deadlock when used both from fiber and main context
[folly.git] / folly / fibers / TimedMutex-inl.h
index 96ee60646cffd1de988f2433e373277df83e0549..26e523c73f0f1463deb380986732ee20d5884aa1 100644 (file)
@@ -22,80 +22,127 @@ namespace fibers {
 // TimedMutex implementation
 //
 
-template <typename BatonType>
-void TimedMutex<BatonType>::lock() {
-  pthread_spin_lock(&lock_);
+template <typename WaitFunc>
+TimedMutex::LockResult TimedMutex::lockHelper(WaitFunc&& waitFunc) {
+  std::unique_lock<folly::SpinLock> lock(lock_);
   if (!locked_) {
     locked_ = true;
-    pthread_spin_unlock(&lock_);
-    return;
+    return LockResult::SUCCESS;
+  }
+
+  const auto isOnFiber = onFiber();
+
+  if (!isOnFiber && notifiedFiber_ != nullptr) {
+    // lock() was called on a thread and while some other fiber was already
+    // notified, it hasn't be run yet. We steal the lock from that fiber then
+    // to avoid potential deadlock.
+    DCHECK(threadWaiters_.empty());
+    notifiedFiber_ = nullptr;
+    return LockResult::SUCCESS;
   }
 
   // Delay constructing the waiter until it is actually required.
   // This makes a huge difference, at least in the benchmarks,
   // when the mutex isn't locked.
   MutexWaiter waiter;
-  waiters_.push_back(waiter);
-  pthread_spin_unlock(&lock_);
-  waiter.baton.wait();
+  if (isOnFiber) {
+    fiberWaiters_.push_back(waiter);
+  } else {
+    threadWaiters_.push_back(waiter);
+  }
+
+  lock.unlock();
+
+  if (!waitFunc(waiter)) {
+    return LockResult::TIMEOUT;
+  }
+
+  if (isOnFiber) {
+    auto lockStolen = [&] {
+      std::lock_guard<folly::SpinLock> lg(lock_);
+
+      auto lockStolen = notifiedFiber_ != &waiter;
+      notifiedFiber_ = nullptr;
+      return lockStolen;
+    }();
+
+    if (lockStolen) {
+      return LockResult::STOLEN;
+    }
+  }
+
+  return LockResult::SUCCESS;
 }
 
-template <typename BatonType>
-template <typename Rep, typename Period>
-bool TimedMutex<BatonType>::timed_lock(
-    const std::chrono::duration<Rep, Period>& duration) {
-  pthread_spin_lock(&lock_);
-  if (!locked_) {
-    locked_ = true;
-    pthread_spin_unlock(&lock_);
+inline void TimedMutex::lock() {
+  auto result = lockHelper([](MutexWaiter& waiter) {
+    waiter.baton.wait();
     return true;
+  });
+
+  DCHECK(result != LockResult::TIMEOUT);
+  if (result == LockResult::SUCCESS) {
+    return;
   }
+  lock();
+}
 
-  MutexWaiter waiter;
-  waiters_.push_back(waiter);
-  pthread_spin_unlock(&lock_);
+template <typename Rep, typename Period>
+bool TimedMutex::timed_lock(
+    const std::chrono::duration<Rep, Period>& duration) {
+  auto result = lockHelper([&](MutexWaiter& waiter) {
+    if (!waiter.baton.timed_wait(duration)) {
+      // We timed out. Two cases:
+      // 1. We're still in the waiter list and we truly timed out
+      // 2. We're not in the waiter list anymore. This could happen if the baton
+      //    times out but the mutex is unlocked before we reach this code. In
+      //    this
+      //    case we'll pretend we got the lock on time.
+      std::lock_guard<folly::SpinLock> lg(lock_);
+      if (waiter.hook.is_linked()) {
+        waiter.hook.unlink();
+        return false;
+      }
+    }
+    return true;
+  });
 
-  if (!waiter.baton.timed_wait(duration)) {
-    // We timed out. Two cases:
-    // 1. We're still in the waiter list and we truly timed out
-    // 2. We're not in the waiter list anymore. This could happen if the baton
-    //    times out but the mutex is unlocked before we reach this code. In this
-    //    case we'll pretend we got the lock on time.
-    pthread_spin_lock(&lock_);
-    if (waiter.hook.is_linked()) {
-      waiters_.erase(waiters_.iterator_to(waiter));
-      pthread_spin_unlock(&lock_);
+  switch (result) {
+    case LockResult::SUCCESS:
+      return true;
+    case LockResult::TIMEOUT:
       return false;
-    }
-    pthread_spin_unlock(&lock_);
+    case LockResult::STOLEN:
+      // We don't respect the duration if lock was stolen
+      lock();
+      return true;
   }
-  return true;
+  assume_unreachable();
 }
 
-template <typename BatonType>
-bool TimedMutex<BatonType>::try_lock() {
-  pthread_spin_lock(&lock_);
+inline bool TimedMutex::try_lock() {
+  std::lock_guard<folly::SpinLock> lg(lock_);
   if (locked_) {
-    pthread_spin_unlock(&lock_);
     return false;
   }
   locked_ = true;
-  pthread_spin_unlock(&lock_);
   return true;
 }
 
-template <typename BatonType>
-void TimedMutex<BatonType>::unlock() {
-  pthread_spin_lock(&lock_);
-  if (waiters_.empty()) {
+inline void TimedMutex::unlock() {
+  std::lock_guard<folly::SpinLock> lg(lock_);
+  if (!threadWaiters_.empty()) {
+    auto& to_wake = threadWaiters_.front();
+    threadWaiters_.pop_front();
+    to_wake.baton.post();
+  } else if (!fiberWaiters_.empty()) {
+    auto& to_wake = fiberWaiters_.front();
+    fiberWaiters_.pop_front();
+    notifiedFiber_ = &to_wake;
+    to_wake.baton.post();
+  } else {
     locked_ = false;
-    pthread_spin_unlock(&lock_);
-    return;
   }
-  MutexWaiter& to_wake = waiters_.front();
-  waiters_.pop_front();
-  to_wake.baton.post();
-  pthread_spin_unlock(&lock_);
 }
 
 //