Return if we handle any error messages to avoid unnecessarily calling recv/send
[folly.git] / folly / io / async / AsyncSocket.cpp
index dc550a310e61ca9f8fdecaa7b69cdf72b61f0a6d..3065c1cd55f129c3508a9c21ba41349b8b44e586 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2017 Facebook, Inc.
+ * Copyright 2017-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 #include <folly/io/async/AsyncSocket.h>
 
 #include <folly/ExceptionWrapper.h>
+#include <folly/Format.h>
 #include <folly/Portability.h>
 #include <folly/SocketAddress.h>
 #include <folly/io/Cursor.h>
@@ -41,11 +41,11 @@ namespace fsp = folly::portability::sockets;
 namespace folly {
 
 static constexpr bool msgErrQueueSupported =
-#ifdef MSG_ERRQUEUE
+#ifdef FOLLY_HAVE_MSG_ERRQUEUE
     true;
 #else
     false;
-#endif // MSG_ERRQUEUE
+#endif // FOLLY_HAVE_MSG_ERRQUEUE
 
 // static members initializers
 const AsyncSocket::OptionMap AsyncSocket::emptyOptionMap;
@@ -104,9 +104,28 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
     if (getNext() != nullptr) {
       writeFlags |= WriteFlags::CORK;
     }
+
+    socket_->adjustZeroCopyFlags(writeFlags);
+
     auto writeResult = socket_->performWrite(
         getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
     bytesWritten_ = writeResult.writeReturn > 0 ? writeResult.writeReturn : 0;
+    if (bytesWritten_) {
+      if (socket_->isZeroCopyRequest(writeFlags)) {
+        if (isComplete()) {
+          socket_->addZeroCopyBuf(std::move(ioBuf_));
+        } else {
+          socket_->addZeroCopyBuf(ioBuf_.get());
+        }
+      } else {
+        // this happens if at least one of the prev requests were sent
+        // with zero copy but not the last one
+        if (isComplete() && socket_->getZeroCopy() &&
+            socket_->containsZeroCopyBuf(ioBuf_.get())) {
+          socket_->setZeroCopyBuf(std::move(ioBuf_));
+        }
+      }
+    }
     return writeResult;
   }
 
@@ -119,11 +138,13 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
     opIndex_ += opsWritten_;
     assert(opIndex_ < opCount_);
 
-    // If we've finished writing any IOBufs, release them
-    if (ioBuf_) {
-      for (uint32_t i = opsWritten_; i != 0; --i) {
-        assert(ioBuf_);
-        ioBuf_ = ioBuf_->pop();
+    if (!socket_->isZeroCopyRequest(flags_)) {
+      // If we've finished writing any IOBufs, release them
+      if (ioBuf_) {
+        for (uint32_t i = opsWritten_; i != 0; --i) {
+          assert(ioBuf_);
+          ioBuf_ = ioBuf_->pop();
+        }
       }
     }
 
@@ -185,8 +206,9 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
   struct iovec writeOps_[];     ///< write operation(s) list
 };
 
-int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(folly::WriteFlags flags)
-                                                                      noexcept {
+int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(
+    folly::WriteFlags flags,
+    bool zeroCopyEnabled) noexcept {
   int msg_flags = MSG_DONTWAIT;
 
 #ifdef MSG_NOSIGNAL // Linux-only
@@ -205,12 +227,16 @@ int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(folly::WriteFlags flags)
     msg_flags |= MSG_EOR;
   }
 
+  if (zeroCopyEnabled && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY)) {
+    msg_flags |= MSG_ZEROCOPY;
+  }
+
   return msg_flags;
 }
 
 namespace {
 static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback;
-}
+} // namespace
 
 AsyncSocket::AsyncSocket()
     : eventBase_(nullptr),
@@ -245,13 +271,14 @@ AsyncSocket::AsyncSocket(EventBase* evb,
   connect(nullptr, ip, port, connectTimeout);
 }
 
-AsyncSocket::AsyncSocket(EventBase* evb, int fd)
-    : eventBase_(evb),
+AsyncSocket::AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId)
+    : zeroCopyBufId_(zeroCopyBufId),
+      eventBase_(evb),
       writeTimeout_(this, evb),
       ioHandler_(this, evb, fd),
       immediateReadHandler_(this) {
-  VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
-          << fd << ")";
+  VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd=" << fd
+          << ", zeroCopyBufId=" << zeroCopyBufId << ")";
   init();
   fd_ = fd;
   setCloseOnExec();
@@ -259,14 +286,19 @@ AsyncSocket::AsyncSocket(EventBase* evb, int fd)
 }
 
 AsyncSocket::AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket)
-    : AsyncSocket(oldAsyncSocket->getEventBase(), oldAsyncSocket->detachFd()) {
+    : AsyncSocket(
+          oldAsyncSocket->getEventBase(),
+          oldAsyncSocket->detachFd(),
+          oldAsyncSocket->getZeroCopyBufId()) {
   preReceivedData_ = std::move(oldAsyncSocket->preReceivedData_);
 }
 
 // init() method, since constructor forwarding isn't supported in most
 // compilers yet.
 void AsyncSocket::init() {
-  assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+  if (eventBase_) {
+    eventBase_->dcheckIsInEventBaseThread();
+  }
   shutdownFlags_ = 0;
   state_ = StateEnum::UNINIT;
   eventFlags_ = EventHandler::NONE;
@@ -278,7 +310,7 @@ void AsyncSocket::init() {
   readCallback_ = nullptr;
   writeReqHead_ = nullptr;
   writeReqTail_ = nullptr;
-  shutdownSocketSet_ = nullptr;
+  wShutdownSocketSet_.reset();
   appBytesWritten_ = 0;
   appBytesReceived_ = 0;
   sendMsgParamCallback_ = &defaultSendMsgParamsCallback;
@@ -307,8 +339,8 @@ int AsyncSocket::detachFd() {
           << ", events=" << std::hex << eventFlags_ << ")";
   // Extract the fd, and set fd_ to -1 first, so closeNow() won't
   // actually close the descriptor.
-  if (shutdownSocketSet_) {
-    shutdownSocketSet_->remove(fd_);
+  if (const auto socketSet = wShutdownSocketSet_.lock()) {
+    socketSet->remove(fd_);
   }
   int fd = fd_;
   fd_ = -1;
@@ -326,17 +358,24 @@ const folly::SocketAddress& AsyncSocket::anyAddress() {
   return anyAddress;
 }
 
-void AsyncSocket::setShutdownSocketSet(ShutdownSocketSet* newSS) {
-  if (shutdownSocketSet_ == newSS) {
+void AsyncSocket::setShutdownSocketSet(
+    const std::weak_ptr<ShutdownSocketSet>& wNewSS) {
+  const auto newSS = wNewSS.lock();
+  const auto shutdownSocketSet = wShutdownSocketSet_.lock();
+
+  if (newSS == shutdownSocketSet) {
     return;
   }
-  if (shutdownSocketSet_ && fd_ != -1) {
-    shutdownSocketSet_->remove(fd_);
+
+  if (shutdownSocketSet && fd_ != -1) {
+    shutdownSocketSet->remove(fd_);
   }
-  shutdownSocketSet_ = newSS;
-  if (shutdownSocketSet_ && fd_ != -1) {
-    shutdownSocketSet_->add(fd_);
+
+  if (newSS && fd_ != -1) {
+    newSS->add(fd_);
   }
+
+  wShutdownSocketSet_ = wNewSS;
 }
 
 void AsyncSocket::setCloseOnExec() {
@@ -356,7 +395,7 @@ void AsyncSocket::connect(ConnectCallback* callback,
                            const OptionMap &options,
                            const folly::SocketAddress& bindAddr) noexcept {
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   addr_ = address;
 
@@ -391,8 +430,8 @@ void AsyncSocket::connect(ConnectCallback* callback,
           withAddr("failed to create socket"),
           errnoCopy);
     }
-    if (shutdownSocketSet_) {
-      shutdownSocketSet_->add(fd_);
+    if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) {
+      shutdownSocketSet->add(fd_);
     }
     ioHandler_.changeHandlerFD(fd_);
 
@@ -431,8 +470,11 @@ void AsyncSocket::connect(ConnectCallback* callback,
     // By default, turn on TCP_NODELAY
     // If setNoDelay() fails, we continue anyway; this isn't a fatal error.
     // setNoDelay() will log an error message if it fails.
+    // Also set the cached zeroCopyVal_ since it cannot be set earlier if the fd
+    // is not created
     if (address.getFamily() != AF_UNIX) {
       (void)setNoDelay(true);
+      setZeroCopy(zeroCopyVal_);
     }
 
     VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
@@ -520,6 +562,11 @@ int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
     // Ignore return value, errors are ok
     setsockopt(fd_, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
   }
+  if (noTSocks_) {
+    VLOG(4) << "Disabling TSOCKS for fd " << fd_;
+    // Ignore return value, errors are ok
+    setsockopt(fd_, SOL_SOCKET, SO_NO_TSOCKS, nullptr, 0);
+  }
 #endif
   int rv = fsp::connect(fd_, saddr, len);
   if (rv < 0) {
@@ -587,7 +634,9 @@ void AsyncSocket::cancelConnect() {
 
 void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
   sendTimeout_ = milliseconds;
-  assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+  if (eventBase_) {
+    eventBase_->dcheckIsInEventBaseThread();
+  }
 
   // If we are currently pending on write requests, immediately update
   // writeTimeout_ with the new value.
@@ -612,6 +661,22 @@ void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) {
           << ", fd=" << fd_ << ", callback=" << callback
           << ", state=" << state_;
 
+  // In the latest stable kernel 4.14.3 as of 2017-12-04, unix domain
+  // socket does not support MSG_ERRQUEUE. So recvmsg(MSG_ERRQUEUE)
+  // will read application data from unix doamin socket as error
+  // message, which breaks the message flow in application.  Feel free
+  // to remove the next code block if MSG_ERRQUEUE is added for unix
+  // domain socket in the future.
+  if (callback != nullptr) {
+    cacheLocalAddress();
+    if (localAddr_.getFamily() == AF_UNIX) {
+      LOG(ERROR) << "Failed to set ErrMessageCallback=" << callback
+                 << " for Unix Doamin Socket where MSG_ERRQUEUE is unsupported,"
+                 << " fd=" << fd_;
+      return;
+    }
+  }
+
   // Short circuit if callback is the same as the existing errMessageCallback_.
   if (callback == errMessageCallback_) {
     return;
@@ -623,7 +688,7 @@ void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) {
   }
 
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   if (callback == nullptr) {
     // We should be able to reset the callback regardless of the
@@ -709,7 +774,7 @@ void AsyncSocket::setReadCB(ReadCallback *callback) {
   }
 
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   switch ((StateEnum)state_) {
     case StateEnum::CONNECTING:
@@ -763,6 +828,134 @@ AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
   return readCallback_;
 }
 
+bool AsyncSocket::setZeroCopy(bool enable) {
+  if (msgErrQueueSupported) {
+    zeroCopyVal_ = enable;
+
+    if (fd_ < 0) {
+      return false;
+    }
+
+    int val = enable ? 1 : 0;
+    int ret = setsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
+
+    // if enable == false, set zeroCopyEnabled_ = false regardless
+    // if SO_ZEROCOPY is set or not
+    if (!enable) {
+      zeroCopyEnabled_ = enable;
+      return true;
+    }
+
+    /* if the setsockopt failed, try to see if the socket inherited the flag
+     * since we cannot set SO_ZEROCOPY on a socket s = accept
+     */
+    if (ret) {
+      val = 0;
+      socklen_t optlen = sizeof(val);
+      ret = getsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, &optlen);
+
+      if (!ret) {
+        enable = val ? true : false;
+      }
+    }
+
+    if (!ret) {
+      zeroCopyEnabled_ = enable;
+
+      return true;
+    }
+  }
+
+  return false;
+}
+
+bool AsyncSocket::isZeroCopyRequest(WriteFlags flags) {
+  return (zeroCopyEnabled_ && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY));
+}
+
+void AsyncSocket::adjustZeroCopyFlags(folly::WriteFlags& flags) {
+  if (!zeroCopyEnabled_) {
+    flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
+  }
+}
+
+void AsyncSocket::addZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf) {
+  uint32_t id = getNextZeroCopyBufId();
+  folly::IOBuf* ptr = buf.get();
+
+  idZeroCopyBufPtrMap_[id] = ptr;
+  auto& p = idZeroCopyBufInfoMap_[ptr];
+  p.count_++;
+  CHECK(p.buf_.get() == nullptr);
+  p.buf_ = std::move(buf);
+}
+
+void AsyncSocket::addZeroCopyBuf(folly::IOBuf* ptr) {
+  uint32_t id = getNextZeroCopyBufId();
+  idZeroCopyBufPtrMap_[id] = ptr;
+
+  idZeroCopyBufInfoMap_[ptr].count_++;
+}
+
+void AsyncSocket::releaseZeroCopyBuf(uint32_t id) {
+  auto iter = idZeroCopyBufPtrMap_.find(id);
+  CHECK(iter != idZeroCopyBufPtrMap_.end());
+  auto ptr = iter->second;
+  auto iter1 = idZeroCopyBufInfoMap_.find(ptr);
+  CHECK(iter1 != idZeroCopyBufInfoMap_.end());
+  if (0 == --iter1->second.count_) {
+    idZeroCopyBufInfoMap_.erase(iter1);
+  }
+
+  idZeroCopyBufPtrMap_.erase(iter);
+}
+
+void AsyncSocket::setZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf) {
+  folly::IOBuf* ptr = buf.get();
+  auto& p = idZeroCopyBufInfoMap_[ptr];
+  CHECK(p.buf_.get() == nullptr);
+
+  p.buf_ = std::move(buf);
+}
+
+bool AsyncSocket::containsZeroCopyBuf(folly::IOBuf* ptr) {
+  return (idZeroCopyBufInfoMap_.find(ptr) != idZeroCopyBufInfoMap_.end());
+}
+
+bool AsyncSocket::isZeroCopyMsg(const cmsghdr& cmsg) const {
+#ifdef FOLLY_HAVE_MSG_ERRQUEUE
+  if (zeroCopyEnabled_ &&
+      ((cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
+       (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR))) {
+    const struct sock_extended_err* serr =
+        reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
+    return (
+        (serr->ee_errno == 0) && (serr->ee_origin == SO_EE_ORIGIN_ZEROCOPY));
+  }
+#endif
+  return false;
+}
+
+void AsyncSocket::processZeroCopyMsg(const cmsghdr& cmsg) {
+#ifdef FOLLY_HAVE_MSG_ERRQUEUE
+  const struct sock_extended_err* serr =
+      reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
+  uint32_t hi = serr->ee_data;
+  uint32_t lo = serr->ee_info;
+  // disable zero copy if the buffer was actually copied
+  if ((serr->ee_code & SO_EE_CODE_ZEROCOPY_COPIED) && zeroCopyEnabled_) {
+    VLOG(2) << "AsyncSocket::processZeroCopyMsg(): setting "
+            << "zeroCopyEnabled_ = false due to SO_EE_CODE_ZEROCOPY_COPIED "
+            << "on " << fd_;
+    zeroCopyEnabled_ = false;
+  }
+
+  for (uint32_t i = lo; i <= hi; i++) {
+    releaseZeroCopyBuf(i);
+  }
+#endif
+}
+
 void AsyncSocket::write(WriteCallback* callback,
                          const void* buf, size_t bytes, WriteFlags flags) {
   iovec op;
@@ -780,6 +973,8 @@ void AsyncSocket::writev(WriteCallback* callback,
 
 void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
                               WriteFlags flags) {
+  adjustZeroCopyFlags(flags);
+
   constexpr size_t kSmallSizeMax = 64;
   size_t count = buf->countChainElements();
   if (count <= kSmallSizeMax) {
@@ -811,7 +1006,7 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
           << ", state=" << state_;
   DestructorGuard dg(this);
   unique_ptr<IOBuf>ioBuf(std::move(buf));
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
     // No new writes may be performed after the write side of the socket has
@@ -851,6 +1046,10 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
             errnoCopy);
         return failWrite(__func__, callback, 0, ex);
       } else if (countWritten == count) {
+        // done, add the whole buffer
+        if (countWritten && isZeroCopyRequest(flags)) {
+          addZeroCopyBuf(std::move(ioBuf));
+        }
         // We successfully wrote everything.
         // Invoke the callback and return.
         if (callback) {
@@ -858,6 +1057,10 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
         }
         return;
       } else { // continue writing the next writeReq
+        // add just the ptr
+        if (bytesWritten && isZeroCopyRequest(flags)) {
+          addZeroCopyBuf(ioBuf.get());
+        }
         if (bufferCallback_) {
           bufferCallback_->onEgressBuffered();
         }
@@ -961,7 +1164,7 @@ void AsyncSocket::close() {
   // Declare a DestructorGuard to ensure that the AsyncSocket cannot be
   // destroyed until close() returns.
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   // Since there are write requests pending, we have to set the
   // SHUT_WRITE_PENDING flag, and wait to perform the real close until the
@@ -993,7 +1196,9 @@ void AsyncSocket::closeNow() {
           << ", state=" << state_ << ", shutdownFlags="
           << std::hex << (int) shutdownFlags_;
   DestructorGuard dg(this);
-  assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+  if (eventBase_) {
+    eventBase_->dcheckIsInEventBaseThread();
+  }
 
   switch (state_) {
     case StateEnum::ESTABLISHED:
@@ -1085,7 +1290,7 @@ void AsyncSocket::shutdownWrite() {
     return;
   }
 
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   // There are pending writes.  Set SHUT_WRITE_PENDING so that the actual
   // shutdown will be performed once all writes complete.
@@ -1112,7 +1317,9 @@ void AsyncSocket::shutdownWriteNow() {
   }
 
   DestructorGuard dg(this);
-  assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+  if (eventBase_) {
+    eventBase_->dcheckIsInEventBaseThread();
+  }
 
   switch (static_cast<StateEnum>(state_)) {
     case StateEnum::ESTABLISHED:
@@ -1189,6 +1396,18 @@ bool AsyncSocket::readable() const {
   return rc == 1;
 }
 
+bool AsyncSocket::writable() const {
+  if (fd_ == -1) {
+    return false;
+  }
+  struct pollfd fds[1];
+  fds[0].fd = fd_;
+  fds[0].events = POLLOUT;
+  fds[0].revents = 0;
+  int rc = poll(fds, 1, 0);
+  return rc == 1;
+}
+
 bool AsyncSocket::isPending() const {
   return ioHandler_.isPending();
 }
@@ -1228,10 +1447,13 @@ void AsyncSocket::attachEventBase(EventBase* eventBase) {
           << ", state=" << state_ << ", events="
           << std::hex << eventFlags_ << ")";
   assert(eventBase_ == nullptr);
-  assert(eventBase->isInEventBaseThread());
+  eventBase->dcheckIsInEventBaseThread();
 
   eventBase_ = eventBase;
   ioHandler_.attachEventBase(eventBase);
+
+  updateEventRegistration();
+
   writeTimeout_.attachEventBase(eventBase);
   if (evbChangeCb_) {
     evbChangeCb_->evbAttached(this);
@@ -1243,9 +1465,12 @@ void AsyncSocket::detachEventBase() {
           << ", old evb=" << eventBase_ << ", state=" << state_
           << ", events=" << std::hex << eventFlags_ << ")";
   assert(eventBase_ != nullptr);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   eventBase_ = nullptr;
+
+  ioHandler_.unregisterHandler();
+
   ioHandler_.detachEventBase();
   writeTimeout_.detachEventBase();
   if (evbChangeCb_) {
@@ -1255,22 +1480,49 @@ void AsyncSocket::detachEventBase() {
 
 bool AsyncSocket::isDetachable() const {
   DCHECK(eventBase_ != nullptr);
-  DCHECK(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
-  return !ioHandler_.isHandlerRegistered() && !writeTimeout_.isScheduled();
+  return !writeTimeout_.isScheduled();
 }
 
-void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
+void AsyncSocket::cacheAddresses() {
+  if (fd_ >= 0) {
+    try {
+      cacheLocalAddress();
+      cachePeerAddress();
+    } catch (const std::system_error& e) {
+      if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
+        VLOG(1) << "Error caching addresses: " << e.code().value() << ", "
+                << e.code().message();
+      }
+    }
+  }
+}
+
+void AsyncSocket::cacheLocalAddress() const {
   if (!localAddr_.isInitialized()) {
     localAddr_.setFromLocalAddress(fd_);
   }
-  *address = localAddr_;
 }
 
-void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
+void AsyncSocket::cachePeerAddress() const {
   if (!addr_.isInitialized()) {
     addr_.setFromPeerAddress(fd_);
   }
+}
+
+bool AsyncSocket::isZeroCopyWriteInProgress() const noexcept {
+  eventBase_->dcheckIsInEventBaseThread();
+  return (!idZeroCopyBufPtrMap_.empty());
+}
+
+void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
+  cacheLocalAddress();
+  *address = localAddr_;
+}
+
+void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
+  cachePeerAddress();
   *address = addr_;
 }
 
@@ -1411,7 +1663,7 @@ void AsyncSocket::ioReady(uint16_t events) noexcept {
           << ", events=" << std::hex << events << ", state=" << state_;
   DestructorGuard dg(this);
   assert(events & EventHandler::READ_WRITE);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   uint16_t relevantEvents = uint16_t(events & EventHandler::READ_WRITE);
   EventBase* originalEventBase = eventBase_;
@@ -1419,7 +1671,11 @@ void AsyncSocket::ioReady(uint16_t events) noexcept {
   // EventHandler::WRITE is set. Any of these flags can
   // indicate that there are messages available in the socket
   // error message queue.
-  handleErrMessages();
+  // Return if we handle any error messages - this is to avoid
+  // unnecessary read/write calls
+  if (handleErrMessages()) {
+    return;
+  }
 
   // Return now if handleErrMessages() detached us from our EventBase
   if (eventBase_ != originalEventBase) {
@@ -1493,18 +1749,18 @@ void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) {
   readCallback_->getReadBuffer(buf, buflen);
 }
 
-void AsyncSocket::handleErrMessages() noexcept {
+size_t AsyncSocket::handleErrMessages() noexcept {
   // This method has non-empty implementation only for platforms
   // supporting per-socket error queues.
   VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_
           << ", state=" << state_;
-  if (errMessageCallback_ == nullptr) {
+  if (errMessageCallback_ == nullptr && idZeroCopyBufPtrMap_.empty()) {
     VLOG(7) << "AsyncSocket::handleErrMessages(): "
             << "no callback installed - exiting.";
-    return;
+    return 0;
   }
 
-#ifdef MSG_ERRQUEUE
+#ifdef FOLLY_HAVE_MSG_ERRQUEUE
   uint8_t ctrl[1024];
   unsigned char data;
   struct msghdr msg;
@@ -1521,6 +1777,7 @@ void AsyncSocket::handleErrMessages() noexcept {
   msg.msg_flags = 0;
 
   int ret;
+  size_t num = 0;
   while (true) {
     ret = recvmsg(fd_, &msg, MSG_ERRQUEUE);
     VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
@@ -1536,16 +1793,37 @@ void AsyncSocket::handleErrMessages() noexcept {
           errnoCopy);
         failErrMessageRead(__func__, ex);
       }
-      return;
+
+      return num;
     }
 
     for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
          cmsg != nullptr && cmsg->cmsg_len != 0;
          cmsg = CMSG_NXTHDR(&msg, cmsg)) {
-      errMessageCallback_->errMessage(*cmsg);
+      ++num;
+      if (isZeroCopyMsg(*cmsg)) {
+        processZeroCopyMsg(*cmsg);
+      } else {
+        if (errMessageCallback_) {
+          errMessageCallback_->errMessage(*cmsg);
+        }
+      }
     }
   }
-#endif //MSG_ERRQUEUE
+#else
+  return 0;
+#endif // FOLLY_HAVE_MSG_ERRQUEUE
+}
+
+bool AsyncSocket::processZeroCopyWriteInProgress() noexcept {
+  eventBase_->dcheckIsInEventBaseThread();
+  if (idZeroCopyBufPtrMap_.empty()) {
+    return true;
+  }
+
+  handleErrMessages();
+
+  return idZeroCopyBufPtrMap_.empty();
 }
 
 void AsyncSocket::handleRead() noexcept {
@@ -1954,7 +2232,7 @@ void AsyncSocket::timeoutExpired() noexcept {
   VLOG(7) << "AsyncSocket " << this << ", fd " << fd_ << ": timeout expired: "
           << "state=" << state_ << ", events=" << std::hex << eventFlags_;
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   if (state_ == StateEnum::CONNECTING) {
     // connect() timed out
@@ -2078,7 +2356,7 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
   } else {
     msg.msg_control = nullptr;
   }
-  int msg_flags = sendMsgParamCallback_->getFlags(flags);
+  int msg_flags = sendMsgParamCallback_->getFlags(flags, zeroCopyEnabled_);
 
   auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
   auto totalWritten = writeResult.writeReturn;
@@ -2091,6 +2369,14 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
     // this bug is fixed.
     tryAgain |= (errno == ENOTCONN);
 #endif
+
+    // workaround for running with zerocopy enabled but without a proper
+    // memlock value - see ulimit -l
+    if (zeroCopyEnabled_ && (errno == ENOBUFS)) {
+      tryAgain = true;
+      zeroCopyEnabled_ = false;
+    }
+
     if (!writeResult.exception && tryAgain) {
       // TCP buffer is full; we can't write any more data right now.
       *countWritten = 0;
@@ -2132,13 +2418,13 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
  * and call all currently installed callbacks.  After an error, the
  * AsyncSocket is completely unregistered.
  *
- * @return Returns true on succcess, or false on error.
+ * @return Returns true on success, or false on error.
  */
 bool AsyncSocket::updateEventRegistration() {
   VLOG(5) << "AsyncSocket::updateEventRegistration(this=" << this
           << ", fd=" << fd_ << ", evb=" << eventBase_ << ", state=" << state_
           << ", events=" << std::hex << eventFlags_;
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
   if (eventFlags_ == EventHandler::NONE) {
     ioHandler_.unregisterHandler();
     return true;
@@ -2440,9 +2726,11 @@ void AsyncSocket::invalidState(WriteCallback* callback) {
 }
 
 void AsyncSocket::doClose() {
-  if (fd_ == -1) return;
-  if (shutdownSocketSet_) {
-    shutdownSocketSet_->close(fd_);
+  if (fd_ == -1) {
+    return;
+  }
+  if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) {
+    shutdownSocketSet->close(fd_);
   } else {
     ::close(fd_);
   }
@@ -2474,4 +2762,4 @@ void AsyncSocket::setBufferCallback(BufferCallback* cb) {
   bufferCallback_ = cb;
 }
 
-} // folly
+} // namespace folly