AsyncSocket::writeRequest() and its first user wangle::FileRegion
authorJames Sedgwick <jsedgwick@fb.com>
Mon, 8 Jun 2015 15:41:33 +0000 (08:41 -0700)
committerSara Golemon <sgolemon@fb.com>
Tue, 9 Jun 2015 20:20:04 +0000 (13:20 -0700)
Summary: similar to D2050808, but move the functionality into AsyncSocket itself so that you have a consistent interface and contiguous writes for a single file

Test Plan: added unit, will hook this up to a file server example next

Reviewed By: davejwatson@fb.com

Subscribers: fugalh, net-systems@, folly-diffs@, jsedgwick, yfeldblum, chalfant

FB internal diff: D2084452

Signature: t1:2084452:1433181933:175158618966706db00bf6620cc86ae145d04ecf

folly/Makefile.am
folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/test/AsyncSocketTest.h [new file with mode: 0644]
folly/io/async/test/AsyncSocketTest2.cpp
folly/wangle/channel/FileRegion.cpp [new file with mode: 0644]
folly/wangle/channel/FileRegion.h [new file with mode: 0644]
folly/wangle/channel/test/FileRegionTest.cpp [new file with mode: 0644]

index d2377fa5e237652f650314dd15095d01ccef8df7..b937ae93750d66d44d31a51b0eacffd3dc622e1b 100644 (file)
@@ -282,6 +282,7 @@ nobase_follyinclude_HEADERS = \
        wangle/bootstrap/ClientBootstrap.h \
        wangle/channel/AsyncSocketHandler.h \
        wangle/channel/EventBaseHandler.h \
+       wangle/channel/FileRegion.h \
        wangle/channel/Handler.h \
        wangle/channel/HandlerContext.h \
        wangle/channel/HandlerContext-inl.h \
index 939788d7d22e0bde00ce8d96ea0f5e3d58c54df4..35d29c39a4e15fb7928a09a37e16bfb8a0636784 100644 (file)
@@ -17,6 +17,8 @@
 #include <folly/io/async/AsyncSocket.h>
 
 #include <folly/io/async/EventBase.h>
+#include <folly/io/async/EventHandler.h>
+#include <folly/Singleton.h>
 #include <folly/SocketAddress.h>
 #include <folly/io/IOBuf.h>
 
@@ -24,6 +26,7 @@
 #include <errno.h>
 #include <limits.h>
 #include <unistd.h>
+#include <thread>
 #include <fcntl.h>
 #include <sys/types.h>
 #include <sys/socket.h>
@@ -43,7 +46,7 @@ const AsyncSocketException socketClosedLocallyEx(
 const AsyncSocketException socketShutdownForWritesEx(
     AsyncSocketException::END_OF_FILE, "socket shutdown for writes");
 
-// TODO: It might help performance to provide a version of WriteRequest that
+// TODO: It might help performance to provide a version of BytesWriteRequest that
 // users could derive from, so we can avoid the extra allocation for each call
 // to write()/writev().  We could templatize TFramedAsyncChannel just like the
 // protocols are currently templatized for transports.
@@ -52,53 +55,6 @@ const AsyncSocketException socketShutdownForWritesEx(
 // storage space, and only our internal version would allocate it at the end of
 // the WriteRequest.
 
-/**
- * A WriteRequest object tracks information about a pending write operation.
- */
-class AsyncSocket::WriteRequest {
- public:
-  WriteRequest(AsyncSocket* socket,
-               WriteRequest* next,
-               WriteCallback* callback,
-               uint32_t totalBytesWritten) :
-    socket_(socket), next_(next), callback_(callback),
-    totalBytesWritten_(totalBytesWritten) {}
-
-  virtual void destroy() = 0;
-
-  virtual bool performWrite() = 0;
-
-  virtual void consume() = 0;
-
-  virtual bool isComplete() = 0;
-
-  WriteRequest* getNext() const {
-    return next_;
-  }
-
-  WriteCallback* getCallback() const {
-    return callback_;
-  }
-
-  uint32_t getTotalBytesWritten() const {
-    return totalBytesWritten_;
-  }
-
-  void append(WriteRequest* next) {
-    assert(next_ == nullptr);
-    next_ = next;
-  }
-
- protected:
-  // protected destructor, to ensure callers use destroy()
-  virtual ~WriteRequest() {}
-
-  AsyncSocket* socket_;         ///< parent socket
-  WriteRequest* next_;          ///< pointer to next WriteRequest
-  WriteCallback* callback_;     ///< completion callback
-  uint32_t totalBytesWritten_;  ///< total bytes written
-};
-
 /* The default WriteRequest implementation, used for write(), writev() and
  * writeChain()
  *
@@ -181,7 +137,7 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
                     uint32_t bytesWritten,
                     unique_ptr<IOBuf>&& ioBuf,
                     WriteFlags flags)
-    : AsyncSocket::WriteRequest(socket, nullptr, callback, 0)
+    : AsyncSocket::WriteRequest(socket, callback)
     , opCount_(opCount)
     , opIndex_(0)
     , flags_(flags)
@@ -773,6 +729,17 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
   }
 }
 
+void AsyncSocket::writeRequest(WriteRequest* req) {
+  if (writeReqTail_ == nullptr) {
+    assert(writeReqHead_ == nullptr);
+    writeReqHead_ = writeReqTail_ = req;
+    req->start();
+  } else {
+    writeReqTail_->append(req);
+    writeReqTail_ = req;
+  }
+}
+
 void AsyncSocket::close() {
   VLOG(5) << "AsyncSocket::close(): this=" << this << ", fd_=" << fd_
           << ", state=" << state_ << ", shutdownFlags="
index 9e3a224b55d0f5678d838f0913482da18bd7e423..523093be675e931a5eb0c1e5fa10b2e3bec59e7c 100644 (file)
@@ -334,6 +334,12 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
                   std::unique_ptr<folly::IOBuf>&& buf,
                   WriteFlags flags = WriteFlags::NONE) override;
 
+  class WriteRequest;
+  virtual void writeRequest(WriteRequest* req);
+  void writeRequestReady() {
+    handleWrite();
+  }
+
   // Methods inherited from AsyncTransport
   void close() override;
   void closeNow() override;
@@ -477,6 +483,60 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     ERROR
   };
 
+  /**
+   * A WriteRequest object tracks information about a pending write operation.
+   */
+  class WriteRequest {
+   public:
+    WriteRequest(AsyncSocket* socket, WriteCallback* callback) :
+      socket_(socket), callback_(callback) {}
+
+    virtual void start() {};
+
+    virtual void destroy() = 0;
+
+    virtual bool performWrite() = 0;
+
+    virtual void consume() = 0;
+
+    virtual bool isComplete() = 0;
+
+    WriteRequest* getNext() const {
+      return next_;
+    }
+
+    WriteCallback* getCallback() const {
+      return callback_;
+    }
+
+    uint32_t getTotalBytesWritten() const {
+      return totalBytesWritten_;
+    }
+
+    void append(WriteRequest* next) {
+      assert(next_ == nullptr);
+      next_ = next;
+    }
+
+    void fail(const char* fn, const AsyncSocketException& ex) {
+      socket_->failWrite(fn, ex);
+    }
+
+    void bytesWritten(size_t count) {
+      totalBytesWritten_ += count;
+      socket_->appBytesWritten_ += count;
+    }
+
+   protected:
+    // protected destructor, to ensure callers use destroy()
+    virtual ~WriteRequest() {}
+
+    AsyncSocket* socket_;         ///< parent socket
+    WriteRequest* next_{nullptr};          ///< pointer to next WriteRequest
+    WriteCallback* callback_;     ///< completion callback
+    uint32_t totalBytesWritten_{0};  ///< total bytes written
+  };
+
  protected:
   enum ReadResultEnum {
     READ_EOF = 0,
@@ -516,7 +576,6 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     SHUT_READ = 0x04,
   };
 
-  class WriteRequest;
   class BytesWriteRequest;
 
   class WriteTimeout : public AsyncTimeout {
diff --git a/folly/io/async/test/AsyncSocketTest.h b/folly/io/async/test/AsyncSocketTest.h
new file mode 100644 (file)
index 0000000..2c25d0e
--- /dev/null
@@ -0,0 +1,265 @@
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/test/BlockingSocket.h>
+
+#include <boost/scoped_array.hpp>
+#include <poll.h>
+
+// This is a test-only header
+/* using override */
+using namespace folly;
+
+enum StateEnum {
+  STATE_WAITING,
+  STATE_SUCCEEDED,
+  STATE_FAILED
+};
+
+typedef std::function<void()> VoidCallback;
+
+class ConnCallback : public AsyncSocket::ConnectCallback {
+ public:
+  ConnCallback()
+    : state(STATE_WAITING)
+    , exception(AsyncSocketException::UNKNOWN, "none") {}
+
+  void connectSuccess() noexcept override {
+    state = STATE_SUCCEEDED;
+    if (successCallback) {
+      successCallback();
+    }
+  }
+
+  void connectErr(const AsyncSocketException& ex) noexcept override {
+    state = STATE_FAILED;
+    exception = ex;
+    if (errorCallback) {
+      errorCallback();
+    }
+  }
+
+  StateEnum state;
+  AsyncSocketException exception;
+  VoidCallback successCallback;
+  VoidCallback errorCallback;
+};
+
+class WriteCallback : public AsyncTransportWrapper::WriteCallback {
+ public:
+  WriteCallback()
+    : state(STATE_WAITING)
+    , bytesWritten(0)
+    , exception(AsyncSocketException::UNKNOWN, "none") {}
+
+  void writeSuccess() noexcept override {
+    state = STATE_SUCCEEDED;
+    if (successCallback) {
+      successCallback();
+    }
+  }
+
+  void writeErr(size_t bytesWritten,
+                const AsyncSocketException& ex) noexcept override {
+    state = STATE_FAILED;
+    this->bytesWritten = bytesWritten;
+    exception = ex;
+    if (errorCallback) {
+      errorCallback();
+    }
+  }
+
+  StateEnum state;
+  size_t bytesWritten;
+  AsyncSocketException exception;
+  VoidCallback successCallback;
+  VoidCallback errorCallback;
+};
+
+class ReadCallback : public AsyncTransportWrapper::ReadCallback {
+ public:
+  ReadCallback()
+    : state(STATE_WAITING)
+    , exception(AsyncSocketException::UNKNOWN, "none")
+    , buffers() {}
+
+  ~ReadCallback() {
+    for (std::vector<Buffer>::iterator it = buffers.begin();
+         it != buffers.end();
+         ++it) {
+      it->free();
+    }
+    currentBuffer.free();
+  }
+
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    if (!currentBuffer.buffer) {
+      currentBuffer.allocate(4096);
+    }
+    *bufReturn = currentBuffer.buffer;
+    *lenReturn = currentBuffer.length;
+  }
+
+  void readDataAvailable(size_t len) noexcept override {
+    currentBuffer.length = len;
+    buffers.push_back(currentBuffer);
+    currentBuffer.reset();
+    if (dataAvailableCallback) {
+      dataAvailableCallback();
+    }
+  }
+
+  void readEOF() noexcept override {
+    state = STATE_SUCCEEDED;
+  }
+
+  void readErr(const AsyncSocketException& ex) noexcept override {
+    state = STATE_FAILED;
+    exception = ex;
+  }
+
+  void verifyData(const char* expected, size_t expectedLen) const {
+    size_t offset = 0;
+    for (size_t idx = 0; idx < buffers.size(); ++idx) {
+      const auto& buf = buffers[idx];
+      size_t cmpLen = std::min(buf.length, expectedLen - offset);
+      CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
+      CHECK_EQ(cmpLen, buf.length);
+      offset += cmpLen;
+    }
+    CHECK_EQ(offset, expectedLen);
+  }
+
+  class Buffer {
+   public:
+    Buffer() : buffer(nullptr), length(0) {}
+    Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
+
+    void reset() {
+      buffer = nullptr;
+      length = 0;
+    }
+    void allocate(size_t length) {
+      assert(buffer == nullptr);
+      this->buffer = static_cast<char*>(malloc(length));
+      this->length = length;
+    }
+    void free() {
+      ::free(buffer);
+      reset();
+    }
+
+    char* buffer;
+    size_t length;
+  };
+
+  StateEnum state;
+  AsyncSocketException exception;
+  std::vector<Buffer> buffers;
+  Buffer currentBuffer;
+  VoidCallback dataAvailableCallback;
+};
+
+class ReadVerifier {
+};
+
+class TestServer {
+ public:
+  // Create a TestServer.
+  // This immediately starts listening on an ephemeral port.
+  TestServer()
+    : fd_(-1) {
+    fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
+    if (fd_ < 0) {
+      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+                                "failed to create test server socket", errno);
+    }
+    if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
+      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+                                "failed to put test server socket in "
+                                "non-blocking mode", errno);
+    }
+    if (listen(fd_, 10) != 0) {
+      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+                                "failed to listen on test server socket",
+                                errno);
+    }
+
+    address_.setFromLocalAddress(fd_);
+    // The local address will contain 0.0.0.0.
+    // Change it to 127.0.0.1, so it can be used to connect to the server
+    address_.setFromIpPort("127.0.0.1", address_.getPort());
+  }
+
+  // Get the address for connecting to the server
+  const folly::SocketAddress& getAddress() const {
+    return address_;
+  }
+
+  int acceptFD(int timeout=50) {
+    struct pollfd pfd;
+    pfd.fd = fd_;
+    pfd.events = POLLIN;
+    int ret = poll(&pfd, 1, timeout);
+    if (ret == 0) {
+      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+                                "test server accept() timed out");
+    } else if (ret < 0) {
+      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+                                "test server accept() poll failed", errno);
+    }
+
+    int acceptedFd = ::accept(fd_, nullptr, nullptr);
+    if (acceptedFd < 0) {
+      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
+                                "test server accept() failed", errno);
+    }
+
+    return acceptedFd;
+  }
+
+  std::shared_ptr<BlockingSocket> accept(int timeout=50) {
+    int fd = acceptFD(timeout);
+    return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
+  }
+
+  std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
+    int fd = acceptFD(timeout);
+    return AsyncSocket::newSocket(evb, fd);
+  }
+
+  /**
+   * Accept a connection, read data from it, and verify that it matches the
+   * data in the specified buffer.
+   */
+  void verifyConnection(const char* buf, size_t len) {
+    // accept a connection
+    std::shared_ptr<BlockingSocket> acceptedSocket = accept();
+    // read the data and compare it to the specified buffer
+    boost::scoped_array<uint8_t> readbuf(new uint8_t[len]);
+    acceptedSocket->readAll(readbuf.get(), len);
+    CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
+    // make sure we get EOF next
+    uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
+    CHECK_EQ(bytesRead, 0);
+  }
+
+ private:
+  int fd_;
+  folly::SocketAddress address_;
+};
index 147bec94c609b35596a4a4dba1bcd5da09463a41..ca075164bd69ab48b6dc8c8ac93b927c64057c2b 100644 (file)
@@ -20,7 +20,7 @@
 #include <folly/SocketAddress.h>
 
 #include <folly/io/IOBuf.h>
-#include <folly/io/async/test/BlockingSocket.h>
+#include <folly/io/async/test/AsyncSocketTest.h>
 #include <folly/io/async/test/Util.h>
 
 #include <gtest/gtest.h>
@@ -47,246 +47,6 @@ using boost::scoped_array;
 
 using namespace folly;
 
-enum StateEnum {
-  STATE_WAITING,
-  STATE_SUCCEEDED,
-  STATE_FAILED
-};
-
-typedef std::function<void()> VoidCallback;
-
-
-class ConnCallback : public AsyncSocket::ConnectCallback {
- public:
-  ConnCallback()
-    : state(STATE_WAITING)
-    , exception(AsyncSocketException::UNKNOWN, "none") {}
-
-  void connectSuccess() noexcept override {
-    state = STATE_SUCCEEDED;
-    if (successCallback) {
-      successCallback();
-    }
-  }
-
-  void connectErr(const AsyncSocketException& ex) noexcept override {
-    state = STATE_FAILED;
-    exception = ex;
-    if (errorCallback) {
-      errorCallback();
-    }
-  }
-
-  StateEnum state;
-  AsyncSocketException exception;
-  VoidCallback successCallback;
-  VoidCallback errorCallback;
-};
-
-class WriteCallback : public AsyncTransportWrapper::WriteCallback {
- public:
-  WriteCallback()
-    : state(STATE_WAITING)
-    , bytesWritten(0)
-    , exception(AsyncSocketException::UNKNOWN, "none") {}
-
-  void writeSuccess() noexcept override {
-    state = STATE_SUCCEEDED;
-    if (successCallback) {
-      successCallback();
-    }
-  }
-
-  void writeErr(size_t bytesWritten,
-                const AsyncSocketException& ex) noexcept override {
-    state = STATE_FAILED;
-    this->bytesWritten = bytesWritten;
-    exception = ex;
-    if (errorCallback) {
-      errorCallback();
-    }
-  }
-
-  StateEnum state;
-  size_t bytesWritten;
-  AsyncSocketException exception;
-  VoidCallback successCallback;
-  VoidCallback errorCallback;
-};
-
-class ReadCallback : public AsyncTransportWrapper::ReadCallback {
- public:
-  ReadCallback()
-    : state(STATE_WAITING)
-    , exception(AsyncSocketException::UNKNOWN, "none")
-    , buffers() {}
-
-  ~ReadCallback() {
-    for (vector<Buffer>::iterator it = buffers.begin();
-         it != buffers.end();
-         ++it) {
-      it->free();
-    }
-    currentBuffer.free();
-  }
-
-  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
-    if (!currentBuffer.buffer) {
-      currentBuffer.allocate(4096);
-    }
-    *bufReturn = currentBuffer.buffer;
-    *lenReturn = currentBuffer.length;
-  }
-
-  void readDataAvailable(size_t len) noexcept override {
-    currentBuffer.length = len;
-    buffers.push_back(currentBuffer);
-    currentBuffer.reset();
-    if (dataAvailableCallback) {
-      dataAvailableCallback();
-    }
-  }
-
-  void readEOF() noexcept override {
-    state = STATE_SUCCEEDED;
-  }
-
-  void readErr(const AsyncSocketException& ex) noexcept override {
-    state = STATE_FAILED;
-    exception = ex;
-  }
-
-  void verifyData(const char* expected, size_t expectedLen) const {
-    size_t offset = 0;
-    for (size_t idx = 0; idx < buffers.size(); ++idx) {
-      const auto& buf = buffers[idx];
-      size_t cmpLen = std::min(buf.length, expectedLen - offset);
-      CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
-      CHECK_EQ(cmpLen, buf.length);
-      offset += cmpLen;
-    }
-    CHECK_EQ(offset, expectedLen);
-  }
-
-  class Buffer {
-   public:
-    Buffer() : buffer(nullptr), length(0) {}
-    Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
-
-    void reset() {
-      buffer = nullptr;
-      length = 0;
-    }
-    void allocate(size_t length) {
-      assert(buffer == nullptr);
-      this->buffer = static_cast<char*>(malloc(length));
-      this->length = length;
-    }
-    void free() {
-      ::free(buffer);
-      reset();
-    }
-
-    char* buffer;
-    size_t length;
-  };
-
-  StateEnum state;
-  AsyncSocketException exception;
-  vector<Buffer> buffers;
-  Buffer currentBuffer;
-  VoidCallback dataAvailableCallback;
-};
-
-class ReadVerifier {
-};
-
-class TestServer {
- public:
-  // Create a TestServer.
-  // This immediately starts listening on an ephemeral port.
-  TestServer()
-    : fd_(-1) {
-    fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
-    if (fd_ < 0) {
-      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                                "failed to create test server socket", errno);
-    }
-    if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
-      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                                "failed to put test server socket in "
-                                "non-blocking mode", errno);
-    }
-    if (listen(fd_, 10) != 0) {
-      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                                "failed to listen on test server socket",
-                                errno);
-    }
-
-    address_.setFromLocalAddress(fd_);
-    // The local address will contain 0.0.0.0.
-    // Change it to 127.0.0.1, so it can be used to connect to the server
-    address_.setFromIpPort("127.0.0.1", address_.getPort());
-  }
-
-  // Get the address for connecting to the server
-  const folly::SocketAddress& getAddress() const {
-    return address_;
-  }
-
-  int acceptFD(int timeout=50) {
-    struct pollfd pfd;
-    pfd.fd = fd_;
-    pfd.events = POLLIN;
-    int ret = poll(&pfd, 1, timeout);
-    if (ret == 0) {
-      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                                "test server accept() timed out");
-    } else if (ret < 0) {
-      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                                "test server accept() poll failed", errno);
-    }
-
-    int acceptedFd = ::accept(fd_, nullptr, nullptr);
-    if (acceptedFd < 0) {
-      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                                "test server accept() failed", errno);
-    }
-
-    return acceptedFd;
-  }
-
-  std::shared_ptr<BlockingSocket> accept(int timeout=50) {
-    int fd = acceptFD(timeout);
-    return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
-  }
-
-  std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
-    int fd = acceptFD(timeout);
-    return AsyncSocket::newSocket(evb, fd);
-  }
-
-  /**
-   * Accept a connection, read data from it, and verify that it matches the
-   * data in the specified buffer.
-   */
-  void verifyConnection(const char* buf, size_t len) {
-    // accept a connection
-    std::shared_ptr<BlockingSocket> acceptedSocket = accept();
-    // read the data and compare it to the specified buffer
-    scoped_array<uint8_t> readbuf(new uint8_t[len]);
-    acceptedSocket->readAll(readbuf.get(), len);
-    CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
-    // make sure we get EOF next
-    uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
-    CHECK_EQ(bytesRead, 0);
-  }
-
- private:
-  int fd_;
-  folly::SocketAddress address_;
-};
-
 class DelayedWrite: public AsyncTimeout {
  public:
   DelayedWrite(const std::shared_ptr<AsyncSocket>& socket,
diff --git a/folly/wangle/channel/FileRegion.cpp b/folly/wangle/channel/FileRegion.cpp
new file mode 100644 (file)
index 0000000..7d14a4a
--- /dev/null
@@ -0,0 +1,214 @@
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/wangle/channel/FileRegion.h>
+
+using namespace folly;
+using namespace folly::wangle;
+
+namespace {
+
+struct FileRegionReadPool {};
+
+Singleton<IOThreadPoolExecutor, FileRegionReadPool> readPool(
+  []{
+    return new IOThreadPoolExecutor(
+        sysconf(_SC_NPROCESSORS_ONLN),
+        std::make_shared<NamedThreadFactory>("FileRegionReadPool"));
+  });
+
+}
+
+namespace folly { namespace wangle {
+
+FileRegion::FileWriteRequest::FileWriteRequest(AsyncSocket* socket,
+    WriteCallback* callback, int fd, off_t offset, size_t count)
+  : WriteRequest(socket, callback),
+    readFd_(fd), offset_(offset), count_(count) {
+}
+
+void FileRegion::FileWriteRequest::destroy() {
+  readBase_->runInEventBaseThread([this]{
+    delete this;
+  });
+}
+
+bool FileRegion::FileWriteRequest::performWrite() {
+  if (!started_) {
+    start();
+    return true;
+  }
+
+  int flags = SPLICE_F_NONBLOCK | SPLICE_F_MORE;
+  ssize_t spliced = ::splice(pipe_out_, nullptr,
+                             socket_->getFd(), nullptr,
+                             bytesInPipe_, flags);
+  if (spliced == -1) {
+    if (errno == EAGAIN) {
+      return true;
+    }
+    return false;
+  }
+
+  bytesInPipe_ -= spliced;
+  bytesWritten(spliced);
+  return true;
+}
+
+void FileRegion::FileWriteRequest::consume() {
+  // do nothing
+}
+
+bool FileRegion::FileWriteRequest::isComplete() {
+  return totalBytesWritten_ == count_;
+}
+
+void FileRegion::FileWriteRequest::messageAvailable(size_t&& count) {
+  bool shouldWrite = bytesInPipe_ == 0;
+  bytesInPipe_ += count;
+  if (shouldWrite) {
+    socket_->writeRequestReady();
+  }
+}
+
+#ifdef __GLIBC__
+# if (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 9))
+#   define GLIBC_AT_LEAST_2_9 1
+#  endif
+#endif
+
+void FileRegion::FileWriteRequest::start() {
+  started_ = true;
+  readBase_ = readPool.get()->getEventBase();
+  readBase_->runInEventBaseThread([this]{
+    auto flags = fcntl(readFd_, F_GETFL);
+    if (flags == -1) {
+      fail(__func__, AsyncSocketException(
+          AsyncSocketException::INTERNAL_ERROR,
+          "fcntl F_GETFL failed", errno));
+      return;
+    }
+
+    flags &= O_ACCMODE;
+    if (flags == O_WRONLY) {
+      fail(__func__, AsyncSocketException(
+          AsyncSocketException::BAD_ARGS, "file not open for reading"));
+      return;
+    }
+
+#ifndef GLIBC_AT_LEAST_2_9
+    fail(__func__, AsyncSocketException(
+        AsyncSocketException::NOT_SUPPORTED,
+        "writeFile unsupported on glibc < 2.9"));
+    return;
+#else
+    int pipeFds[2];
+    if (::pipe2(pipeFds, O_NONBLOCK) == -1) {
+      fail(__func__, AsyncSocketException(
+          AsyncSocketException::INTERNAL_ERROR,
+          "pipe2 failed", errno));
+      return;
+    }
+
+    // Max size for unprevileged processes as set in /proc/sys/fs/pipe-max-size
+    // Ignore failures and just roll with it
+    // TODO maybe read max size from /proc?
+    fcntl(pipeFds[0], F_SETPIPE_SZ, 1048576);
+    fcntl(pipeFds[1], F_SETPIPE_SZ, 1048576);
+
+    pipe_out_ = pipeFds[0];
+
+    socket_->getEventBase()->runInEventBaseThreadAndWait([&]{
+      startConsuming(socket_->getEventBase(), &queue_);
+    });
+    readHandler_ = folly::make_unique<FileReadHandler>(
+        this, pipeFds[1], count_);
+#endif
+  });
+}
+
+FileRegion::FileWriteRequest::~FileWriteRequest() {
+  CHECK(readBase_->isInEventBaseThread());
+  socket_->getEventBase()->runInEventBaseThreadAndWait([&]{
+    stopConsuming();
+    if (pipe_out_ > -1) {
+      ::close(pipe_out_);
+    }
+  });
+
+}
+
+void FileRegion::FileWriteRequest::fail(
+    const char* fn,
+    const AsyncSocketException& ex) {
+  socket_->getEventBase()->runInEventBaseThread([=]{
+    WriteRequest::fail(fn, ex);
+  });
+}
+
+FileRegion::FileWriteRequest::FileReadHandler::FileReadHandler(
+    FileWriteRequest* req, int pipe_in, size_t bytesToRead)
+  : req_(req), pipe_in_(pipe_in), bytesToRead_(bytesToRead) {
+  CHECK(req_->readBase_->isInEventBaseThread());
+  initHandler(req_->readBase_, pipe_in);
+  if (!registerHandler(EventFlags::WRITE | EventFlags::PERSIST)) {
+    req_->fail(__func__, AsyncSocketException(
+        AsyncSocketException::INTERNAL_ERROR,
+        "registerHandler failed"));
+  }
+}
+
+FileRegion::FileWriteRequest::FileReadHandler::~FileReadHandler() {
+  CHECK(req_->readBase_->isInEventBaseThread());
+  unregisterHandler();
+  ::close(pipe_in_);
+}
+
+void FileRegion::FileWriteRequest::FileReadHandler::handlerReady(
+    uint16_t events) noexcept {
+  CHECK(events & EventHandler::WRITE);
+  if (bytesToRead_ == 0) {
+    unregisterHandler();
+    return;
+  }
+
+  int flags = SPLICE_F_NONBLOCK | SPLICE_F_MORE;
+  ssize_t spliced = ::splice(req_->readFd_, &req_->offset_,
+                             pipe_in_, nullptr,
+                             bytesToRead_, flags);
+  if (spliced == -1) {
+    if (errno == EAGAIN) {
+      return;
+    } else {
+      req_->fail(__func__, AsyncSocketException(
+          AsyncSocketException::INTERNAL_ERROR,
+          "splice failed", errno));
+      return;
+    }
+  }
+
+  if (spliced > 0) {
+    bytesToRead_ -= spliced;
+    try {
+      req_->queue_.putMessage(static_cast<size_t>(spliced));
+    } catch (...) {
+      req_->fail(__func__, AsyncSocketException(
+          AsyncSocketException::INTERNAL_ERROR,
+          "putMessage failed"));
+      return;
+    }
+  }
+}
+}} // folly::wangle
diff --git a/folly/wangle/channel/FileRegion.h b/folly/wangle/channel/FileRegion.h
new file mode 100644 (file)
index 0000000..6360ae3
--- /dev/null
@@ -0,0 +1,116 @@
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/Singleton.h>
+#include <folly/io/async/AsyncTransport.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/NotificationQueue.h>
+#include <folly/futures/Future.h>
+#include <folly/futures/Promise.h>
+#include <folly/wangle/concurrent/IOThreadPoolExecutor.h>
+
+namespace folly { namespace wangle {
+
+class FileRegion {
+ public:
+  FileRegion(int fd, off_t offset, size_t count)
+    : fd_(fd), offset_(offset), count_(count) {}
+
+  Future<void> transferTo(std::shared_ptr<AsyncTransport> transport) {
+    auto socket = std::dynamic_pointer_cast<AsyncSocket>(
+        transport);
+    CHECK(socket);
+    auto cb = new WriteCallback();
+    auto f = cb->promise_.getFuture();
+    auto req = new FileWriteRequest(socket.get(), cb, fd_, offset_, count_);
+    socket->writeRequest(req);
+    return f;
+  }
+
+ private:
+  class WriteCallback : private AsyncSocket::WriteCallback {
+    void writeSuccess() noexcept override {
+      promise_.setValue();
+      delete this;
+    }
+
+    void writeErr(size_t bytesWritten,
+                  const AsyncSocketException& ex)
+      noexcept override {
+      promise_.setException(ex);
+      delete this;
+    }
+
+    friend class FileRegion;
+    folly::Promise<void> promise_;
+  };
+
+  const int fd_;
+  const off_t offset_;
+  const size_t count_;
+
+  class FileWriteRequest : public AsyncSocket::WriteRequest,
+                           public NotificationQueue<size_t>::Consumer {
+   public:
+    FileWriteRequest(AsyncSocket* socket, WriteCallback* callback,
+                     int fd, off_t offset, size_t count);
+
+    void destroy() override;
+
+    bool performWrite() override;
+
+    void consume() override;
+
+    bool isComplete() override;
+
+    void messageAvailable(size_t&& count) override;
+
+    void start() override;
+
+    class FileReadHandler : public folly::EventHandler {
+     public:
+      FileReadHandler(FileWriteRequest* req, int pipe_in, size_t bytesToRead);
+
+      ~FileReadHandler();
+
+      void handlerReady(uint16_t events) noexcept override;
+
+     private:
+      FileWriteRequest* req_;
+      int pipe_in_;
+      size_t bytesToRead_;
+    };
+
+   private:
+    ~FileWriteRequest();
+
+    void fail(const char* fn, const AsyncSocketException& ex);
+
+    const int readFd_;
+    off_t offset_;
+    const size_t count_;
+    bool started_{false};
+    int pipe_out_{-1};
+
+    size_t bytesInPipe_{0};
+    folly::EventBase* readBase_;
+    folly::NotificationQueue<size_t> queue_;
+    std::unique_ptr<FileReadHandler> readHandler_;
+  };
+};
+
+}} // folly::wangle
diff --git a/folly/wangle/channel/test/FileRegionTest.cpp b/folly/wangle/channel/test/FileRegionTest.cpp
new file mode 100644 (file)
index 0000000..ff12fc2
--- /dev/null
@@ -0,0 +1,110 @@
+/*
+ * Copyright 2015 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/wangle/channel/FileRegion.h>
+#include <folly/io/async/test/AsyncSocketTest.h>
+#include <gtest/gtest.h>
+
+using namespace folly;
+using namespace folly::wangle;
+using namespace testing;
+
+struct FileRegionTest : public Test {
+  FileRegionTest() {
+    // Connect
+    socket = AsyncSocket::newSocket(&evb);
+    socket->connect(&ccb, server.getAddress(), 30);
+
+    // Accept the connection
+    acceptedSocket = server.acceptAsync(&evb);
+    acceptedSocket->setReadCB(&rcb);
+
+    // Create temp file
+    char path[] = "/tmp/AsyncSocketTest.WriteFile.XXXXXX";
+    fd = mkostemp(path, O_RDWR);
+    EXPECT_TRUE(fd > 0);
+    EXPECT_EQ(0, unlink(path));
+  }
+
+  ~FileRegionTest() {
+    // Close up shop
+    close(fd);
+    acceptedSocket->close();
+    socket->close();
+  }
+
+  TestServer server;
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket;
+  std::shared_ptr<AsyncSocket> acceptedSocket;
+  ConnCallback ccb;
+  ReadCallback rcb;
+  int fd;
+};
+
+TEST_F(FileRegionTest, Basic) {
+  size_t count = 1000000000; // 1 GB
+  void* zeroBuf = calloc(1, count);
+  write(fd, zeroBuf, count);
+
+  FileRegion fileRegion(fd, 0, count);
+  auto f = fileRegion.transferTo(socket);
+  try {
+    f.getVia(&evb);
+  } catch (std::exception& e) {
+    LOG(FATAL) << exceptionStr(e);
+  }
+
+  // Let the reads run to completion
+  socket->shutdownWrite();
+  evb.loop();
+
+  ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
+
+  size_t receivedBytes = 0;
+  for (auto& buf : rcb.buffers) {
+    receivedBytes += buf.length;
+    ASSERT_EQ(memcmp(buf.buffer, zeroBuf, buf.length), 0);
+  }
+  ASSERT_EQ(receivedBytes, count);
+}
+
+TEST_F(FileRegionTest, Repeated) {
+  size_t count = 1000000;
+  void* zeroBuf = calloc(1, count);
+  write(fd, zeroBuf, count);
+
+  int sendCount = 1000;
+
+  FileRegion fileRegion(fd, 0, count);
+  std::vector<Future<void>> fs;
+  for (int i = 0; i < sendCount; i++) {
+    fs.push_back(fileRegion.transferTo(socket));
+  }
+  auto f = collect(fs);
+  ASSERT_NO_THROW(f.getVia(&evb));
+
+  // Let the reads run to completion
+  socket->shutdownWrite();
+  evb.loop();
+
+  ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
+
+  size_t receivedBytes = 0;
+  for (auto& buf : rcb.buffers) {
+    receivedBytes += buf.length;
+  }
+  ASSERT_EQ(receivedBytes, sendCount*count);
+}