AsyncIO::cancel
[folly.git] / folly / experimental / io / AsyncIO.cpp
index 9128d8b3d88b876b4f8190762f5ed1c16322b98a..e5674bff9e1915fd2eaf01126fb10646e4cff786 100644 (file)
@@ -66,6 +66,11 @@ void AsyncIOOp::complete(ssize_t result) {
   }
 }
 
+void AsyncIOOp::cancel() {
+  DCHECK_EQ(state_, State::PENDING);
+  state_ = State::CANCELED;
+}
+
 ssize_t AsyncIOOp::result() const {
   CHECK_EQ(state_, State::COMPLETED);
   return result_;
@@ -104,13 +109,7 @@ void AsyncIOOp::init() {
   state_ = State::INITIALIZED;
 }
 
-AsyncIO::AsyncIO(size_t capacity, PollMode pollMode)
-  : ctx_(0),
-    ctxSet_(false),
-    pending_(0),
-    submitted_(0),
-    capacity_(capacity),
-    pollFd_(-1) {
+AsyncIO::AsyncIO(size_t capacity, PollMode pollMode) : capacity_(capacity) {
   CHECK_GT(capacity_, 0);
   completed_.reserve(capacity_);
   if (pollMode == POLLABLE) {
@@ -194,7 +193,15 @@ Range<AsyncIO::Op**> AsyncIO::wait(size_t minRequests) {
   CHECK_EQ(pollFd_, -1) << "wait() only allowed on non-pollable object";
   auto p = pending_.load(std::memory_order_acquire);
   CHECK_LE(minRequests, p);
-  return doWait(minRequests, p);
+  doWait(WaitType::COMPLETE, minRequests, p, &completed_);
+  return Range<Op**>(completed_.data(), completed_.size());
+}
+
+size_t AsyncIO::cancel() {
+  CHECK(ctx_);
+  auto p = pending_.load(std::memory_order_acquire);
+  doWait(WaitType::CANCEL, p, p, nullptr);
+  return p;
 }
 
 Range<AsyncIO::Op**> AsyncIO::pollCompleted() {
@@ -217,12 +224,19 @@ Range<AsyncIO::Op**> AsyncIO::pollCompleted() {
   DCHECK_LE(numEvents, pending_);
 
   // Don't reap more than numEvents, as we've just reset the counter to 0.
-  return doWait(numEvents, numEvents);
+  doWait(WaitType::COMPLETE, numEvents, numEvents, &completed_);
+  return Range<Op**>(completed_.data(), completed_.size());
 }
 
-Range<AsyncIO::Op**> AsyncIO::doWait(size_t minRequests, size_t maxRequests) {
+void AsyncIO::doWait(
+    WaitType type,
+    size_t minRequests,
+    size_t maxRequests,
+    std::vector<Op*>* result) {
   io_event events[maxRequests];
 
+  // Unfortunately, Linux AIO doesn't implement io_cancel, so even for
+  // WaitType::CANCEL we have to wait for IO completion.
   size_t count = 0;
   do {
     int ret;
@@ -237,27 +251,32 @@ Range<AsyncIO::Op**> AsyncIO::doWait(size_t minRequests, size_t maxRequests) {
                          /* timeout */ nullptr);  // wait forever
     } while (ret == -EINTR);
     // Check as may not be able to recover without leaking events.
-    CHECK_GE(ret, 0)
-      << "AsyncIO: io_getevents failed with error " << errnoStr(-ret);
+    CHECK_GE(ret, 0) << "AsyncIO: io_getevents failed with error "
+                     << errnoStr(-ret);
     count += ret;
   } while (count < minRequests);
   DCHECK_LE(count, maxRequests);
 
-  completed_.clear();
-  if (count == 0) {
-    return folly::Range<Op**>();
+  if (result != nullptr) {
+    result->clear();
   }
-
   for (size_t i = 0; i < count; ++i) {
     DCHECK(events[i].obj);
     Op* op = boost::intrusive::get_parent_from_member(
         events[i].obj, &AsyncIOOp::iocb_);
     decrementPending();
-    op->complete(events[i].res);
-    completed_.push_back(op);
+    switch (type) {
+      case WaitType::COMPLETE:
+        op->complete(events[i].res);
+        break;
+      case WaitType::CANCEL:
+        op->cancel();
+        break;
+    }
+    if (result != nullptr) {
+      result->push_back(op);
+    }
   }
-
-  return folly::Range<Op**>(&completed_.front(), count);
 }
 
 AsyncIOQueue::AsyncIOQueue(AsyncIO* asyncIO)
@@ -308,6 +327,7 @@ const char* asyncIoOpStateToString(AsyncIOOp::State state) {
     X(AsyncIOOp::State::INITIALIZED);
     X(AsyncIOOp::State::PENDING);
     X(AsyncIOOp::State::COMPLETED);
+    X(AsyncIOOp::State::CANCELED);
   }
   return "<INVALID AsyncIOOp::State>";
 }