Add unit test for timeout=0
[folly.git] / folly / io / async / NotificationQueue.h
index f817fbdc8e24bead9d9e8fc516131e4fa4f71c72..e1b1cb36fe29b3ea4dd4149d28fd1d517b13d0bf 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2014 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 
 #pragma once
 
-#include <fcntl.h>
-#include <unistd.h>
+#include <sys/types.h>
 
-#include <folly/io/PortableSpinLock.h>
+#include <algorithm>
+#include <deque>
+#include <iterator>
+#include <memory>
+#include <stdexcept>
+#include <utility>
+
+#include <folly/Exception.h>
+#include <folly/FileUtil.h>
+#include <folly/Likely.h>
+#include <folly/ScopeGuard.h>
+#include <folly/SpinLock.h>
+#include <folly/io/async/DelayedDestruction.h>
 #include <folly/io/async/EventBase.h>
 #include <folly/io/async/EventHandler.h>
 #include <folly/io/async/Request.h>
-#include <folly/Likely.h>
-#include <folly/ScopeGuard.h>
+#include <folly/portability/Fcntl.h>
+#include <folly/portability/Sockets.h>
+#include <folly/portability/Unistd.h>
 
 #include <glog/logging.h>
-#include <deque>
 
 #if __linux__ && !__ANDROID__
 #define FOLLY_HAVE_EVENTFD
@@ -54,13 +65,13 @@ namespace folly {
  * spinning trying to move a message off the queue and failing, and then
  * retrying.
  */
-template<typename MessageT>
+template <typename MessageT>
 class NotificationQueue {
  public:
   /**
    * A callback interface for consuming messages from the queue as they arrive.
    */
-  class Consumer : private EventHandler {
+  class Consumer : public DelayedDestruction, private EventHandler {
    public:
     enum : uint16_t { kDefaultMaxReadAtOnce = 10 };
 
@@ -69,13 +80,16 @@ class NotificationQueue {
         destroyedFlagPtr_(nullptr),
         maxReadAtOnce_(kDefaultMaxReadAtOnce) {}
 
-    virtual ~Consumer();
+    // create a consumer in-place, without the need to build new class
+    template <typename TCallback>
+    static std::unique_ptr<Consumer, DelayedDestruction::Destructor> make(
+        TCallback&& callback);
 
     /**
      * messageAvailable() will be invoked whenever a new
      * message is available from the pipe.
      */
-    virtual void messageAvailable(MessageT&& message) = 0;
+    virtual void messageAvailable(MessageT&& message) noexcept = 0;
 
     /**
      * Begin consuming messages from the specified queue.
@@ -111,6 +125,17 @@ class NotificationQueue {
      */
     void stopConsuming();
 
+    /**
+     * Consume messages off the queue until it is empty. No messages may be
+     * added to the queue while it is draining, so that the process is bounded.
+     * To that end, putMessage/tryPutMessage will throw an std::runtime_error,
+     * and tryPutMessageNoThrow will return false.
+     *
+     * @returns true if the queue was drained, false otherwise. In practice,
+     * this will only fail if someone else is already draining the queue.
+     */
+    bool consumeUntilDrained(size_t* numConsumed = nullptr) noexcept;
+
     /**
      * Get the NotificationQueue that this consumer is currently consuming
      * messages from.  Returns nullptr if the consumer is not currently
@@ -141,9 +166,26 @@ class NotificationQueue {
       return base_;
     }
 
-    virtual void handlerReady(uint16_t events) noexcept;
+    void handlerReady(uint16_t events) noexcept override;
+
+   protected:
+
+    void destroy() override;
+
+    ~Consumer() override {}
 
    private:
+    /**
+     * Consume messages off the the queue until
+     *   - the queue is empty (1), or
+     *   - until the consumer is destroyed, or
+     *   - until the consumer is uninstalled, or
+     *   - an exception is thrown in the course of dequeueing, or
+     *   - unless isDrain is true, until the maxReadAtOnce_ limit is hit
+     *
+     * (1) Well, maybe. See logic/comments around "wasEmpty" in implementation.
+     */
+    void consumeMessages(bool isDrain, size_t* numConsumed = nullptr) noexcept;
 
     void setActive(bool active, bool shouldLock = false) {
       if (!queue_) {
@@ -172,6 +214,24 @@ class NotificationQueue {
     bool active_{false};
   };
 
+  class SimpleConsumer {
+   public:
+    explicit SimpleConsumer(NotificationQueue& queue) : queue_(queue) {
+      ++queue_.numConsumers_;
+    }
+
+    ~SimpleConsumer() {
+      --queue_.numConsumers_;
+    }
+
+    int getFd() const {
+      return queue_.eventfd_ >= 0 ? queue_.eventfd_ : queue_.pipeFds_[0];
+    }
+
+   private:
+    NotificationQueue& queue_;
+  };
+
   enum class FdType {
     PIPE,
 #ifdef FOLLY_HAVE_EVENTFD
@@ -200,17 +260,17 @@ class NotificationQueue {
 #else
                              FdType fdType = FdType::PIPE)
 #endif
-    : eventfd_(-1),
-      pipeFds_{-1, -1},
-      advisoryMaxQueueSize_(maxSize),
-      pid_(getpid()),
-      queue_() {
+      : eventfd_(-1),
+        pipeFds_{-1, -1},
+        advisoryMaxQueueSize_(maxSize),
+        pid_(pid_t(getpid())),
+        queue_() {
 
-    RequestContext::getStaticContext();
+    RequestContext::saveContext();
 
 #ifdef FOLLY_HAVE_EVENTFD
     if (fdType == FdType::EVENTFD) {
-      eventfd_ = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK | EFD_SEMAPHORE);
+      eventfd_ = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK);
       if (eventfd_ == -1) {
         if (errno == ENOSYS || errno == EINVAL) {
           // eventfd not availalble
@@ -282,6 +342,8 @@ class NotificationQueue {
    * If the queue is full, a std::overflow_error will be thrown.  The
    * setMaxQueueSize() function controls the maximum queue size.
    *
+   * If the queue is currently draining, an std::runtime_error will be thrown.
+   *
    * This method may contend briefly on a spinlock if many threads are
    * concurrently accessing the queue, but for all intents and purposes it will
    * immediately place the message on the queue and return.
@@ -290,25 +352,23 @@ class NotificationQueue {
    * may throw any other exception thrown by the MessageT move/copy
    * constructor.
    */
-  void tryPutMessage(MessageT&& message) {
-    putMessageImpl(std::move(message), advisoryMaxQueueSize_);
-  }
-  void tryPutMessage(const MessageT& message) {
-    putMessageImpl(message, advisoryMaxQueueSize_);
+  template <typename MessageTT>
+  void tryPutMessage(MessageTT&& message) {
+    putMessageImpl(std::forward<MessageTT>(message), advisoryMaxQueueSize_);
   }
 
   /**
    * No-throw versions of the above.  Instead returns true on success, false on
    * failure.
    *
-   * Only std::overflow_error is prevented from being thrown (since this is the
-   * common exception case), user code must still catch std::bad_alloc errors.
+   * Only std::overflow_error (the common exception case) and std::runtime_error
+   * (which indicates that the queue is being drained) are prevented from being
+   * thrown. User code must still catch std::bad_alloc errors.
    */
-  bool tryPutMessageNoThrow(MessageT&& message) {
-    return putMessageImpl(std::move(message), advisoryMaxQueueSize_, false);
-  }
-  bool tryPutMessageNoThrow(const MessageT& message) {
-    return putMessageImpl(message, advisoryMaxQueueSize_, false);
+  template <typename MessageTT>
+  bool tryPutMessageNoThrow(MessageTT&& message) {
+    return putMessageImpl(
+        std::forward<MessageTT>(message), advisoryMaxQueueSize_, false);
   }
 
   /**
@@ -318,20 +378,20 @@ class NotificationQueue {
    * and always puts the message on the queue, even if the maximum queue size
    * would be exceeded.
    *
-   * putMessage() may throw std::bad_alloc if memory allocation fails, and may
-   * throw any other exception thrown by the MessageT move/copy constructor.
+   * putMessage() may throw
+   *   - std::bad_alloc if memory allocation fails, and may
+   *   - std::runtime_error if the queue is currently draining
+   *   - any other exception thrown by the MessageT move/copy constructor.
    */
-  void putMessage(MessageT&& message) {
-    putMessageImpl(std::move(message), 0);
-  }
-  void putMessage(const MessageT& message) {
-    putMessageImpl(message, 0);
+  template <typename MessageTT>
+  void putMessage(MessageTT&& message) {
+    putMessageImpl(std::forward<MessageTT>(message), 0);
   }
 
   /**
    * Put several messages on the queue.
    */
-  template<typename InputIteratorT>
+  template <typename InputIteratorT>
   void putMessages(InputIteratorT first, InputIteratorT last) {
     typedef typename std::iterator_traits<InputIteratorT>::iterator_category
       IterCategory;
@@ -348,35 +408,27 @@ class NotificationQueue {
    * unmodified.
    */
   bool tryConsume(MessageT& result) {
+    SCOPE_EXIT { syncSignalAndQueue(); };
+
     checkPid();
 
-    try {
+    folly::SpinLockGuard g(spinlock_);
 
-      folly::io::PortableSpinLockGuard g(spinlock_);
+    if (UNLIKELY(queue_.empty())) {
+      return false;
+    }
 
-      if (UNLIKELY(queue_.empty())) {
-        return false;
-      }
+    auto& data = queue_.front();
+    result = std::move(data.first);
+    RequestContext::setContext(std::move(data.second));
 
-      auto data = std::move(queue_.front());
-      result = data.first;
-      RequestContext::setContext(data.second);
-
-      queue_.pop_front();
-    } catch (...) {
-      // Handle an exception if the assignment operator happens to throw.
-      // We consumed an event but weren't able to pop the message off the
-      // queue.  Signal the event again since the message is still in the
-      // queue.
-      signalEvent(1);
-      throw;
-    }
+    queue_.pop_front();
 
     return true;
   }
 
-  int size() {
-    folly::io::PortableSpinLockGuard g(spinlock_);
+  size_t size() const {
+    folly::SpinLockGuard g(spinlock_);
     return queue_.size();
   }
 
@@ -393,9 +445,7 @@ class NotificationQueue {
    * check ensures that we catch the problem in the misbehaving child process
    * code, and crash before signalling the parent process.
    */
-  void checkPid() const {
-    CHECK_EQ(pid_, getpid());
-  }
+  void checkPid() const { CHECK_EQ(pid_, pid_t(getpid())); }
 
  private:
   // Forbidden copy constructor and assignment operator
@@ -403,7 +453,7 @@ class NotificationQueue {
   NotificationQueue& operator=(NotificationQueue const &) = delete;
 
   inline bool checkQueueSize(size_t maxSize, bool throws=true) const {
-    DCHECK(0 == spinlock_.trylock());
+    DCHECK(0 == spinlock_.try_lock());
     if (maxSize > 0 && queue_.size() >= maxSize) {
       if (throws) {
         throw std::overflow_error("unable to add message to NotificationQueue: "
@@ -414,125 +464,155 @@ class NotificationQueue {
     return true;
   }
 
-  inline void signalEvent(size_t numAdded = 1) const {
-    static const uint8_t kPipeMessage[] = {
-      1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
-    };
+  inline bool checkDraining(bool throws=true) {
+    if (UNLIKELY(draining_ && throws)) {
+      throw std::runtime_error("queue is draining, cannot add message");
+    }
+    return draining_;
+  }
+
+#ifdef __ANDROID__
+  // TODO 10860938 Remove after figuring out crash
+  mutable std::atomic<int> eventBytes_{0};
+  mutable std::atomic<int> maxEventBytes_{0};
+#endif
+
+  void ensureSignalLocked() const {
+    // semantics: empty fd == empty queue <=> !signal_
+    if (signal_) {
+      return;
+    }
 
     ssize_t bytes_written = 0;
-    ssize_t bytes_expected = 0;
-    if (eventfd_ >= 0) {
-      // eventfd(2) dictates that we must write a 64-bit integer
-      uint64_t numAdded64(numAdded);
-      bytes_expected = static_cast<ssize_t>(sizeof(numAdded64));
-      bytes_written = ::write(eventfd_, &numAdded64, sizeof(numAdded64));
-    } else {
-      // pipe semantics, add one message for each numAdded
-      bytes_expected = numAdded;
-      do {
-        size_t messageSize = std::min(numAdded, sizeof(kPipeMessage));
-        ssize_t rc = ::write(pipeFds_[1], kPipeMessage, messageSize);
-        if (rc < 0) {
-          // TODO: if the pipe is full, write will fail with EAGAIN.
-          // See task #1044651 for how this could be handled
-          break;
-        }
-        numAdded -= rc;
-        bytes_written += rc;
-      } while (numAdded > 0);
+    size_t bytes_expected = 0;
+
+    do {
+      if (eventfd_ >= 0) {
+        // eventfd(2) dictates that we must write a 64-bit integer
+        uint64_t signal = 1;
+        bytes_expected = sizeof(signal);
+        bytes_written = ::write(eventfd_, &signal, bytes_expected);
+      } else {
+        uint8_t signal = 1;
+        bytes_expected = sizeof(signal);
+        bytes_written = ::write(pipeFds_[1], &signal, bytes_expected);
+      }
+    } while (bytes_written == -1 && errno == EINTR);
+
+#ifdef __ANDROID__
+    if (bytes_written > 0) {
+      eventBytes_ += bytes_written;
+      maxEventBytes_ = std::max((int)maxEventBytes_, (int)eventBytes_);
     }
-    if (bytes_written != bytes_expected) {
+#endif
+
+    if (bytes_written == ssize_t(bytes_expected)) {
+      signal_ = true;
+    } else {
+#ifdef __ANDROID__
+      LOG(ERROR) << "NotificationQueue Write Error=" << errno
+                 << " bytesInPipe=" << eventBytes_
+                 << " maxInPipe=" << maxEventBytes_ << " queue=" << size();
+#endif
       folly::throwSystemError("failed to signal NotificationQueue after "
                               "write", errno);
     }
   }
 
-  bool tryConsumeEvent() {
-    uint64_t value = 0;
-    ssize_t rc = -1;
-    if (eventfd_ >= 0) {
-      rc = ::read(eventfd_, &value, sizeof(value));
+  void drainSignalsLocked() {
+    ssize_t bytes_read = 0;
+    if (eventfd_ > 0) {
+      uint64_t message;
+      bytes_read = readNoInt(eventfd_, &message, sizeof(message));
+      CHECK(bytes_read != -1 || errno == EAGAIN);
     } else {
-      uint8_t value8;
-      rc = ::read(pipeFds_[0], &value8, sizeof(value8));
-      value = value8;
+      // There should only be one byte in the pipe. To avoid potential leaks we still drain.
+      uint8_t message[32];
+      ssize_t result;
+      while ((result = readNoInt(pipeFds_[0], &message, sizeof(message))) != -1) {
+        bytes_read += result;
+      }
+      CHECK(result == -1 && errno == EAGAIN);
+      LOG_IF(ERROR, bytes_read > 1)
+        << "[NotificationQueue] Unexpected state while draining pipe: bytes_read="
+        << bytes_read << " bytes, expected <= 1";
     }
-    if (rc < 0) {
-      // EAGAIN should pretty much be the only error we can ever get.
-      // This means someone else already processed the only available message.
-      assert(errno == EAGAIN);
-      return false;
+    LOG_IF(ERROR, (signal_ && bytes_read == 0) || (!signal_ && bytes_read > 0))
+      << "[NotificationQueue] Unexpected state while draining signals: signal_="
+      << signal_ << " bytes_read=" << bytes_read;
+
+    signal_ = false;
+
+#ifdef __ANDROID__
+    if (bytes_read > 0) {
+      eventBytes_ -= bytes_read;
     }
-    assert(value == 1);
-    return true;
+#endif
   }
 
-  bool putMessageImpl(MessageT&& message, size_t maxSize, bool throws=true) {
-    checkPid();
-    bool signal = false;
-    {
-      folly::io::PortableSpinLockGuard g(spinlock_);
-      if (!checkQueueSize(maxSize, throws)) {
-        return false;
-      }
-      // We only need to signal an event if not all consumers are
-      // awake.
-      if (numActiveConsumers_ < numConsumers_) {
-        signal = true;
-      }
-      queue_.push_back(
-        std::make_pair(std::move(message),
-                       RequestContext::saveContext()));
-    }
-    if (signal) {
-      signalEvent();
+  void ensureSignal() const {
+    folly::SpinLockGuard g(spinlock_);
+    ensureSignalLocked();
+  }
+
+  void syncSignalAndQueue() {
+    folly::SpinLockGuard g(spinlock_);
+
+    if (queue_.empty()) {
+      drainSignalsLocked();
+    } else {
+      ensureSignalLocked();
     }
-    return true;
   }
 
-  bool putMessageImpl(
-    const MessageT& message, size_t maxSize, bool throws=true) {
+  template <typename MessageTT>
+  bool putMessageImpl(MessageTT&& message, size_t maxSize, bool throws = true) {
     checkPid();
     bool signal = false;
     {
-      folly::io::PortableSpinLockGuard g(spinlock_);
-      if (!checkQueueSize(maxSize, throws)) {
+      folly::SpinLockGuard g(spinlock_);
+      if (checkDraining(throws) || !checkQueueSize(maxSize, throws)) {
         return false;
       }
+      // We only need to signal an event if not all consumers are
+      // awake.
       if (numActiveConsumers_ < numConsumers_) {
         signal = true;
       }
-      queue_.push_back(std::make_pair(message, RequestContext::saveContext()));
-    }
-    if (signal) {
-      signalEvent();
+      queue_.emplace_back(
+          std::forward<MessageTT>(message), RequestContext::saveContext());
+      if (signal) {
+        ensureSignalLocked();
+      }
     }
     return true;
   }
 
-  template<typename InputIteratorT>
+  template <typename InputIteratorT>
   void putMessagesImpl(InputIteratorT first, InputIteratorT last,
                        std::input_iterator_tag) {
     checkPid();
     bool signal = false;
     size_t numAdded = 0;
     {
-      folly::io::PortableSpinLockGuard g(spinlock_);
+      folly::SpinLockGuard g(spinlock_);
+      checkDraining();
       while (first != last) {
-        queue_.push_back(std::make_pair(*first, RequestContext::saveContext()));
+        queue_.emplace_back(*first, RequestContext::saveContext());
         ++first;
         ++numAdded;
       }
       if (numActiveConsumers_ < numConsumers_) {
         signal = true;
       }
-    }
-    if (signal) {
-      signalEvent();
+      if (signal) {
+        ensureSignalLocked();
+      }
     }
   }
 
-  mutable folly::io::PortableSpinLock spinlock_;
+  mutable folly::SpinLock spinlock_;
+  mutable bool signal_{false};
   int eventfd_;
   int pipeFds_[2]; // to fallback to on older/non-linux systems
   uint32_t advisoryMaxQueueSize_;
@@ -540,10 +620,11 @@ class NotificationQueue {
   std::deque<std::pair<MessageT, std::shared_ptr<RequestContext>>> queue_;
   int numConsumers_{0};
   std::atomic<int> numActiveConsumers_{0};
+  bool draining_{false};
 };
 
-template<typename MessageT>
-NotificationQueue<MessageT>::Consumer::~Consumer() {
+template <typename MessageT>
+void NotificationQueue<MessageT>::Consumer::destroy() {
   // If we are in the middle of a call to handlerReady(), destroyedFlagPtr_
   // will be non-nullptr.  Mark the value that it points to, so that
   // handlerReady() will know the callback is destroyed, and that it cannot
@@ -551,27 +632,34 @@ NotificationQueue<MessageT>::Consumer::~Consumer() {
   if (destroyedFlagPtr_) {
     *destroyedFlagPtr_ = true;
   }
+  stopConsuming();
+  DelayedDestruction::destroy();
 }
 
-template<typename MessageT>
-void NotificationQueue<MessageT>::Consumer::handlerReady(uint16_t events)
+template <typename MessageT>
+void NotificationQueue<MessageT>::Consumer::handlerReady(uint16_t /*events*/)
     noexcept {
+  consumeMessages(false);
+}
+
+template <typename MessageT>
+void NotificationQueue<MessageT>::Consumer::consumeMessages(
+    bool isDrain, size_t* numConsumed) noexcept {
+  DestructorGuard dg(this);
   uint32_t numProcessed = 0;
-  bool firstRun = true;
   setActive(true);
+  SCOPE_EXIT {
+    if (queue_) {
+      queue_->syncSignalAndQueue();
+    }
+  };
   SCOPE_EXIT { setActive(false, /* shouldLock = */ true); };
-  while (true) {
-    // Try to decrement the eventfd.
-    //
-    // The eventfd is only used to wake up the consumer - there may or
-    // may not actually be an event available (another consumer may
-    // have read it).  We don't really care, we only care about
-    // emptying the queue.
-    if (firstRun) {
-      queue_->tryConsumeEvent();
-      firstRun = false;
+  SCOPE_EXIT {
+    if (numConsumed != nullptr) {
+      *numConsumed = numProcessed;
     }
-
+  };
+  while (true) {
     // Now pop the message off of the queue.
     //
     // We have to manually acquire and release the spinlock here, rather than
@@ -595,8 +683,7 @@ void NotificationQueue<MessageT>::Consumer::handlerReady(uint16_t events)
       auto& data = queue_->queue_.front();
 
       MessageT msg(std::move(data.first));
-      auto old_ctx =
-        RequestContext::setContext(data.second);
+      RequestContextScopeGuard rctx(std::move(data.second));
       queue_->queue_.pop_front();
 
       // Check to see if the queue is empty now.
@@ -616,14 +703,12 @@ void NotificationQueue<MessageT>::Consumer::handlerReady(uint16_t events)
       CHECK(destroyedFlagPtr_ == nullptr);
       destroyedFlagPtr_ = &callbackDestroyed;
       messageAvailable(std::move(msg));
-
-      RequestContext::setContext(old_ctx);
+      destroyedFlagPtr_ = nullptr;
 
       // If the callback was destroyed before it returned, we are done
       if (callbackDestroyed) {
         return;
       }
-      destroyedFlagPtr_ = nullptr;
 
       // If the callback is no longer installed, we are done.
       if (queue_ == nullptr) {
@@ -632,8 +717,8 @@ void NotificationQueue<MessageT>::Consumer::handlerReady(uint16_t events)
 
       // If we have hit maxReadAtOnce_, we are done.
       ++numProcessed;
-      if (maxReadAtOnce_ > 0 && numProcessed >= maxReadAtOnce_) {
-        queue_->signalEvent(1);
+      if (!isDrain && maxReadAtOnce_ > 0 &&
+          numProcessed >= maxReadAtOnce_) {
         return;
       }
 
@@ -646,7 +731,7 @@ void NotificationQueue<MessageT>::Consumer::handlerReady(uint16_t events)
       if (wasEmpty) {
         return;
       }
-    } catch (const std::exception& ex) {
+    } catch (const std::exception&) {
       // This catch block is really just to handle the case where the MessageT
       // constructor throws.  The messageAvailable() callback itself is
       // declared as noexcept and should never throw.
@@ -660,10 +745,6 @@ void NotificationQueue<MessageT>::Consumer::handlerReady(uint16_t events)
       if (locked) {
         // Unlock the spinlock.
         queue_->spinlock_.unlock();
-
-        // Push a notification back on the eventfd since we didn't actually
-        // read the message off of the queue.
-        queue_->signalEvent(1);
       }
 
       return;
@@ -671,11 +752,11 @@ void NotificationQueue<MessageT>::Consumer::handlerReady(uint16_t events)
   }
 }
 
-template<typename MessageT>
+template <typename MessageT>
 void NotificationQueue<MessageT>::Consumer::init(
     EventBase* eventBase,
     NotificationQueue* queue) {
-  assert(eventBase->isInEventBaseThread());
+  eventBase->dcheckIsInEventBaseThread();
   assert(queue_ == nullptr);
   assert(!isHandlerRegistered());
   queue->checkPid();
@@ -685,10 +766,10 @@ void NotificationQueue<MessageT>::Consumer::init(
   queue_ = queue;
 
   {
-    folly::io::PortableSpinLockGuard g(queue_->spinlock_);
+    folly::SpinLockGuard g(queue_->spinlock_);
     queue_->numConsumers_++;
   }
-  queue_->signalEvent();
+  queue_->ensureSignal();
 
   if (queue_->eventfd_ >= 0) {
     initHandler(eventBase, queue_->eventfd_);
@@ -697,7 +778,7 @@ void NotificationQueue<MessageT>::Consumer::init(
   }
 }
 
-template<typename MessageT>
+template <typename MessageT>
 void NotificationQueue<MessageT>::Consumer::stopConsuming() {
   if (queue_ == nullptr) {
     assert(!isHandlerRegistered());
@@ -705,7 +786,7 @@ void NotificationQueue<MessageT>::Consumer::stopConsuming() {
   }
 
   {
-    folly::io::PortableSpinLockGuard g(queue_->spinlock_);
+    folly::SpinLockGuard g(queue_->spinlock_);
     queue_->numConsumers_--;
     setActive(false);
   }
@@ -716,4 +797,68 @@ void NotificationQueue<MessageT>::Consumer::stopConsuming() {
   queue_ = nullptr;
 }
 
-} // folly
+template <typename MessageT>
+bool NotificationQueue<MessageT>::Consumer::consumeUntilDrained(
+    size_t* numConsumed) noexcept {
+  DestructorGuard dg(this);
+  {
+    folly::SpinLockGuard g(queue_->spinlock_);
+    if (queue_->draining_) {
+      return false;
+    }
+    queue_->draining_ = true;
+  }
+  consumeMessages(true, numConsumed);
+  {
+    folly::SpinLockGuard g(queue_->spinlock_);
+    queue_->draining_ = false;
+  }
+  return true;
+}
+
+/**
+ * Creates a NotificationQueue::Consumer wrapping a function object
+ * Modeled after AsyncTimeout::make
+ *
+ */
+
+namespace detail {
+
+template <typename MessageT, typename TCallback>
+struct notification_queue_consumer_wrapper
+    : public NotificationQueue<MessageT>::Consumer {
+
+  template <typename UCallback>
+  explicit notification_queue_consumer_wrapper(UCallback&& callback)
+      : callback_(std::forward<UCallback>(callback)) {}
+
+  // we are being stricter here and requiring noexcept for callback
+  void messageAvailable(MessageT&& message) noexcept override {
+    static_assert(
+      noexcept(std::declval<TCallback>()(std::forward<MessageT>(message))),
+      "callback must be declared noexcept, e.g.: `[]() noexcept {}`"
+    );
+
+    callback_(std::forward<MessageT>(message));
+  }
+
+ private:
+  TCallback callback_;
+};
+
+} // namespace detail
+
+template <typename MessageT>
+template <typename TCallback>
+std::unique_ptr<typename NotificationQueue<MessageT>::Consumer,
+                DelayedDestruction::Destructor>
+NotificationQueue<MessageT>::Consumer::make(TCallback&& callback) {
+  return std::unique_ptr<NotificationQueue<MessageT>::Consumer,
+                         DelayedDestruction::Destructor>(
+      new detail::notification_queue_consumer_wrapper<
+          MessageT,
+          typename std::decay<TCallback>::type>(
+          std::forward<TCallback>(callback)));
+}
+
+} // namespace folly