add tryReadUntil and make fixes along the way
authorJames Sedgwick <jsedgwick@fb.com>
Tue, 8 Nov 2016 07:13:30 +0000 (23:13 -0800)
committerFacebook Github Bot <facebook-github-bot-bot@fb.com>
Tue, 8 Nov 2016 07:23:29 +0000 (23:23 -0800)
Summary:
this diff adds tryReadUntil, which is a mirror of tryWriteUntil in both function and implementation.
Two bugs were exposed in the process of implementing and testing tryWriteUntil; they are fixed as well and are as follows:
  1. tryObtainPromisedPopTicket didn't assign to the passed ticket return reference in the failure case
  2. TurnSequencer::tryWaitForTurn() didn't distinguish between past turns and timeouts in the failure case; they need to be
     differentiated because SingleElementQueue::tryWaitFor{De/En}queue() should only fail in the timeout case, not if the turn has passed.

The two added unit tests are admittedly clumsy, but making the obvious simplifications to them keeps them from triggering the premature timeout race caused by bug 2 above, so I kept them as is.

Reviewed By: magedm

Differential Revision: D4050515

fbshipit-source-id: b0a3dd894d502c44be62d362ea347a1837df4c2f

folly/MPMCQueue.h
folly/detail/TurnSequencer.h
folly/experimental/LockFreeRingBuffer.h
folly/stop_watch.h
folly/test/MPMCQueueTest.cpp

index 02a2bf12ac268be2ada6916a9bd391a5d31e63b5..43e9eb210a7c4e9935306f45c03090beb1e0bab4 100644 (file)
@@ -266,10 +266,8 @@ class MPMCQueue<T,Atom,true> :
       if (!trySeqlockReadSection(state, slots, cap, stride)) {
         continue;
       }
-      offset = getOffset(state);
-      if (ticket < offset) {
+      if (maybeUpdateFromClosed(state, ticket, offset, slots, cap, stride)) {
         // There was an expansion after this ticket was issued.
-        updateFromClosed(state, ticket, offset, slots, cap, stride);
         break;
       }
       if (slots[this->idx((ticket-offset), cap, stride)]
@@ -305,12 +303,9 @@ class MPMCQueue<T,Atom,true> :
     uint64_t state;
     uint64_t offset;
     while (!trySeqlockReadSection(state, slots, cap, stride));
-    offset = getOffset(state);
-    if (ticket < offset) {
-      // There was an expansion after the corresponding push ticket
-      // was issued.
-      updateFromClosed(state, ticket, offset, slots, cap, stride);
-    }
+    // If there was an expansion after the corresponding push ticket
+    // was issued, adjust accordingly
+    maybeUpdateFromClosed(state, ticket, offset, slots, cap, stride);
     this->dequeueWithTicketBase(ticket-offset, slots, cap, stride, elem);
   }
 
@@ -351,11 +346,12 @@ class MPMCQueue<T,Atom,true> :
       if (!trySeqlockReadSection(state, slots, cap, stride)) {
         continue;
       }
-      uint64_t offset = getOffset(state);
-      if (ticket < offset) {
-        // There was an expansion with offset greater than this ticket
-        updateFromClosed(state, ticket, offset, slots, cap, stride);
-      }
+
+      // If there was an expansion with offset greater than this ticket,
+      // adjust accordingly
+      uint64_t offset;
+      maybeUpdateFromClosed(state, ticket, offset, slots, cap, stride);
+
       if (slots[this->idx((ticket-offset), cap, stride)]
           .mayEnqueue(this->turn(ticket-offset, cap))) {
         // A slot is ready.
@@ -394,25 +390,31 @@ class MPMCQueue<T,Atom,true> :
       if (!trySeqlockReadSection(state, slots, cap, stride)) {
         continue;
       }
+
+      const auto oldCap = cap;
+      // If there was an expansion with offset greater than this ticket,
+      // adjust accordingly
+      uint64_t offset;
+      maybeUpdateFromClosed(state, ticket, offset, slots, cap, stride);
+
       int64_t n = ticket - numPops;
       if (n >= static_cast<ssize_t>(this->capacity_)) {
+        ticket -= offset;
         return false;
       }
-      if ((n >= static_cast<ssize_t>(cap))) {
-        if (tryExpand(state, cap)) {
+
+      if (n >= static_cast<ssize_t>(oldCap)) {
+        if (tryExpand(state, oldCap)) {
           // This or another thread started an expansion. Start over
           // with a new state.
           continue;
         } else {
           // Can't expand.
+          ticket -= offset;
           return false;
         }
       }
-      uint64_t offset = getOffset(state);
-      if (ticket < offset) {
-        // There was an expansion with offset greater than this ticket
-        updateFromClosed(state, ticket, offset, slots, cap, stride);
-      }
+
       if (this->pushTicket_.compare_exchange_strong(ticket, ticket + 1)) {
         // Adjust ticket
         ticket -= offset;
@@ -430,12 +432,12 @@ class MPMCQueue<T,Atom,true> :
       if (!trySeqlockReadSection(state, slots, cap, stride)) {
         continue;
       }
-      uint64_t offset = getOffset(state);
-      if (ticket < offset) {
-        // There was an expansion after the corresponding push ticket
-        // was issued.
-        updateFromClosed(state, ticket, offset, slots, cap, stride);
-      }
+
+      // If there was an expansion after the corresponding push ticket
+      // was issued, adjust accordingly
+      uint64_t offset;
+      maybeUpdateFromClosed(state, ticket, offset, slots, cap, stride);
+
       if (slots[this->idx((ticket-offset), cap, stride)]
           .mayDequeue(this->turn(ticket-offset, cap))) {
         if (this->popTicket_.compare_exchange_strong(ticket, ticket + 1)) {
@@ -459,18 +461,17 @@ class MPMCQueue<T,Atom,true> :
       if (!trySeqlockReadSection(state, slots, cap, stride)) {
         continue;
       }
+
+      uint64_t offset;
+      // If there was an expansion after the corresponding push
+      // ticket was issued, adjust accordingly
+      maybeUpdateFromClosed(state, ticket, offset, slots, cap, stride);
+
       if (ticket >= numPushes) {
+        ticket -= offset;
         return false;
       }
       if (this->popTicket_.compare_exchange_strong(ticket, ticket + 1)) {
-        // Adjust ticket
-        uint64_t offset = getOffset(state);
-        if (ticket < offset) {
-          // There was an expansion after the corresponding push
-          // ticket was issued.
-          updateFromClosed(state, ticket, offset, slots, cap, stride);
-        }
-        // Adjust ticket
         ticket -= offset;
         return true;
       }
@@ -485,12 +486,13 @@ class MPMCQueue<T,Atom,true> :
     int stride;
     uint64_t state;
     uint64_t offset;
+
     while (!trySeqlockReadSection(state, slots, cap, stride)) {}
-    offset = getOffset(state);
-    if (ticket < offset) {
-      // There was an expansion after this ticket was issued.
-      updateFromClosed(state, ticket, offset, slots, cap, stride);
-    }
+
+    // If there was an expansion after this ticket was issued, adjust
+    // accordingly
+    maybeUpdateFromClosed(state, ticket, offset, slots, cap, stride);
+
     this->enqueueWithTicketBase(ticket-offset, slots, cap, stride,
                                 std::forward<Args>(args)...);
   }
@@ -570,23 +572,32 @@ class MPMCQueue<T,Atom,true> :
     return (state == this->dstate_.load(std::memory_order_relaxed));
   }
 
-  /// Update local variables of a lagging operation using the
-  /// most recent closed array with offset <= ticket
-  void updateFromClosed(
-    const uint64_t state, const uint64_t ticket,
-    uint64_t& offset, Slot*& slots, size_t& cap, int& stride
-  ) noexcept {
+  /// If there was an expansion after ticket was issued, update local variables
+  /// of the lagging operation using the most recent closed array with
+  /// offset <= ticket and return true. Otherwise, return false;
+  bool maybeUpdateFromClosed(
+      const uint64_t state,
+      const uint64_t ticket,
+      uint64_t& offset,
+      Slot*& slots,
+      size_t& cap,
+      int& stride) noexcept {
+    offset = getOffset(state);
+    if (ticket >= offset) {
+      return false;
+    }
     for (int i = getNumClosed(state) - 1; i >= 0; --i) {
       offset = closed_[i].offset_;
       if (offset <= ticket) {
         slots = closed_[i].slots_;
         cap = closed_[i].capacity_;
         stride = closed_[i].stride_;
-        return;;
+        return true;
       }
     }
     // A closed array with offset <= ticket should have been found
     assert(false);
+    return false;
   }
 };
 
@@ -903,6 +914,25 @@ class MPMCQueueBase<Derived<T, Atom, Dynamic>> : boost::noncopyable {
     }
   }
 
+  template <class Clock, typename... Args>
+  bool tryReadUntil(
+      const std::chrono::time_point<Clock>& when,
+      T& elem) noexcept {
+    uint64_t ticket;
+    Slot* slots;
+    size_t cap;
+    int stride;
+    if (tryObtainPromisedPopTicketUntil(ticket, slots, cap, stride, when)) {
+      // we have pre-validated that the ticket won't block, or rather that
+      // it won't block longer than it takes another thread to enqueue an
+      // element on the slot it identifies.
+      dequeueWithTicketBase(ticket, slots, cap, stride, elem);
+      return true;
+    } else {
+      return false;
+    }
+  }
+
   /// If the queue is not empty, dequeues and returns true, otherwise
   /// returns false.  If the matching write is still in progress then this
   /// method may block waiting for it.  If you don't rely on being able
@@ -1112,10 +1142,10 @@ class MPMCQueueBase<Derived<T, Atom, Dynamic>> : boost::noncopyable {
     cap = capacity_;
     stride = stride_;
     while (true) {
-      auto numPops = popTicket_.load(std::memory_order_acquire); // B
-      // n will be negative if pops are pending
-      int64_t n = numPushes - numPops;
       ticket = numPushes;
+      const auto numPops = popTicket_.load(std::memory_order_acquire); // B
+      // n will be negative if pops are pending
+      const int64_t n = numPushes - numPops;
       if (n >= static_cast<ssize_t>(capacity_)) {
         // Full, linearize at B.  We don't need to recheck the read we
         // performed at A, because if numPushes was stale at B then the
@@ -1154,6 +1184,37 @@ class MPMCQueueBase<Derived<T, Atom, Dynamic>> : boost::noncopyable {
     }
   }
 
+  /// Tries until when to obtain a pop ticket for which
+  /// SingleElementQueue::dequeue won't block.  Returns true on success, false
+  /// on failure.
+  /// ticket is filled on success AND failure.
+  template <class Clock>
+  bool tryObtainPromisedPopTicketUntil(
+      uint64_t& ticket,
+      Slot*& slots,
+      size_t& cap,
+      int& stride,
+      const std::chrono::time_point<Clock>& when) noexcept {
+    bool deadlineReached = false;
+    while (!deadlineReached) {
+      if (static_cast<Derived<T, Atom, Dynamic>*>(this)
+              ->tryObtainPromisedPopTicket(ticket, slots, cap, stride)) {
+        return true;
+      }
+      // ticket is a blocking ticket until the preceding ticket has been
+      // processed: wait until this ticket's turn arrives. We have not reserved
+      // this ticket so we will have to re-attempt to get a non-blocking ticket
+      // if we wake up before we time-out.
+      deadlineReached =
+          !slots[idx(ticket, cap, stride)].tryWaitForDequeueTurnUntil(
+              turn(ticket, cap),
+              pushSpinCutoff_,
+              (ticket % kAdaptationFreq) == 0,
+              when);
+    }
+    return false;
+  }
+
   /// Similar to tryObtainReadyPopTicket, but returns a pop ticket whose
   /// corresponding push ticket has already been handed out, rather than
   /// returning one whose corresponding push ticket has already been
@@ -1168,8 +1229,12 @@ class MPMCQueueBase<Derived<T, Atom, Dynamic>> : boost::noncopyable {
     uint64_t& ticket, Slot*& slots, size_t& cap, int& stride
   ) noexcept {
     auto numPops = popTicket_.load(std::memory_order_acquire); // A
+    slots = slots_;
+    cap = capacity_;
+    stride = stride_;
     while (true) {
-      auto numPushes = pushTicket_.load(std::memory_order_acquire); // B
+      ticket = numPops;
+      const auto numPushes = pushTicket_.load(std::memory_order_acquire); // B
       if (numPops >= numPushes) {
         // Empty, or empty with pending pops.  Linearize at B.  We don't
         // need to recheck the read we performed at A, because if numPops
@@ -1177,10 +1242,6 @@ class MPMCQueueBase<Derived<T, Atom, Dynamic>> : boost::noncopyable {
         return false;
       }
       if (popTicket_.compare_exchange_strong(numPops, numPops + 1)) {
-        ticket = numPops;
-        slots = slots_;
-        cap = capacity_;
-        stride = stride_;
         return true;
       }
     }
@@ -1275,7 +1336,8 @@ struct SingleElementQueue {
       const bool updateSpinCutoff,
       const std::chrono::time_point<Clock>& when) noexcept {
     return sequencer_.tryWaitForTurn(
-        turn * 2, spinCutoff, updateSpinCutoff, &when);
+               turn * 2, spinCutoff, updateSpinCutoff, &when) !=
+        TurnSequencer<Atom>::TryWaitResult::TIMEDOUT;
   }
 
   bool mayEnqueue(const uint32_t turn) const noexcept {
@@ -1295,6 +1357,21 @@ struct SingleElementQueue {
                                           ImplByMove>::type());
   }
 
+  /// Waits until either:
+  /// 1: the enqueue turn preceding the given dequeue turn has arrived
+  /// 2: the given deadline has arrived
+  /// Case 1 returns true, case 2 returns false.
+  template <class Clock>
+  bool tryWaitForDequeueTurnUntil(
+      const uint32_t turn,
+      Atom<uint32_t>& spinCutoff,
+      const bool updateSpinCutoff,
+      const std::chrono::time_point<Clock>& when) noexcept {
+    return sequencer_.tryWaitForTurn(
+               turn * 2 + 1, spinCutoff, updateSpinCutoff, &when) !=
+        TurnSequencer<Atom>::TryWaitResult::TIMEDOUT;
+  }
+
   bool mayDequeue(const uint32_t turn) const noexcept {
     return sequencer_.isTurn(turn * 2 + 1);
   }
index 193b40b0d8139517e46f65cd7bf25b03403c2fee..b70f258bf1d9780714c7bbfd53ec340ef161c032 100644 (file)
 #pragma once
 
 #include <algorithm>
-#include <assert.h>
 #include <limits>
 
 #include <folly/detail/Futex.h>
 #include <folly/portability/Asm.h>
 #include <folly/portability/Unistd.h>
 
+#include <glog/logging.h>
+
 namespace folly {
 
 namespace detail {
@@ -79,14 +80,15 @@ struct TurnSequencer {
     return decodeCurrentSturn(state) == (turn << kTurnShift);
   }
 
+  enum class TryWaitResult { SUCCESS, PAST, TIMEDOUT };
+
   /// See tryWaitForTurn
   /// Requires that `turn` is not a turn in the past.
   void waitForTurn(const uint32_t turn,
                    Atom<uint32_t>& spinCutoff,
                    const bool updateSpinCutoff) noexcept {
-    bool success = tryWaitForTurn(turn, spinCutoff, updateSpinCutoff);
-    (void)success;
-    assert(success);
+    const auto ret = tryWaitForTurn(turn, spinCutoff, updateSpinCutoff);
+    DCHECK(ret == TryWaitResult::SUCCESS);
   }
 
   // Internally we always work with shifted turn values, which makes the
@@ -98,16 +100,18 @@ struct TurnSequencer {
   /// updateSpinCutoff is true then this will spin for up to kMaxSpins tries
   /// before blocking and will adjust spinCutoff based on the results,
   /// otherwise it will spin for at most spinCutoff spins.
-  /// Returns true if the wait succeeded, false if the turn is in the past
-  /// or the absTime time value is not nullptr and is reached before the turn
-  /// arrives
-  template <class Clock = std::chrono::steady_clock,
-            class Duration = typename Clock::duration>
-  bool tryWaitForTurn(const uint32_t turn,
-                      Atom<uint32_t>& spinCutoff,
-                      const bool updateSpinCutoff,
-                      const std::chrono::time_point<Clock, Duration>* absTime =
-                          nullptr) noexcept {
+  /// Returns SUCCESS if the wait succeeded, PAST if the turn is in the past
+  /// or TIMEDOUT if the absTime time value is not nullptr and is reached before
+  /// the turn arrives
+  template <
+      class Clock = std::chrono::steady_clock,
+      class Duration = typename Clock::duration>
+  TryWaitResult tryWaitForTurn(
+      const uint32_t turn,
+      Atom<uint32_t>& spinCutoff,
+      const bool updateSpinCutoff,
+      const std::chrono::time_point<Clock, Duration>* absTime =
+          nullptr) noexcept {
     uint32_t prevThresh = spinCutoff.load(std::memory_order_relaxed);
     const uint32_t effectiveSpinCutoff =
         updateSpinCutoff || prevThresh == 0 ? kMaxSpins : prevThresh;
@@ -124,7 +128,7 @@ struct TurnSequencer {
       // wrap-safe version of (current_sturn >= sturn)
       if(sturn - current_sturn >= std::numeric_limits<uint32_t>::max() / 2) {
         // turn is in the past
-        return false;
+        return TryWaitResult::PAST;
       }
 
       // the first effectSpinCutoff tries are spins, after that we will
@@ -152,7 +156,7 @@ struct TurnSequencer {
         auto futexResult =
             state_.futexWaitUntil(new_state, *absTime, futexChannel(turn));
         if (futexResult == FutexResult::TIMEDOUT) {
-          return false;
+          return TryWaitResult::TIMEDOUT;
         }
       } else {
         state_.futexWait(new_state, futexChannel(turn));
@@ -184,14 +188,14 @@ struct TurnSequencer {
       }
     }
 
-    return true;
+    return TryWaitResult::SUCCESS;
   }
 
   /// Unblocks a thread running waitForTurn(turn + 1)
   void completeTurn(const uint32_t turn) noexcept {
     uint32_t state = state_.load(std::memory_order_acquire);
     while (true) {
-      assert(state == encode(turn << kTurnShift, decodeMaxWaitersDelta(state)));
+      DCHECK(state == encode(turn << kTurnShift, decodeMaxWaitersDelta(state)));
       uint32_t max_waiter_delta = decodeMaxWaitersDelta(state);
       uint32_t new_state =
           encode((turn + 1) << kTurnShift,
index 1a2bc464c7c17bc2a16a481f11b787790794a7a0..d117926a7c91f3be99f107519e2c29918764ef66 100644 (file)
@@ -204,7 +204,8 @@ public:
   bool waitAndTryRead(T& dest, uint32_t turn) noexcept {
     uint32_t desired_turn = (turn + 1) * 2;
     Atom<uint32_t> cutoff(0);
-    if(!sequencer_.tryWaitForTurn(desired_turn, cutoff, false)) {
+    if (sequencer_.tryWaitForTurn(desired_turn, cutoff, false) !=
+        TurnSequencer<Atom>::TryWaitResult::SUCCESS) {
       return false;
     }
     memcpy(&dest, &data, sizeof(T));
index bc01935a8a0156829ca8eed9a7717fe37eb5f64d..e16071a38f70dca8b158cefb2004394abc3e9b6f 100644 (file)
@@ -271,6 +271,13 @@ struct custom_stop_watch {
     return true;
   }
 
+  /**
+   * Returns the current checkpoint
+   */
+  typename clock_type::time_point getCheckpoint() const {
+    return checkpoint_;
+  }
+
  private:
   typename clock_type::time_point checkpoint_;
 };
index cc9e8c123460ddc09109f0d2c2cd579eb2edfa39..b98138a8e215c4b8ebb1a4c9a679f91a6bb88573 100644 (file)
  * limitations under the License.
  */
 
-#include <folly/MPMCQueue.h>
 #include <folly/Format.h>
+#include <folly/MPMCQueue.h>
 #include <folly/Memory.h>
 #include <folly/portability/GTest.h>
 #include <folly/portability/SysResource.h>
 #include <folly/portability/SysTime.h>
 #include <folly/portability/Unistd.h>
+#include <folly/stop_watch.h>
 #include <folly/test/DeterministicSchedule.h>
 
 #include <boost/intrusive_ptr.hpp>
-#include <memory>
+#include <boost/thread/barrier.hpp>
 #include <functional>
+#include <memory>
 #include <thread>
 #include <utility>
 
@@ -1158,3 +1160,92 @@ TEST(MPMCQueue, explicit_zero_capacity_fail) {
   using DynamicMPMCQueueInt = MPMCQueue<int, std::atomic, true>;
   ASSERT_THROW(DynamicMPMCQueueInt cq(0), std::invalid_argument);
 }
+
+template <bool Dynamic>
+void testTryReadUntil() {
+  MPMCQueue<int, std::atomic, Dynamic> q{1};
+
+  const auto wait = std::chrono::milliseconds(100);
+  stop_watch<> watch;
+  bool rets[2];
+  int vals[2];
+  std::vector<std::thread> threads;
+  boost::barrier b{3};
+  for (int i = 0; i < 2; i++) {
+    threads.emplace_back([&, i] {
+      b.wait();
+      rets[i] = q.tryReadUntil(watch.getCheckpoint() + wait, vals[i]);
+    });
+  }
+
+  b.wait();
+  EXPECT_TRUE(q.write(42));
+
+  for (int i = 0; i < 2; i++) {
+    threads[i].join();
+  }
+
+  for (int i = 0; i < 2; i++) {
+    int other = (i + 1) % 2;
+    if (rets[i]) {
+      EXPECT_EQ(42, vals[i]);
+      EXPECT_FALSE(rets[other]);
+    }
+  }
+
+  EXPECT_TRUE(watch.elapsed(wait));
+}
+
+template <bool Dynamic>
+void testTryWriteUntil() {
+  MPMCQueue<int, std::atomic, Dynamic> q{1};
+  EXPECT_TRUE(q.write(42));
+
+  const auto wait = std::chrono::milliseconds(100);
+  stop_watch<> watch;
+  bool rets[2];
+  std::vector<std::thread> threads;
+  boost::barrier b{3};
+  for (int i = 0; i < 2; i++) {
+    threads.emplace_back([&, i] {
+      b.wait();
+      rets[i] = q.tryWriteUntil(watch.getCheckpoint() + wait, i);
+    });
+  }
+
+  b.wait();
+  int x;
+  EXPECT_TRUE(q.read(x));
+  EXPECT_EQ(42, x);
+
+  for (int i = 0; i < 2; i++) {
+    threads[i].join();
+  }
+  EXPECT_TRUE(q.read(x));
+
+  for (int i = 0; i < 2; i++) {
+    int other = (i + 1) % 2;
+    if (rets[i]) {
+      EXPECT_EQ(i, x);
+      EXPECT_FALSE(rets[other]);
+    }
+  }
+
+  EXPECT_TRUE(watch.elapsed(wait));
+}
+
+TEST(MPMCQueue, try_read_until) {
+  testTryReadUntil<false>();
+}
+
+TEST(MPMCQueue, try_read_until_dynamic) {
+  testTryReadUntil<true>();
+}
+
+TEST(MPMCQueue, try_write_until) {
+  testTryWriteUntil<false>();
+}
+
+TEST(MPMCQueue, try_write_until_dynamic) {
+  testTryWriteUntil<true>();
+}