Don't use pthread_spinlock_t in TimedRWMutex
[folly.git] / folly / fibers / TimedMutex-inl.h
index 96ee60646cffd1de988f2433e373277df83e0549..1c64a0ebb0646fcda77413ade6af8d2fca8a791f 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -22,80 +22,127 @@ namespace fibers {
 // TimedMutex implementation
 //
 
-template <typename BatonType>
-void TimedMutex<BatonType>::lock() {
-  pthread_spin_lock(&lock_);
+template <typename WaitFunc>
+TimedMutex::LockResult TimedMutex::lockHelper(WaitFunc&& waitFunc) {
+  std::unique_lock<folly::SpinLock> lock(lock_);
   if (!locked_) {
     locked_ = true;
-    pthread_spin_unlock(&lock_);
-    return;
+    return LockResult::SUCCESS;
+  }
+
+  const auto isOnFiber = onFiber();
+
+  if (!isOnFiber && notifiedFiber_ != nullptr) {
+    // lock() was called on a thread and while some other fiber was already
+    // notified, it hasn't be run yet. We steal the lock from that fiber then
+    // to avoid potential deadlock.
+    DCHECK(threadWaiters_.empty());
+    notifiedFiber_ = nullptr;
+    return LockResult::SUCCESS;
   }
 
   // Delay constructing the waiter until it is actually required.
   // This makes a huge difference, at least in the benchmarks,
   // when the mutex isn't locked.
   MutexWaiter waiter;
-  waiters_.push_back(waiter);
-  pthread_spin_unlock(&lock_);
-  waiter.baton.wait();
+  if (isOnFiber) {
+    fiberWaiters_.push_back(waiter);
+  } else {
+    threadWaiters_.push_back(waiter);
+  }
+
+  lock.unlock();
+
+  if (!waitFunc(waiter)) {
+    return LockResult::TIMEOUT;
+  }
+
+  if (isOnFiber) {
+    auto lockStolen = [&] {
+      std::lock_guard<folly::SpinLock> lg(lock_);
+
+      auto stolen = notifiedFiber_ != &waiter;
+      notifiedFiber_ = nullptr;
+      return stolen;
+    }();
+
+    if (lockStolen) {
+      return LockResult::STOLEN;
+    }
+  }
+
+  return LockResult::SUCCESS;
 }
 
-template <typename BatonType>
-template <typename Rep, typename Period>
-bool TimedMutex<BatonType>::timed_lock(
-    const std::chrono::duration<Rep, Period>& duration) {
-  pthread_spin_lock(&lock_);
-  if (!locked_) {
-    locked_ = true;
-    pthread_spin_unlock(&lock_);
+inline void TimedMutex::lock() {
+  auto result = lockHelper([](MutexWaiter& waiter) {
+    waiter.baton.wait();
     return true;
+  });
+
+  DCHECK(result != LockResult::TIMEOUT);
+  if (result == LockResult::SUCCESS) {
+    return;
   }
+  lock();
+}
 
-  MutexWaiter waiter;
-  waiters_.push_back(waiter);
-  pthread_spin_unlock(&lock_);
+template <typename Rep, typename Period>
+bool TimedMutex::timed_lock(
+    const std::chrono::duration<Rep, Period>& duration) {
+  auto result = lockHelper([&](MutexWaiter& waiter) {
+    if (!waiter.baton.timed_wait(duration)) {
+      // We timed out. Two cases:
+      // 1. We're still in the waiter list and we truly timed out
+      // 2. We're not in the waiter list anymore. This could happen if the baton
+      //    times out but the mutex is unlocked before we reach this code. In
+      //    this
+      //    case we'll pretend we got the lock on time.
+      std::lock_guard<folly::SpinLock> lg(lock_);
+      if (waiter.hook.is_linked()) {
+        waiter.hook.unlink();
+        return false;
+      }
+    }
+    return true;
+  });
 
-  if (!waiter.baton.timed_wait(duration)) {
-    // We timed out. Two cases:
-    // 1. We're still in the waiter list and we truly timed out
-    // 2. We're not in the waiter list anymore. This could happen if the baton
-    //    times out but the mutex is unlocked before we reach this code. In this
-    //    case we'll pretend we got the lock on time.
-    pthread_spin_lock(&lock_);
-    if (waiter.hook.is_linked()) {
-      waiters_.erase(waiters_.iterator_to(waiter));
-      pthread_spin_unlock(&lock_);
+  switch (result) {
+    case LockResult::SUCCESS:
+      return true;
+    case LockResult::TIMEOUT:
       return false;
-    }
-    pthread_spin_unlock(&lock_);
+    case LockResult::STOLEN:
+      // We don't respect the duration if lock was stolen
+      lock();
+      return true;
   }
-  return true;
+  assume_unreachable();
 }
 
-template <typename BatonType>
-bool TimedMutex<BatonType>::try_lock() {
-  pthread_spin_lock(&lock_);
+inline bool TimedMutex::try_lock() {
+  std::lock_guard<folly::SpinLock> lg(lock_);
   if (locked_) {
-    pthread_spin_unlock(&lock_);
     return false;
   }
   locked_ = true;
-  pthread_spin_unlock(&lock_);
   return true;
 }
 
-template <typename BatonType>
-void TimedMutex<BatonType>::unlock() {
-  pthread_spin_lock(&lock_);
-  if (waiters_.empty()) {
+inline void TimedMutex::unlock() {
+  std::lock_guard<folly::SpinLock> lg(lock_);
+  if (!threadWaiters_.empty()) {
+    auto& to_wake = threadWaiters_.front();
+    threadWaiters_.pop_front();
+    to_wake.baton.post();
+  } else if (!fiberWaiters_.empty()) {
+    auto& to_wake = fiberWaiters_.front();
+    fiberWaiters_.pop_front();
+    notifiedFiber_ = &to_wake;
+    to_wake.baton.post();
+  } else {
     locked_ = false;
-    pthread_spin_unlock(&lock_);
-    return;
   }
-  MutexWaiter& to_wake = waiters_.front();
-  waiters_.pop_front();
-  to_wake.baton.post();
-  pthread_spin_unlock(&lock_);
 }
 
 //
@@ -104,11 +151,11 @@ void TimedMutex<BatonType>::unlock() {
 
 template <typename BatonType>
 void TimedRWMutex<BatonType>::read_lock() {
-  pthread_spin_lock(&lock_);
+  lock_.lock();
   if (state_ == State::WRITE_LOCKED) {
     MutexWaiter waiter;
     read_waiters_.push_back(waiter);
-    pthread_spin_unlock(&lock_);
+    lock_.unlock();
     waiter.baton.wait();
     assert(state_ == State::READ_LOCKED);
     return;
@@ -119,18 +166,18 @@ void TimedRWMutex<BatonType>::read_lock() {
   assert(read_waiters_.empty());
   state_ = State::READ_LOCKED;
   readers_ += 1;
-  pthread_spin_unlock(&lock_);
+  lock_.unlock();
 }
 
 template <typename BatonType>
 template <typename Rep, typename Period>
 bool TimedRWMutex<BatonType>::timed_read_lock(
     const std::chrono::duration<Rep, Period>& duration) {
-  pthread_spin_lock(&lock_);
+  lock_.lock();
   if (state_ == State::WRITE_LOCKED) {
     MutexWaiter waiter;
     read_waiters_.push_back(waiter);
-    pthread_spin_unlock(&lock_);
+    lock_.unlock();
 
     if (!waiter.baton.timed_wait(duration)) {
       // We timed out. Two cases:
@@ -138,13 +185,13 @@ bool TimedRWMutex<BatonType>::timed_read_lock(
       // 2. We're not in the waiter list anymore. This could happen if the baton
       //    times out but the mutex is unlocked before we reach this code. In
       //    this case we'll pretend we got the lock on time.
-      pthread_spin_lock(&lock_);
+      lock_.lock();
       if (waiter.hook.is_linked()) {
         read_waiters_.erase(read_waiters_.iterator_to(waiter));
-        pthread_spin_unlock(&lock_);
+        lock_.unlock();
         return false;
       }
-      pthread_spin_unlock(&lock_);
+      lock_.unlock();
     }
     return true;
   }
@@ -154,13 +201,13 @@ bool TimedRWMutex<BatonType>::timed_read_lock(
   assert(read_waiters_.empty());
   state_ = State::READ_LOCKED;
   readers_ += 1;
-  pthread_spin_unlock(&lock_);
+  lock_.unlock();
   return true;
 }
 
 template <typename BatonType>
 bool TimedRWMutex<BatonType>::try_read_lock() {
-  pthread_spin_lock(&lock_);
+  lock_.lock();
   if (state_ != State::WRITE_LOCKED) {
     assert(
         (state_ == State::UNLOCKED && readers_ == 0) ||
@@ -168,25 +215,25 @@ bool TimedRWMutex<BatonType>::try_read_lock() {
     assert(read_waiters_.empty());
     state_ = State::READ_LOCKED;
     readers_ += 1;
-    pthread_spin_unlock(&lock_);
+    lock_.unlock();
     return true;
   }
-  pthread_spin_unlock(&lock_);
+  lock_.unlock();
   return false;
 }
 
 template <typename BatonType>
 void TimedRWMutex<BatonType>::write_lock() {
-  pthread_spin_lock(&lock_);
+  lock_.lock();
   if (state_ == State::UNLOCKED) {
     verify_unlocked_properties();
     state_ = State::WRITE_LOCKED;
-    pthread_spin_unlock(&lock_);
+    lock_.unlock();
     return;
   }
   MutexWaiter waiter;
   write_waiters_.push_back(waiter);
-  pthread_spin_unlock(&lock_);
+  lock_.unlock();
   waiter.baton.wait();
 }
 
@@ -194,16 +241,16 @@ template <typename BatonType>
 template <typename Rep, typename Period>
 bool TimedRWMutex<BatonType>::timed_write_lock(
     const std::chrono::duration<Rep, Period>& duration) {
-  pthread_spin_lock(&lock_);
+  lock_.lock();
   if (state_ == State::UNLOCKED) {
     verify_unlocked_properties();
     state_ = State::WRITE_LOCKED;
-    pthread_spin_unlock(&lock_);
+    lock_.unlock();
     return true;
   }
   MutexWaiter waiter;
   write_waiters_.push_back(waiter);
-  pthread_spin_unlock(&lock_);
+  lock_.unlock();
 
   if (!waiter.baton.timed_wait(duration)) {
     // We timed out. Two cases:
@@ -211,13 +258,13 @@ bool TimedRWMutex<BatonType>::timed_write_lock(
     // 2. We're not in the waiter list anymore. This could happen if the baton
     //    times out but the mutex is unlocked before we reach this code. In
     //    this case we'll pretend we got the lock on time.
-    pthread_spin_lock(&lock_);
+    lock_.lock();
     if (waiter.hook.is_linked()) {
       write_waiters_.erase(write_waiters_.iterator_to(waiter));
-      pthread_spin_unlock(&lock_);
+      lock_.unlock();
       return false;
     }
-    pthread_spin_unlock(&lock_);
+    lock_.unlock();
   }
   assert(state_ == State::WRITE_LOCKED);
   return true;
@@ -225,20 +272,20 @@ bool TimedRWMutex<BatonType>::timed_write_lock(
 
 template <typename BatonType>
 bool TimedRWMutex<BatonType>::try_write_lock() {
-  pthread_spin_lock(&lock_);
+  lock_.lock();
   if (state_ == State::UNLOCKED) {
     verify_unlocked_properties();
     state_ = State::WRITE_LOCKED;
-    pthread_spin_unlock(&lock_);
+    lock_.unlock();
     return true;
   }
-  pthread_spin_unlock(&lock_);
+  lock_.unlock();
   return false;
 }
 
 template <typename BatonType>
 void TimedRWMutex<BatonType>::unlock() {
-  pthread_spin_lock(&lock_);
+  lock_.lock();
   assert(state_ != State::UNLOCKED);
   assert(
       (state_ == State::READ_LOCKED && readers_ > 0) ||
@@ -275,12 +322,12 @@ void TimedRWMutex<BatonType>::unlock() {
   } else {
     assert(state_ == State::READ_LOCKED);
   }
-  pthread_spin_unlock(&lock_);
+  lock_.unlock();
 }
 
 template <typename BatonType>
 void TimedRWMutex<BatonType>::downgrade() {
-  pthread_spin_lock(&lock_);
+  lock_.lock();
   assert(state_ == State::WRITE_LOCKED && readers_ == 0);
   state_ = State::READ_LOCKED;
   readers_ += 1;
@@ -294,7 +341,7 @@ void TimedRWMutex<BatonType>::downgrade() {
       to_wake.baton.post();
     }
   }
-  pthread_spin_unlock(&lock_);
+  lock_.unlock();
 }
 }
 }