Make folly::AsyncIO thread safe
authorTudor Bosman <tudorb@fb.com>
Wed, 8 May 2013 21:40:10 +0000 (14:40 -0700)
committerSara Golemon <sgolemon@fb.com>
Mon, 20 May 2013 18:01:27 +0000 (11:01 -0700)
Summary:
You can now submit to the same AsyncIO object from different threads, but you
must still reap from only one thread at a time.

Test Plan: async_io_test, added MT test

Reviewed By: philipp@fb.com

FB internal diff: D804914

folly/experimental/io/AsyncIO.cpp
folly/experimental/io/AsyncIO.h
folly/experimental/io/test/AsyncIOTest.cpp

index 83a552867d8afe370ed0b2c4a1eec90b64051548..83558e83b1b6f6ef506ee9a7e363116b2998d57c 100644 (file)
@@ -19,6 +19,7 @@
 #include <sys/eventfd.h>
 #include <unistd.h>
 #include <cerrno>
+#include <stdexcept>
 #include <string>
 
 #include <boost/intrusive/parent_from_member.hpp>
@@ -104,6 +105,7 @@ void AsyncIOOp::init() {
 
 AsyncIO::AsyncIO(size_t capacity, PollMode pollMode)
   : ctx_(0),
+    ctxSet_(false),
     pending_(0),
     capacity_(capacity),
     pollFd_(-1) {
@@ -126,36 +128,54 @@ AsyncIO::~AsyncIO() {
   }
 }
 
+void AsyncIO::decrementPending() {
+  ssize_t p = pending_.fetch_add(-1, std::memory_order_acq_rel);
+  DCHECK_GE(p, 1);
+}
+
 void AsyncIO::initializeContext() {
-  if (!ctx_) {
-    int rc = io_queue_init(capacity_, &ctx_);
-    // returns negative errno
-    checkKernelError(rc, "AsyncIO: io_queue_init failed");
-    DCHECK(ctx_);
+  if (!ctxSet_.load(std::memory_order_acquire)) {
+    std::lock_guard<std::mutex> lock(initMutex_);
+    if (!ctxSet_.load(std::memory_order_relaxed)) {
+      int rc = io_queue_init(capacity_, &ctx_);
+      // returns negative errno
+      checkKernelError(rc, "AsyncIO: io_queue_init failed");
+      DCHECK(ctx_);
+      ctxSet_.store(true, std::memory_order_release);
+    }
   }
 }
 
 void AsyncIO::submit(Op* op) {
   CHECK_EQ(op->state(), Op::State::INITIALIZED);
-  CHECK_LT(pending_, capacity_) << "too many pending requests";
   initializeContext();  // on demand
+
+  // We can increment past capacity, but we'll clean up after ourselves.
+  ssize_t p = pending_.fetch_add(1, std::memory_order_acq_rel);
+  if (p >= capacity_) {
+    decrementPending();
+    throw std::range_error("AsyncIO: too many pending requests");
+  }
   iocb* cb = &op->iocb_;
   cb->data = nullptr;  // unused
   if (pollFd_ != -1) {
     io_set_eventfd(cb, pollFd_);
   }
   int rc = io_submit(ctx_, 1, &cb);
-  checkKernelError(rc, "AsyncIO: io_submit failed");
+  if (rc < 0) {
+    decrementPending();
+    throwSystemErrorExplicit(-rc, "AsyncIO: io_submit failed");
+  }
   DCHECK_EQ(rc, 1);
   op->start();
-  ++pending_;
 }
 
 Range<AsyncIO::Op**> AsyncIO::wait(size_t minRequests) {
   CHECK(ctx_);
   CHECK_EQ(pollFd_, -1) << "wait() only allowed on non-pollable object";
-  CHECK_LE(minRequests, pending_);
-  return doWait(minRequests, pending_);
+  ssize_t p = pending_.load(std::memory_order_acquire);
+  CHECK_LE(minRequests, p);
+  return doWait(minRequests, p);
 }
 
 Range<AsyncIO::Op**> AsyncIO::pollCompleted() {
@@ -182,7 +202,7 @@ Range<AsyncIO::Op**> AsyncIO::pollCompleted() {
 }
 
 Range<AsyncIO::Op**> AsyncIO::doWait(size_t minRequests, size_t maxRequests) {
-  io_event events[pending_];
+  io_event events[maxRequests];
   int count;
   do {
     // Wait forever
@@ -190,7 +210,7 @@ Range<AsyncIO::Op**> AsyncIO::doWait(size_t minRequests, size_t maxRequests) {
   } while (count == -EINTR);
   checkKernelError(count, "AsyncIO: io_getevents failed");
   DCHECK_GE(count, minRequests);  // the man page says so
-  DCHECK_LE(count, pending_);
+  DCHECK_LE(count, maxRequests);
 
   completed_.clear();
   if (count == 0) {
@@ -201,7 +221,7 @@ Range<AsyncIO::Op**> AsyncIO::doWait(size_t minRequests, size_t maxRequests) {
     DCHECK(events[i].obj);
     Op* op = boost::intrusive::get_parent_from_member(
         events[i].obj, &AsyncIOOp::iocb_);
-    --pending_;
+    decrementPending();
     op->complete(events[i].res);
     completed_.push_back(op);
   }
index 0421796995dce193aa28e4082a3b294c179d894c..83c37f167ae2838905762b70fa0f78371d96ac02 100644 (file)
 #include <sys/uio.h>
 #include <libaio.h>
 
+#include <atomic>
 #include <cstdint>
 #include <deque>
 #include <functional>
+#include <mutex>
 #include <ostream>
 #include <utility>
 #include <vector>
@@ -138,6 +140,12 @@ class AsyncIO : private boost::noncopyable {
    * any IOs on this AsyncIO have completed.  If you do this, you must use
    * pollCompleted() instead of wait() -- do not read from the pollFd()
    * file descriptor directly.
+   *
+   * You may use the same AsyncIO object from multiple threads, as long as
+   * there is only one concurrent caller of wait() / pollCompleted() (perhaps
+   * by always calling it from the same thread, or by providing appropriate
+   * mutual exclusion)  In this case, pending() returns a snapshot
+   * of the current number of pending requests.
    */
   explicit AsyncIO(size_t capacity, PollMode pollMode=NOT_POLLABLE);
   ~AsyncIO();
@@ -180,12 +188,17 @@ class AsyncIO : private boost::noncopyable {
   void submit(Op* op);
 
  private:
+  void decrementPending();
   void initializeContext();
+
   Range<Op**> doWait(size_t minRequests, size_t maxRequests);
 
   io_context_t ctx_;
-  size_t pending_;
-  const size_t capacity_;
+  std::atomic<bool> ctxSet_;
+  std::mutex initMutex_;
+
+  std::atomic<ssize_t> pending_;
+  const ssize_t capacity_;
   int pollFd_;
   std::vector<Op*> completed_;
 };
index cdd9d2837ef367992ab4f03a38aacfa039405e31..f4ee54b6b91b464d1ad70c8993b2c8a24bacd94a 100644 (file)
@@ -25,6 +25,7 @@
 #include <cstdio>
 #include <memory>
 #include <random>
+#include <thread>
 #include <vector>
 
 #include <glog/logging.h>
@@ -152,20 +153,39 @@ void testReadsSerially(const std::vector<TestSpec>& specs,
 }
 
 void testReadsParallel(const std::vector<TestSpec>& specs,
-                       AsyncIO::PollMode pollMode) {
+                       AsyncIO::PollMode pollMode,
+                       bool multithreaded) {
   AsyncIO aioReader(specs.size(), pollMode);
   std::unique_ptr<AsyncIO::Op[]> ops(new AsyncIO::Op[specs.size()]);
   std::vector<ManagedBuffer> bufs;
+  bufs.reserve(specs.size());
 
   int fd = ::open(tempFile.path().c_str(), O_DIRECT | O_RDONLY);
   PCHECK(fd != -1);
   SCOPE_EXIT {
     ::close(fd);
   };
+
+  std::vector<std::thread> threads;
+  if (multithreaded) {
+    threads.reserve(specs.size());
+  }
   for (int i = 0; i < specs.size(); i++) {
     bufs.push_back(allocateAligned(specs[i].size));
+  }
+  auto submit = [&] (int i) {
     ops[i].pread(fd, bufs[i].get(), specs[i].size, specs[i].start);
     aioReader.submit(&ops[i]);
+  };
+  for (int i = 0; i < specs.size(); i++) {
+    if (multithreaded) {
+      threads.emplace_back([&submit, i] { submit(i); });
+    } else {
+      submit(i);
+    }
+  }
+  for (auto& t : threads) {
+    t.join();
   }
   std::vector<bool> pending(specs.size(), true);
 
@@ -249,7 +269,8 @@ void testReadsQueued(const std::vector<TestSpec>& specs,
 void testReads(const std::vector<TestSpec>& specs,
                AsyncIO::PollMode pollMode) {
   testReadsSerially(specs, pollMode);
-  testReadsParallel(specs, pollMode);
+  testReadsParallel(specs, pollMode, false);
+  testReadsParallel(specs, pollMode, true);
   testReadsQueued(specs, pollMode);
 }