Fix TimedMutex deadlock when used both from fiber and main context
authorAndrii Grynenko <andrii@fb.com>
Fri, 16 Dec 2016 21:27:33 +0000 (13:27 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 16 Dec 2016 21:33:07 +0000 (13:33 -0800)
Summary: TimedMutex is a fair mutex, which can cause a deadlock if same mutex is requested first in a fiber, and then in main context.

Reviewed By: yfeldblum

Differential Revision: D4209155

fbshipit-source-id: 0623d9a2e6a0b5cc310fb71ad1b1cf33afd6a30e

folly/fibers/TimedMutex-inl.h
folly/fibers/TimedMutex.h
folly/fibers/test/FibersTest.cpp

index 96ee60646cffd1de988f2433e373277df83e0549..26e523c73f0f1463deb380986732ee20d5884aa1 100644 (file)
@@ -22,80 +22,127 @@ namespace fibers {
 // TimedMutex implementation
 //
 
-template <typename BatonType>
-void TimedMutex<BatonType>::lock() {
-  pthread_spin_lock(&lock_);
+template <typename WaitFunc>
+TimedMutex::LockResult TimedMutex::lockHelper(WaitFunc&& waitFunc) {
+  std::unique_lock<folly::SpinLock> lock(lock_);
   if (!locked_) {
     locked_ = true;
-    pthread_spin_unlock(&lock_);
-    return;
+    return LockResult::SUCCESS;
+  }
+
+  const auto isOnFiber = onFiber();
+
+  if (!isOnFiber && notifiedFiber_ != nullptr) {
+    // lock() was called on a thread and while some other fiber was already
+    // notified, it hasn't be run yet. We steal the lock from that fiber then
+    // to avoid potential deadlock.
+    DCHECK(threadWaiters_.empty());
+    notifiedFiber_ = nullptr;
+    return LockResult::SUCCESS;
   }
 
   // Delay constructing the waiter until it is actually required.
   // This makes a huge difference, at least in the benchmarks,
   // when the mutex isn't locked.
   MutexWaiter waiter;
-  waiters_.push_back(waiter);
-  pthread_spin_unlock(&lock_);
-  waiter.baton.wait();
+  if (isOnFiber) {
+    fiberWaiters_.push_back(waiter);
+  } else {
+    threadWaiters_.push_back(waiter);
+  }
+
+  lock.unlock();
+
+  if (!waitFunc(waiter)) {
+    return LockResult::TIMEOUT;
+  }
+
+  if (isOnFiber) {
+    auto lockStolen = [&] {
+      std::lock_guard<folly::SpinLock> lg(lock_);
+
+      auto lockStolen = notifiedFiber_ != &waiter;
+      notifiedFiber_ = nullptr;
+      return lockStolen;
+    }();
+
+    if (lockStolen) {
+      return LockResult::STOLEN;
+    }
+  }
+
+  return LockResult::SUCCESS;
 }
 
-template <typename BatonType>
-template <typename Rep, typename Period>
-bool TimedMutex<BatonType>::timed_lock(
-    const std::chrono::duration<Rep, Period>& duration) {
-  pthread_spin_lock(&lock_);
-  if (!locked_) {
-    locked_ = true;
-    pthread_spin_unlock(&lock_);
+inline void TimedMutex::lock() {
+  auto result = lockHelper([](MutexWaiter& waiter) {
+    waiter.baton.wait();
     return true;
+  });
+
+  DCHECK(result != LockResult::TIMEOUT);
+  if (result == LockResult::SUCCESS) {
+    return;
   }
+  lock();
+}
 
-  MutexWaiter waiter;
-  waiters_.push_back(waiter);
-  pthread_spin_unlock(&lock_);
+template <typename Rep, typename Period>
+bool TimedMutex::timed_lock(
+    const std::chrono::duration<Rep, Period>& duration) {
+  auto result = lockHelper([&](MutexWaiter& waiter) {
+    if (!waiter.baton.timed_wait(duration)) {
+      // We timed out. Two cases:
+      // 1. We're still in the waiter list and we truly timed out
+      // 2. We're not in the waiter list anymore. This could happen if the baton
+      //    times out but the mutex is unlocked before we reach this code. In
+      //    this
+      //    case we'll pretend we got the lock on time.
+      std::lock_guard<folly::SpinLock> lg(lock_);
+      if (waiter.hook.is_linked()) {
+        waiter.hook.unlink();
+        return false;
+      }
+    }
+    return true;
+  });
 
-  if (!waiter.baton.timed_wait(duration)) {
-    // We timed out. Two cases:
-    // 1. We're still in the waiter list and we truly timed out
-    // 2. We're not in the waiter list anymore. This could happen if the baton
-    //    times out but the mutex is unlocked before we reach this code. In this
-    //    case we'll pretend we got the lock on time.
-    pthread_spin_lock(&lock_);
-    if (waiter.hook.is_linked()) {
-      waiters_.erase(waiters_.iterator_to(waiter));
-      pthread_spin_unlock(&lock_);
+  switch (result) {
+    case LockResult::SUCCESS:
+      return true;
+    case LockResult::TIMEOUT:
       return false;
-    }
-    pthread_spin_unlock(&lock_);
+    case LockResult::STOLEN:
+      // We don't respect the duration if lock was stolen
+      lock();
+      return true;
   }
-  return true;
+  assume_unreachable();
 }
 
-template <typename BatonType>
-bool TimedMutex<BatonType>::try_lock() {
-  pthread_spin_lock(&lock_);
+inline bool TimedMutex::try_lock() {
+  std::lock_guard<folly::SpinLock> lg(lock_);
   if (locked_) {
-    pthread_spin_unlock(&lock_);
     return false;
   }
   locked_ = true;
-  pthread_spin_unlock(&lock_);
   return true;
 }
 
-template <typename BatonType>
-void TimedMutex<BatonType>::unlock() {
-  pthread_spin_lock(&lock_);
-  if (waiters_.empty()) {
+inline void TimedMutex::unlock() {
+  std::lock_guard<folly::SpinLock> lg(lock_);
+  if (!threadWaiters_.empty()) {
+    auto& to_wake = threadWaiters_.front();
+    threadWaiters_.pop_front();
+    to_wake.baton.post();
+  } else if (!fiberWaiters_.empty()) {
+    auto& to_wake = fiberWaiters_.front();
+    fiberWaiters_.pop_front();
+    notifiedFiber_ = &to_wake;
+    to_wake.baton.post();
+  } else {
     locked_ = false;
-    pthread_spin_unlock(&lock_);
-    return;
   }
-  MutexWaiter& to_wake = waiters_.front();
-  waiters_.pop_front();
-  to_wake.baton.post();
-  pthread_spin_unlock(&lock_);
 }
 
 //
index e21f246cca406358dc3f1dbab0aad9fec4655a6f..6dbe50ba0020ffdf88e1d9cd047013ea0145bed1 100644 (file)
@@ -17,6 +17,8 @@
 
 #include <pthread.h>
 
+#include <folly/IntrusiveList.h>
+#include <folly/SpinLock.h>
 #include <folly/fibers/GenericBaton.h>
 
 namespace folly {
@@ -27,15 +29,14 @@ namespace fibers {
  *
  * Like mutex but allows timed_lock in addition to lock and try_lock.
  **/
-template <typename BatonType>
 class TimedMutex {
  public:
-  TimedMutex() {
-    pthread_spin_init(&lock_, PTHREAD_PROCESS_PRIVATE);
-  }
+  TimedMutex() {}
 
   ~TimedMutex() {
-    pthread_spin_destroy(&lock_);
+    DCHECK(threadWaiters_.empty());
+    DCHECK(fiberWaiters_.empty());
+    DCHECK(notifiedFiber_ == nullptr);
   }
 
   TimedMutex(const TimedMutex& rhs) = delete;
@@ -59,28 +60,25 @@ class TimedMutex {
   void unlock();
 
  private:
-  typedef boost::intrusive::list_member_hook<> MutexWaiterHookType;
+  enum class LockResult { SUCCESS, TIMEOUT, STOLEN };
+
+  template <typename WaitFunc>
+  LockResult lockHelper(WaitFunc&& waitFunc);
 
   // represents a waiter waiting for the lock. The waiter waits on the
   // baton until it is woken up by a post or timeout expires.
   struct MutexWaiter {
-    BatonType baton;
-    MutexWaiterHookType hook;
+    Baton baton;
+    folly::IntrusiveListHook hook;
   };
 
-  typedef boost::intrusive::
-      member_hook<MutexWaiter, MutexWaiterHookType, &MutexWaiter::hook>
-          MutexWaiterHook;
-
-  typedef boost::intrusive::list<
-      MutexWaiter,
-      MutexWaiterHook,
-      boost::intrusive::constant_time_size<true>>
-      MutexWaiterList;
+  using MutexWaiterList = folly::IntrusiveList<MutexWaiter, &MutexWaiter::hook>;
 
-  pthread_spinlock_t lock_; //< lock to protect waiter list
+  folly::SpinLock lock_; //< lock to protect waiter list
   bool locked_ = false; //< is this locked by some thread?
-  MutexWaiterList waiters_; //< list of waiters
+  MutexWaiterList threadWaiters_; //< list of waiters
+  MutexWaiterList fiberWaiters_; //< list of waiters
+  MutexWaiter* notifiedFiber_{nullptr}; //< Fiber waiter which has been notified
 };
 
 /**
index d1b60cee6b13c69f9236dce3f4a20a9e49a6f064..d096839a6ef1b130e9fcdbe1cf97baf85eca3224 100644 (file)
@@ -31,6 +31,7 @@
 #include <folly/fibers/GenericBaton.h>
 #include <folly/fibers/Semaphore.h>
 #include <folly/fibers/SimpleLoopController.h>
+#include <folly/fibers/TimedMutex.h>
 #include <folly/fibers/WhenN.h>
 #include <folly/io/async/ScopedEventBaseThread.h>
 #include <folly/portability/GTest.h>
@@ -2072,6 +2073,64 @@ TEST(FiberManager, VirtualEventBase) {
   EXPECT_TRUE(done2);
 }
 
+TEST(TimedMutex, ThreadFiberDeadlockOrder) {
+  folly::EventBase evb;
+  auto& fm = getFiberManager(evb);
+  TimedMutex mutex;
+
+  mutex.lock();
+  std::thread unlockThread([&] {
+    /* sleep override */ std::this_thread::sleep_for(
+        std::chrono::milliseconds{100});
+    mutex.unlock();
+  });
+
+  fm.addTask([&] { std::lock_guard<TimedMutex> lg(mutex); });
+  fm.addTask([&] {
+    runInMainContext([&] {
+      auto locked = mutex.timed_lock(std::chrono::seconds{1});
+      EXPECT_TRUE(locked);
+      if (locked) {
+        mutex.unlock();
+      }
+    });
+  });
+
+  evb.loopOnce();
+  EXPECT_EQ(0, fm.hasTasks());
+
+  unlockThread.join();
+}
+
+TEST(TimedMutex, ThreadFiberDeadlockRace) {
+  folly::EventBase evb;
+  auto& fm = getFiberManager(evb);
+  TimedMutex mutex;
+
+  mutex.lock();
+
+  fm.addTask([&] {
+    auto locked = mutex.timed_lock(std::chrono::seconds{1});
+    EXPECT_TRUE(locked);
+    if (locked) {
+      mutex.unlock();
+    }
+  });
+  fm.addTask([&] {
+    mutex.unlock();
+    runInMainContext([&] {
+      auto locked = mutex.timed_lock(std::chrono::seconds{1});
+      EXPECT_TRUE(locked);
+      if (locked) {
+        mutex.unlock();
+      }
+    });
+  });
+
+  evb.loopOnce();
+  EXPECT_EQ(0, fm.hasTasks());
+}
+
 /**
  * Test that we can properly track fiber stack usage.
  *