Fix TimedMutex deadlock when used both from fiber and main context
[folly.git] / folly / fibers / TimedMutex.h
index e21f246cca406358dc3f1dbab0aad9fec4655a6f..6dbe50ba0020ffdf88e1d9cd047013ea0145bed1 100644 (file)
@@ -17,6 +17,8 @@
 
 #include <pthread.h>
 
+#include <folly/IntrusiveList.h>
+#include <folly/SpinLock.h>
 #include <folly/fibers/GenericBaton.h>
 
 namespace folly {
@@ -27,15 +29,14 @@ namespace fibers {
  *
  * Like mutex but allows timed_lock in addition to lock and try_lock.
  **/
-template <typename BatonType>
 class TimedMutex {
  public:
-  TimedMutex() {
-    pthread_spin_init(&lock_, PTHREAD_PROCESS_PRIVATE);
-  }
+  TimedMutex() {}
 
   ~TimedMutex() {
-    pthread_spin_destroy(&lock_);
+    DCHECK(threadWaiters_.empty());
+    DCHECK(fiberWaiters_.empty());
+    DCHECK(notifiedFiber_ == nullptr);
   }
 
   TimedMutex(const TimedMutex& rhs) = delete;
@@ -59,28 +60,25 @@ class TimedMutex {
   void unlock();
 
  private:
-  typedef boost::intrusive::list_member_hook<> MutexWaiterHookType;
+  enum class LockResult { SUCCESS, TIMEOUT, STOLEN };
+
+  template <typename WaitFunc>
+  LockResult lockHelper(WaitFunc&& waitFunc);
 
   // represents a waiter waiting for the lock. The waiter waits on the
   // baton until it is woken up by a post or timeout expires.
   struct MutexWaiter {
-    BatonType baton;
-    MutexWaiterHookType hook;
+    Baton baton;
+    folly::IntrusiveListHook hook;
   };
 
-  typedef boost::intrusive::
-      member_hook<MutexWaiter, MutexWaiterHookType, &MutexWaiter::hook>
-          MutexWaiterHook;
-
-  typedef boost::intrusive::list<
-      MutexWaiter,
-      MutexWaiterHook,
-      boost::intrusive::constant_time_size<true>>
-      MutexWaiterList;
+  using MutexWaiterList = folly::IntrusiveList<MutexWaiter, &MutexWaiter::hook>;
 
-  pthread_spinlock_t lock_; //< lock to protect waiter list
+  folly::SpinLock lock_; //< lock to protect waiter list
   bool locked_ = false; //< is this locked by some thread?
-  MutexWaiterList waiters_; //< list of waiters
+  MutexWaiterList threadWaiters_; //< list of waiters
+  MutexWaiterList fiberWaiters_; //< list of waiters
+  MutexWaiter* notifiedFiber_{nullptr}; //< Fiber waiter which has been notified
 };
 
 /**