Make folly::AsyncIO thread safe
[folly.git] / folly / experimental / io / AsyncIO.cpp
index 83a552867d8afe370ed0b2c4a1eec90b64051548..83558e83b1b6f6ef506ee9a7e363116b2998d57c 100644 (file)
@@ -19,6 +19,7 @@
 #include <sys/eventfd.h>
 #include <unistd.h>
 #include <cerrno>
+#include <stdexcept>
 #include <string>
 
 #include <boost/intrusive/parent_from_member.hpp>
@@ -104,6 +105,7 @@ void AsyncIOOp::init() {
 
 AsyncIO::AsyncIO(size_t capacity, PollMode pollMode)
   : ctx_(0),
+    ctxSet_(false),
     pending_(0),
     capacity_(capacity),
     pollFd_(-1) {
@@ -126,36 +128,54 @@ AsyncIO::~AsyncIO() {
   }
 }
 
+void AsyncIO::decrementPending() {
+  ssize_t p = pending_.fetch_add(-1, std::memory_order_acq_rel);
+  DCHECK_GE(p, 1);
+}
+
 void AsyncIO::initializeContext() {
-  if (!ctx_) {
-    int rc = io_queue_init(capacity_, &ctx_);
-    // returns negative errno
-    checkKernelError(rc, "AsyncIO: io_queue_init failed");
-    DCHECK(ctx_);
+  if (!ctxSet_.load(std::memory_order_acquire)) {
+    std::lock_guard<std::mutex> lock(initMutex_);
+    if (!ctxSet_.load(std::memory_order_relaxed)) {
+      int rc = io_queue_init(capacity_, &ctx_);
+      // returns negative errno
+      checkKernelError(rc, "AsyncIO: io_queue_init failed");
+      DCHECK(ctx_);
+      ctxSet_.store(true, std::memory_order_release);
+    }
   }
 }
 
 void AsyncIO::submit(Op* op) {
   CHECK_EQ(op->state(), Op::State::INITIALIZED);
-  CHECK_LT(pending_, capacity_) << "too many pending requests";
   initializeContext();  // on demand
+
+  // We can increment past capacity, but we'll clean up after ourselves.
+  ssize_t p = pending_.fetch_add(1, std::memory_order_acq_rel);
+  if (p >= capacity_) {
+    decrementPending();
+    throw std::range_error("AsyncIO: too many pending requests");
+  }
   iocb* cb = &op->iocb_;
   cb->data = nullptr;  // unused
   if (pollFd_ != -1) {
     io_set_eventfd(cb, pollFd_);
   }
   int rc = io_submit(ctx_, 1, &cb);
-  checkKernelError(rc, "AsyncIO: io_submit failed");
+  if (rc < 0) {
+    decrementPending();
+    throwSystemErrorExplicit(-rc, "AsyncIO: io_submit failed");
+  }
   DCHECK_EQ(rc, 1);
   op->start();
-  ++pending_;
 }
 
 Range<AsyncIO::Op**> AsyncIO::wait(size_t minRequests) {
   CHECK(ctx_);
   CHECK_EQ(pollFd_, -1) << "wait() only allowed on non-pollable object";
-  CHECK_LE(minRequests, pending_);
-  return doWait(minRequests, pending_);
+  ssize_t p = pending_.load(std::memory_order_acquire);
+  CHECK_LE(minRequests, p);
+  return doWait(minRequests, p);
 }
 
 Range<AsyncIO::Op**> AsyncIO::pollCompleted() {
@@ -182,7 +202,7 @@ Range<AsyncIO::Op**> AsyncIO::pollCompleted() {
 }
 
 Range<AsyncIO::Op**> AsyncIO::doWait(size_t minRequests, size_t maxRequests) {
-  io_event events[pending_];
+  io_event events[maxRequests];
   int count;
   do {
     // Wait forever
@@ -190,7 +210,7 @@ Range<AsyncIO::Op**> AsyncIO::doWait(size_t minRequests, size_t maxRequests) {
   } while (count == -EINTR);
   checkKernelError(count, "AsyncIO: io_getevents failed");
   DCHECK_GE(count, minRequests);  // the man page says so
-  DCHECK_LE(count, pending_);
+  DCHECK_LE(count, maxRequests);
 
   completed_.clear();
   if (count == 0) {
@@ -201,7 +221,7 @@ Range<AsyncIO::Op**> AsyncIO::doWait(size_t minRequests, size_t maxRequests) {
     DCHECK(events[i].obj);
     Op* op = boost::intrusive::get_parent_from_member(
         events[i].obj, &AsyncIOOp::iocb_);
-    --pending_;
+    decrementPending();
     op->complete(events[i].res);
     completed_.push_back(op);
   }