Add SO_ZEROCOPY support v2017.10.09.00
authorDan Melnic <dmm@fb.com>
Mon, 9 Oct 2017 01:27:54 +0000 (18:27 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 9 Oct 2017 01:35:09 +0000 (18:35 -0700)
Summary: Add SO_ZEROCOPY support

Reviewed By: djwatson

Differential Revision: D5851637

fbshipit-source-id: 5378b7e44ce9d888ae08527506218998974d4309

folly/io/async/AsyncSSLSocket.cpp
folly/io/async/AsyncServerSocket.cpp
folly/io/async/AsyncServerSocket.h
folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/AsyncTransport.h
folly/io/async/test/AsyncSSLSocketTest.h
folly/io/async/test/ZeroCopyBenchmark.cpp [new file with mode: 0644]
folly/portability/Sockets.h

index 2b1aa88e17db321901dc7f4cd825a691611b29d6..a30cc7cabf895b1bc0fca9339d1337a825115c50 100644 (file)
@@ -1646,7 +1646,8 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
     flags |= WriteFlags::CORK;
   }
 
-  int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(flags);
+  int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(
+      flags, false /*zeroCopyEnabled*/);
   msg.msg_controllen =
       tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags);
   CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
index 8a70d77cc188701d60875ee0bde36386f8721fe0..c9f391f08640612a7a01972d4dcb4a066c12ea2d 100644 (file)
@@ -39,6 +39,13 @@ namespace fsp = folly::portability::sockets;
 
 namespace folly {
 
+static constexpr bool msgErrQueueSupported =
+#ifdef MSG_ERRQUEUE
+    true;
+#else
+    false;
+#endif // MSG_ERRQUEUE
+
 const uint32_t AsyncServerSocket::kDefaultMaxAcceptAtOnce;
 const uint32_t AsyncServerSocket::kDefaultCallbackAcceptAtOnce;
 const uint32_t AsyncServerSocket::kDefaultMaxMessagesInQueue;
@@ -331,6 +338,18 @@ void AsyncServerSocket::bindSocket(
   }
 }
 
+bool AsyncServerSocket::setZeroCopy(bool enable) {
+  if (msgErrQueueSupported) {
+    int fd = getSocket();
+    int val = enable ? 1 : 0;
+    int ret = setsockopt(fd, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
+
+    return (0 == ret);
+  }
+
+  return false;
+}
+
 void AsyncServerSocket::bind(const SocketAddress& address) {
   if (eventBase_) {
     eventBase_->dcheckIsInEventBaseThread();
index fad20a01700ac3fe3a14aae94977d854a9b90867..6589c667e91b6a047d225cb17ea2784b857f86d7 100644 (file)
@@ -319,6 +319,11 @@ class AsyncServerSocket : public DelayedDestruction
     }
   }
 
+  /* enable zerocopy support for the server sockets - the s = accept sockets
+   * inherit it
+   */
+  bool setZeroCopy(bool enable);
+
   /**
    * Bind to the specified address.
    *
index 7f8c5f13614a73f76c8a57058486a26a5bbefb69..abbec9b995ff8854789b7927f768832d1d9e2b18 100644 (file)
@@ -104,9 +104,28 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
     if (getNext() != nullptr) {
       writeFlags |= WriteFlags::CORK;
     }
+
+    socket_->adjustZeroCopyFlags(getOps(), getOpCount(), writeFlags);
+
     auto writeResult = socket_->performWrite(
         getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
     bytesWritten_ = writeResult.writeReturn > 0 ? writeResult.writeReturn : 0;
+    if (bytesWritten_) {
+      if (socket_->isZeroCopyRequest(writeFlags)) {
+        if (isComplete()) {
+          socket_->addZeroCopyBuff(std::move(ioBuf_));
+        } else {
+          socket_->addZeroCopyBuff(ioBuf_.get());
+        }
+      } else {
+        // this happens if at least one of the prev requests were sent
+        // with zero copy but not the last one
+        if (isComplete() && socket_->getZeroCopy() &&
+            socket_->containsZeroCopyBuff(ioBuf_.get())) {
+          socket_->setZeroCopyBuff(std::move(ioBuf_));
+        }
+      }
+    }
     return writeResult;
   }
 
@@ -119,11 +138,13 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
     opIndex_ += opsWritten_;
     assert(opIndex_ < opCount_);
 
-    // If we've finished writing any IOBufs, release them
-    if (ioBuf_) {
-      for (uint32_t i = opsWritten_; i != 0; --i) {
-        assert(ioBuf_);
-        ioBuf_ = ioBuf_->pop();
+    if (!socket_->isZeroCopyRequest(flags_)) {
+      // If we've finished writing any IOBufs, release them
+      if (ioBuf_) {
+        for (uint32_t i = opsWritten_; i != 0; --i) {
+          assert(ioBuf_);
+          ioBuf_ = ioBuf_->pop();
+        }
       }
     }
 
@@ -185,8 +206,9 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
   struct iovec writeOps_[];     ///< write operation(s) list
 };
 
-int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(folly::WriteFlags flags)
-                                                                      noexcept {
+int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(
+    folly::WriteFlags flags,
+    bool zeroCopyEnabled) noexcept {
   int msg_flags = MSG_DONTWAIT;
 
 #ifdef MSG_NOSIGNAL // Linux-only
@@ -205,6 +227,10 @@ int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(folly::WriteFlags flags)
     msg_flags |= MSG_EOR;
   }
 
+  if (zeroCopyEnabled && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY)) {
+    msg_flags |= MSG_ZEROCOPY;
+  }
+
   return msg_flags;
 }
 
@@ -433,8 +459,11 @@ void AsyncSocket::connect(ConnectCallback* callback,
     // By default, turn on TCP_NODELAY
     // If setNoDelay() fails, we continue anyway; this isn't a fatal error.
     // setNoDelay() will log an error message if it fails.
+    // Also set the cached zeroCopyVal_ since it cannot be set earlier if the fd
+    // is not created
     if (address.getFamily() != AF_UNIX) {
       (void)setNoDelay(true);
+      setZeroCopy(zeroCopyVal_);
     }
 
     VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
@@ -772,6 +801,156 @@ AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
   return readCallback_;
 }
 
+bool AsyncSocket::setZeroCopy(bool enable) {
+  if (msgErrQueueSupported) {
+    zeroCopyVal_ = enable;
+
+    if (fd_ < 0) {
+      return false;
+    }
+
+    int val = enable ? 1 : 0;
+    int ret = setsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
+
+    // if enable == false, set zeroCopyEnabled_ = false regardless
+    // if SO_ZEROCOPY is set or not
+    if (!enable) {
+      zeroCopyEnabled_ = enable;
+      return true;
+    }
+
+    /* if the setsockopt failed, try to see if the socket inherited the flag
+     * since we cannot set SO_ZEROCOPY on a socket s = accept
+     */
+    if (ret) {
+      val = 0;
+      socklen_t optlen = sizeof(val);
+      ret = getsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, &optlen);
+
+      if (!ret) {
+        enable = val ? true : false;
+      }
+    }
+
+    if (!ret) {
+      zeroCopyEnabled_ = enable;
+
+      return true;
+    }
+  }
+
+  return false;
+}
+
+void AsyncSocket::setZeroCopyWriteChainThreshold(size_t threshold) {
+  zeroCopyWriteChainThreshold_ = threshold;
+}
+
+bool AsyncSocket::isZeroCopyRequest(WriteFlags flags) {
+  return (zeroCopyEnabled_ && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY));
+}
+
+void AsyncSocket::adjustZeroCopyFlags(
+    folly::IOBuf* buf,
+    folly::WriteFlags& flags) {
+  if (zeroCopyEnabled_ && zeroCopyWriteChainThreshold_ && buf) {
+    if (buf->computeChainDataLength() >= zeroCopyWriteChainThreshold_) {
+      flags |= folly::WriteFlags::WRITE_MSG_ZEROCOPY;
+    } else {
+      flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
+    }
+  }
+}
+
+void AsyncSocket::adjustZeroCopyFlags(
+    const iovec* vec,
+    uint32_t count,
+    folly::WriteFlags& flags) {
+  if (zeroCopyEnabled_ && zeroCopyWriteChainThreshold_) {
+    count = std::min<uint32_t>(count, kIovMax);
+    size_t sum = 0;
+    for (uint32_t i = 0; i < count; ++i) {
+      const iovec* v = vec + i;
+      sum += v->iov_len;
+    }
+
+    if (sum >= zeroCopyWriteChainThreshold_) {
+      flags |= folly::WriteFlags::WRITE_MSG_ZEROCOPY;
+    } else {
+      flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
+    }
+  }
+}
+
+void AsyncSocket::addZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf) {
+  uint32_t id = getNextZeroCopyBuffId();
+  folly::IOBuf* ptr = buf.get();
+
+  idZeroCopyBufPtrMap_[id] = ptr;
+  auto& p = idZeroCopyBufPtrToBufMap_[ptr];
+  p.first++;
+  CHECK(p.second.get() == nullptr);
+  p.second = std::move(buf);
+}
+
+void AsyncSocket::addZeroCopyBuff(folly::IOBuf* ptr) {
+  uint32_t id = getNextZeroCopyBuffId();
+  idZeroCopyBufPtrMap_[id] = ptr;
+
+  idZeroCopyBufPtrToBufMap_[ptr].first++;
+}
+
+void AsyncSocket::releaseZeroCopyBuff(uint32_t id) {
+  auto iter = idZeroCopyBufPtrMap_.find(id);
+  CHECK(iter != idZeroCopyBufPtrMap_.end());
+  auto ptr = iter->second;
+  auto iter1 = idZeroCopyBufPtrToBufMap_.find(ptr);
+  CHECK(iter1 != idZeroCopyBufPtrToBufMap_.end());
+  if (0 == --iter1->second.first) {
+    idZeroCopyBufPtrToBufMap_.erase(iter1);
+  }
+}
+
+void AsyncSocket::setZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf) {
+  folly::IOBuf* ptr = buf.get();
+  auto& p = idZeroCopyBufPtrToBufMap_[ptr];
+  CHECK(p.second.get() == nullptr);
+
+  p.second = std::move(buf);
+}
+
+bool AsyncSocket::containsZeroCopyBuff(folly::IOBuf* ptr) {
+  return (
+      idZeroCopyBufPtrToBufMap_.find(ptr) != idZeroCopyBufPtrToBufMap_.end());
+}
+
+bool AsyncSocket::isZeroCopyMsg(const cmsghdr& cmsg) const {
+#ifdef MSG_ERRQUEUE
+  if (zeroCopyEnabled_ &&
+      ((cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
+       (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR))) {
+    const struct sock_extended_err* serr =
+        reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
+    return (
+        (serr->ee_errno == 0) && (serr->ee_origin == SO_EE_ORIGIN_ZEROCOPY));
+  }
+#endif
+  return false;
+}
+
+void AsyncSocket::processZeroCopyMsg(const cmsghdr& cmsg) {
+#ifdef MSG_ERRQUEUE
+  const struct sock_extended_err* serr =
+      reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
+  uint32_t hi = serr->ee_data;
+  uint32_t lo = serr->ee_info;
+
+  for (uint32_t i = lo; i <= hi; i++) {
+    releaseZeroCopyBuff(i);
+  }
+#endif
+}
+
 void AsyncSocket::write(WriteCallback* callback,
                          const void* buf, size_t bytes, WriteFlags flags) {
   iovec op;
@@ -789,6 +968,8 @@ void AsyncSocket::writev(WriteCallback* callback,
 
 void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
                               WriteFlags flags) {
+  adjustZeroCopyFlags(buf.get(), flags);
+
   constexpr size_t kSmallSizeMax = 64;
   size_t count = buf->countChainElements();
   if (count <= kSmallSizeMax) {
@@ -860,6 +1041,10 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
             errnoCopy);
         return failWrite(__func__, callback, 0, ex);
       } else if (countWritten == count) {
+        // done, add the whole buffer
+        if (isZeroCopyRequest(flags)) {
+          addZeroCopyBuff(std::move(ioBuf));
+        }
         // We successfully wrote everything.
         // Invoke the callback and return.
         if (callback) {
@@ -867,6 +1052,10 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
         }
         return;
       } else { // continue writing the next writeReq
+        // add just the ptr
+        if (isZeroCopyRequest(flags)) {
+          addZeroCopyBuff(ioBuf.get());
+        }
         if (bufferCallback_) {
           bufferCallback_->onEgressBuffered();
         }
@@ -1545,7 +1734,8 @@ void AsyncSocket::handleErrMessages() noexcept {
   // supporting per-socket error queues.
   VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_
           << ", state=" << state_;
-  if (errMessageCallback_ == nullptr) {
+  if (errMessageCallback_ == nullptr &&
+      (!zeroCopyEnabled_ || idZeroCopyBufPtrMap_.empty())) {
     VLOG(7) << "AsyncSocket::handleErrMessages(): "
             << "no callback installed - exiting.";
     return;
@@ -1587,11 +1777,15 @@ void AsyncSocket::handleErrMessages() noexcept {
     }
 
     for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
-         cmsg != nullptr &&
-           cmsg->cmsg_len != 0 &&
-           errMessageCallback_ != nullptr;
+         cmsg != nullptr && cmsg->cmsg_len != 0;
          cmsg = CMSG_NXTHDR(&msg, cmsg)) {
-      errMessageCallback_->errMessage(*cmsg);
+      if (isZeroCopyMsg(*cmsg)) {
+        processZeroCopyMsg(*cmsg);
+      } else {
+        if (errMessageCallback_) {
+          errMessageCallback_->errMessage(*cmsg);
+        }
+      }
     }
   }
 #endif //MSG_ERRQUEUE
@@ -2127,7 +2321,7 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
   } else {
     msg.msg_control = nullptr;
   }
-  int msg_flags = sendMsgParamCallback_->getFlags(flags);
+  int msg_flags = sendMsgParamCallback_->getFlags(flags, zeroCopyEnabled_);
 
   auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
   auto totalWritten = writeResult.writeReturn;
index beff1a07285e0bc9b93fe130f30ad87c00648790..e99300fb238a6ae491f453bdbc886b405037adda 100644 (file)
@@ -156,8 +156,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
      *
      * @param flags     Write flags requested for the given write operation
      */
-    int getFlags(folly::WriteFlags flags) noexcept {
-      return getFlagsImpl(flags, getDefaultFlags(flags));
+    int getFlags(folly::WriteFlags flags, bool zeroCopyEnabled) noexcept {
+      return getFlagsImpl(flags, getDefaultFlags(flags, zeroCopyEnabled));
     }
 
     /**
@@ -211,7 +211,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
      *
      * @param flags     Write flags requested for the given write operation
      */
-    int getDefaultFlags(folly::WriteFlags flags) noexcept;
+    int getDefaultFlags(folly::WriteFlags flags, bool zeroCopyEnabled) noexcept;
   };
 
   explicit AsyncSocket();
@@ -504,6 +504,21 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   void setReadCB(ReadCallback* callback) override;
   ReadCallback* getReadCallback() const override;
 
+  static const size_t kDefaultZeroCopyThreshold = 32768; // 32KB
+
+  bool setZeroCopy(bool enable);
+  bool getZeroCopy() const {
+    return zeroCopyEnabled_;
+  }
+
+  void setZeroCopyWriteChainThreshold(size_t threshold);
+  size_t getZeroCopyWriteChainThreshold() const {
+    return zeroCopyWriteChainThreshold_;
+  }
+
+  bool isZeroCopyMsg(const cmsghdr& cmsg) const;
+  void processZeroCopyMsg(const cmsghdr& cmsg);
+
   void write(WriteCallback* callback, const void* buf, size_t bytes,
              WriteFlags flags = WriteFlags::NONE) override;
   void writev(WriteCallback* callback, const iovec* vec, size_t count,
@@ -1133,6 +1148,32 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   void cacheLocalAddress() const;
   void cachePeerAddress() const;
 
+  bool isZeroCopyRequest(WriteFlags flags);
+  uint32_t getNextZeroCopyBuffId() {
+    return zeroCopyBuffId_++;
+  }
+  void adjustZeroCopyFlags(folly::IOBuf* buf, folly::WriteFlags& flags);
+  void adjustZeroCopyFlags(
+      const iovec* vec,
+      uint32_t count,
+      folly::WriteFlags& flags);
+  void addZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf);
+  void addZeroCopyBuff(folly::IOBuf* ptr);
+  void setZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf);
+  bool containsZeroCopyBuff(folly::IOBuf* ptr);
+  void releaseZeroCopyBuff(uint32_t id);
+
+  // a folly::IOBuf can be used in multiple partial requests
+  // so we keep a map that maps a buffer id to a raw folly::IOBuf ptr
+  // and one more map that adds a ref count for a folly::IOBuf that is either
+  // the original ptr or nullptr
+  uint32_t zeroCopyBuffId_{0};
+  std::unordered_map<uint32_t, folly::IOBuf*> idZeroCopyBufPtrMap_;
+  std::unordered_map<
+      folly::IOBuf*,
+      std::pair<uint32_t, std::unique_ptr<folly::IOBuf>>>
+      idZeroCopyBufPtrToBufMap_;
+
   StateEnum state_;                      ///< StateEnum describing current state
   uint8_t shutdownFlags_;                ///< Shutdown state (ShutdownFlags)
   uint16_t eventFlags_;                  ///< EventBase::HandlerFlags settings
@@ -1149,8 +1190,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
 
   ConnectCallback* connectCallback_;     ///< ConnectCallback
   ErrMessageCallback* errMessageCallback_; ///< TimestampCallback
-  SendMsgParamsCallback*                 ///< Callback for retreaving
-      sendMsgParamCallback_;             ///< ::sendmsg() parameters
+  SendMsgParamsCallback* ///< Callback for retrieving
+      sendMsgParamCallback_; ///< ::sendmsg() parameters
   ReadCallback* readCallback_;           ///< ReadCallback
   WriteRequest* writeReqHead_;           ///< Chain of WriteRequests
   WriteRequest* writeReqTail_;           ///< End of WriteRequest chain
@@ -1178,6 +1219,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   bool noTSocks_{false};
   // Whether to track EOR or not.
   bool trackEor_{false};
+  bool zeroCopyEnabled_{false};
+  bool zeroCopyVal_{false};
+  size_t zeroCopyWriteChainThreshold_{kDefaultZeroCopyThreshold};
 
   std::unique_ptr<EvbChangeCallback> evbChangeCb_{nullptr};
 };
index ca5cc4d2b4b2ae9b1eb1be56e1c1c49cc038102d..c42888da211144eff5bbb1042a3b40e5cbc93074 100644 (file)
@@ -60,6 +60,10 @@ enum class WriteFlags : uint32_t {
    * this indicates that only the write side of socket should be shutdown
    */
   WRITE_SHUTDOWN = 0x04,
+  /*
+   * use msg zerocopy if allowed
+   */
+  WRITE_MSG_ZEROCOPY = 0x08,
 };
 
 /*
index b916c933591bc4f97125ba62adbcc8392ae1af00..122bb8f70ed92c5c9892577bfc61b0716e353870 100644 (file)
@@ -60,7 +60,7 @@ class SendMsgParamsCallbackBase :
 
   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
                                                                      override {
-    return oldCallback_->getFlags(flags);
+    return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
   }
 
   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
@@ -88,7 +88,7 @@ class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
     if (flags_) {
       return flags_;
     } else {
-      return oldCallback_->getFlags(flags);
+      return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
     }
   }
 
diff --git a/folly/io/async/test/ZeroCopyBenchmark.cpp b/folly/io/async/test/ZeroCopyBenchmark.cpp
new file mode 100644 (file)
index 0000000..7f397e0
--- /dev/null
@@ -0,0 +1,379 @@
+/*
+ * Copyright 2017 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/Benchmark.h>
+
+#include <folly/ExceptionWrapper.h>
+#include <folly/SocketAddress.h>
+#include <folly/io/IOBufQueue.h>
+#include <folly/io/async/AsyncServerSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/EventBase.h>
+
+#include <folly/portability/GFlags.h>
+
+using namespace folly;
+
+class TestAsyncSocket {
+ public:
+  explicit TestAsyncSocket(
+      folly::EventBase* evb,
+      int numLoops,
+      size_t bufferSize,
+      bool zeroCopy)
+      : evb_(evb),
+        numLoops_(numLoops),
+        sock_(new folly::AsyncSocket(evb)),
+        callback_(this),
+        client_(true) {
+    setBufferSize(bufferSize);
+    setZeroCopy(zeroCopy);
+  }
+
+  explicit TestAsyncSocket(
+      folly::EventBase* evb,
+      int fd,
+      int numLoops,
+      size_t bufferSize,
+      bool zeroCopy)
+      : evb_(evb),
+        numLoops_(numLoops),
+        sock_(new folly::AsyncSocket(evb, fd)),
+        callback_(this),
+        client_(false) {
+    setBufferSize(bufferSize);
+    setZeroCopy(zeroCopy);
+    // enable reads
+    if (sock_) {
+      sock_->setReadCB(&callback_);
+    }
+  }
+
+  ~TestAsyncSocket() {
+    clearBuffers();
+  }
+
+  void connect(const folly::SocketAddress& remote) {
+    if (sock_) {
+      sock_->connect(&callback_, remote);
+    }
+  }
+
+ private:
+  void setZeroCopy(bool enable) {
+    zeroCopy_ = enable;
+    if (sock_) {
+      sock_->setZeroCopy(zeroCopy_);
+    }
+  }
+
+  void setBufferSize(size_t bufferSize) {
+    clearBuffers();
+    bufferSize_ = bufferSize;
+
+    readBuffer_ = new char[bufferSize_];
+  }
+
+  class Callback : public folly::AsyncSocket::ReadCallback,
+                   public folly::AsyncSocket::ConnectCallback {
+   public:
+    explicit Callback(TestAsyncSocket* parent) : parent_(parent) {}
+
+    void connectSuccess() noexcept override {
+      parent_->sock_->setReadCB(this);
+      parent_->onConnected();
+    }
+
+    void connectErr(const folly::AsyncSocketException& ex) noexcept override {
+      LOG(ERROR) << "Connect error: " << ex.what();
+      parent_->onDataFinish(folly::exception_wrapper(ex));
+    }
+
+    void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+      parent_->getReadBuffer(bufReturn, lenReturn);
+    }
+
+    void readDataAvailable(size_t len) noexcept override {
+      parent_->readDataAvailable(len);
+    }
+
+    void readEOF() noexcept override {
+      parent_->onDataFinish(folly::exception_wrapper());
+    }
+
+    void readErr(const folly::AsyncSocketException& ex) noexcept override {
+      parent_->onDataFinish(folly::exception_wrapper(ex));
+    }
+
+   private:
+    TestAsyncSocket* parent_{nullptr};
+  };
+
+  void clearBuffers() {
+    if (readBuffer_) {
+      delete[] readBuffer_;
+    }
+  }
+
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) {
+    *bufReturn = readBuffer_ + readOffset_;
+    *lenReturn = bufferSize_ - readOffset_;
+  }
+
+  void readDataAvailable(size_t len) noexcept {
+    readOffset_ += len;
+    if (readOffset_ == bufferSize_) {
+      readOffset_ = 0;
+      onDataReady();
+    }
+  }
+
+  void onConnected() {
+    setZeroCopy(zeroCopy_);
+    writeBuffer();
+  }
+
+  void onDataReady() {
+    currLoop_++;
+    if (client_ && currLoop_ >= numLoops_) {
+      evb_->terminateLoopSoon();
+      return;
+    }
+    writeBuffer();
+  }
+
+  void onDataFinish(folly::exception_wrapper) {
+    if (client_) {
+      evb_->terminateLoopSoon();
+    }
+  }
+
+  bool writeBuffer() {
+    writeBuffer_ =
+        folly::IOBuf::takeOwnership(::malloc(bufferSize_), bufferSize_);
+
+    if (sock_ && writeBuffer_) {
+      sock_->writeChain(
+          nullptr,
+          std::move(writeBuffer_),
+          zeroCopy_ ? WriteFlags::WRITE_MSG_ZEROCOPY : WriteFlags::NONE);
+    }
+
+    return true;
+  }
+
+  folly::EventBase* evb_;
+  int numLoops_{0};
+  int currLoop_{0};
+  bool zeroCopy_{false};
+
+  folly::AsyncSocket::UniquePtr sock_;
+  Callback callback_;
+
+  size_t bufferSize_{0};
+  size_t readOffset_{0};
+  char* readBuffer_{nullptr};
+  std::unique_ptr<folly::IOBuf> writeBuffer_;
+
+  bool client_;
+};
+
+class TestServer : public folly::AsyncServerSocket::AcceptCallback {
+ public:
+  explicit TestServer(
+      folly::EventBase* evb,
+      int numLoops,
+      size_t bufferSize,
+      bool zeroCopy)
+      : evb_(evb),
+        numLoops_(numLoops),
+        bufferSize_(bufferSize),
+        zeroCopy_(zeroCopy) {}
+
+  void addCallbackToServerSocket(folly::AsyncServerSocket& sock) {
+    sock.addAcceptCallback(this, evb_);
+  }
+
+  void connectionAccepted(
+      int fd,
+      const folly::SocketAddress& /* unused */) noexcept override {
+    auto client = std::make_shared<TestAsyncSocket>(
+        evb_, fd, numLoops_, bufferSize_, zeroCopy_);
+    clients_[client.get()] = client;
+  }
+
+  void acceptError(const std::exception&) noexcept override {}
+
+ private:
+  folly::EventBase* evb_;
+  int numLoops_;
+  size_t bufferSize_;
+  bool zeroCopy_;
+  std::unique_ptr<TestAsyncSocket> client_;
+  std::unordered_map<TestAsyncSocket*, std::shared_ptr<TestAsyncSocket>>
+      clients_;
+};
+
+class Test {
+ public:
+  explicit Test(int numLoops, bool zeroCopy, size_t bufferSize)
+      : numLoops_(numLoops),
+        zeroCopy_(zeroCopy),
+        bufferSize_(bufferSize),
+        client_(new TestAsyncSocket(&evb_, numLoops_, bufferSize_, zeroCopy)),
+        listenSock_(new folly::AsyncServerSocket(&evb_)),
+        server_(&evb_, numLoops_, bufferSize_, zeroCopy) {
+    if (listenSock_) {
+      server_.addCallbackToServerSocket(*listenSock_);
+    }
+  }
+
+  void run() {
+    evb_.runInEventBaseThread([this]() {
+
+      if (listenSock_) {
+        listenSock_->bind(0);
+        listenSock_->setZeroCopy(zeroCopy_);
+        listenSock_->listen(10);
+        listenSock_->startAccepting();
+
+        connectOne();
+      }
+    });
+
+    evb_.loopForever();
+  }
+
+ private:
+  void connectOne() {
+    SocketAddress addr = listenSock_->getAddress();
+    client_->connect(addr);
+  }
+
+  int numLoops_;
+  bool zeroCopy_;
+  size_t bufferSize_;
+
+  EventBase evb_;
+  std::unique_ptr<TestAsyncSocket> client_;
+  folly::AsyncServerSocket::UniquePtr listenSock_;
+  TestServer server_;
+};
+
+void runClient(
+    const std::string& host,
+    uint16_t port,
+    int numLoops,
+    bool zeroCopy,
+    size_t bufferSize) {
+  LOG(INFO) << "Running client. host = " << host << " port = " << port
+            << " numLoops = " << numLoops << " zeroCopy = " << zeroCopy
+            << " bufferSize = " << bufferSize;
+
+  EventBase evb;
+  std::unique_ptr<TestAsyncSocket> client(
+      new TestAsyncSocket(&evb, numLoops, bufferSize, zeroCopy));
+  SocketAddress addr(host, port);
+  evb.runInEventBaseThread([&]() { client->connect(addr); });
+
+  evb.loopForever();
+}
+
+void runServer(uint16_t port, int numLoops, bool zeroCopy, size_t bufferSize) {
+  LOG(INFO) << "Running server. port = " << port << " numLoops = " << numLoops
+            << " zeroCopy = " << zeroCopy << " bufferSize = " << bufferSize;
+
+  EventBase evb;
+  folly::AsyncServerSocket::UniquePtr listenSock(
+      new folly::AsyncServerSocket(&evb));
+  TestServer server(&evb, numLoops, bufferSize, zeroCopy);
+
+  server.addCallbackToServerSocket(*listenSock);
+
+  evb.runInEventBaseThread([&]() {
+    listenSock->bind(port);
+    listenSock->setZeroCopy(zeroCopy);
+    listenSock->listen(10);
+    listenSock->startAccepting();
+  });
+
+  evb.loopForever();
+}
+
+static auto constexpr kMaxLoops = 200000;
+
+void zeroCopyOn(unsigned /* unused */, size_t bufferSize) {
+  Test test(kMaxLoops, true, bufferSize);
+  test.run();
+}
+
+void zeroCopyOff(unsigned /* unused */, size_t bufferSize) {
+  Test test(kMaxLoops, false, bufferSize);
+  test.run();
+}
+
+BENCHMARK_PARAM(zeroCopyOn, 4096);
+BENCHMARK_PARAM(zeroCopyOff, 4096);
+BENCHMARK_DRAW_LINE()
+BENCHMARK_PARAM(zeroCopyOn, 8192);
+BENCHMARK_PARAM(zeroCopyOff, 8192);
+BENCHMARK_DRAW_LINE()
+BENCHMARK_PARAM(zeroCopyOn, 16384);
+BENCHMARK_PARAM(zeroCopyOff, 16384);
+BENCHMARK_DRAW_LINE()
+BENCHMARK_PARAM(zeroCopyOn, 32768);
+BENCHMARK_PARAM(zeroCopyOff, 32768);
+BENCHMARK_DRAW_LINE()
+BENCHMARK_PARAM(zeroCopyOn, 65536);
+BENCHMARK_PARAM(zeroCopyOff, 65536);
+BENCHMARK_DRAW_LINE()
+BENCHMARK_PARAM(zeroCopyOn, 131072);
+BENCHMARK_PARAM(zeroCopyOff, 131072);
+BENCHMARK_DRAW_LINE()
+BENCHMARK_PARAM(zeroCopyOn, 262144);
+BENCHMARK_PARAM(zeroCopyOff, 262144);
+BENCHMARK_DRAW_LINE()
+BENCHMARK_PARAM(zeroCopyOn, 524288);
+BENCHMARK_PARAM(zeroCopyOff, 524288);
+BENCHMARK_DRAW_LINE()
+BENCHMARK_PARAM(zeroCopyOn, 1048576);
+BENCHMARK_PARAM(zeroCopyOff, 1048576);
+BENCHMARK_DRAW_LINE()
+
+DEFINE_bool(client, false, "client mode");
+DEFINE_bool(server, false, "server mode");
+DEFINE_bool(zeroCopy, false, "use zerocopy");
+DEFINE_int32(numLoops, kMaxLoops, "number of loops");
+DEFINE_int32(bufferSize, 524288, "buffer size");
+DEFINE_int32(port, 33130, "port");
+DEFINE_string(host, "::1", "host");
+
+int main(int argc, char** argv) {
+  gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+  if (FLAGS_client) {
+    runClient(
+        FLAGS_host,
+        FLAGS_port,
+        FLAGS_numLoops,
+        FLAGS_zeroCopy,
+        FLAGS_bufferSize);
+  } else if (FLAGS_server) {
+    runServer(FLAGS_port, FLAGS_numLoops, FLAGS_zeroCopy, FLAGS_bufferSize);
+  } else {
+    runBenchmarks();
+  }
+}
index f1617fc923dcc496f6834d84fcbbd685a454bdcb..99f1424dd7dc26b2569262edccb2429e9db91e27 100755 (executable)
 #include <netinet/tcp.h>
 #include <sys/socket.h>
 #include <sys/un.h>
+
+#ifdef MSG_ERRQUEUE
+/* for struct sock_extended_err*/
+#include <linux/errqueue.h>
+#endif
+
+#ifndef SO_EE_ORIGIN_ZEROCOPY
+#define SO_EE_ORIGIN_ZEROCOPY 5
+#endif
+
+#ifndef SO_ZEROCOPY
+#define SO_ZEROCOPY 60
+#endif
+
+#ifndef MSG_ZEROCOPY
+#define MSG_ZEROCOPY 0x4000000
+#endif
+
 #else
 #include <folly/portability/IOVec.h>
 #include <folly/portability/SysTypes.h>
 using nfds_t = int;
 using sa_family_t = ADDRESS_FAMILY;
 
+// these are not supported
+#define SO_EE_ORIGIN_ZEROCOPY 0
+#define SO_ZEROCOPY 0
+#define MSG_ZEROCOPY 0x0
+
 // We don't actually support either of these flags
 // currently.
 #define MSG_DONTWAIT 0x1000
@@ -198,5 +221,6 @@ int setsockopt(
 
 #ifdef _WIN32
 // Add our helpers to the overload set.
-/* using override */ using namespace folly::portability::sockets;
+/* using override */
+using namespace folly::portability::sockets;
 #endif