Return if we handle any error messages to avoid unnecessarily calling recv/send
[folly.git] / folly / io / async / AsyncSocket.cpp
index a4c41969e49a9278d95094bc4101abd9cdda935d..3065c1cd55f129c3508a9c21ba41349b8b44e586 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2017 Facebook, Inc.
+ * Copyright 2017-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -13,7 +13,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 #include <folly/io/async/AsyncSocket.h>
 
 #include <folly/ExceptionWrapper.h>
@@ -106,7 +105,7 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
       writeFlags |= WriteFlags::CORK;
     }
 
-    socket_->adjustZeroCopyFlags(getOps(), getOpCount(), writeFlags);
+    socket_->adjustZeroCopyFlags(writeFlags);
 
     auto writeResult = socket_->performWrite(
         getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
@@ -272,13 +271,14 @@ AsyncSocket::AsyncSocket(EventBase* evb,
   connect(nullptr, ip, port, connectTimeout);
 }
 
-AsyncSocket::AsyncSocket(EventBase* evb, int fd)
-    : eventBase_(evb),
+AsyncSocket::AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId)
+    : zeroCopyBufId_(zeroCopyBufId),
+      eventBase_(evb),
       writeTimeout_(this, evb),
       ioHandler_(this, evb, fd),
       immediateReadHandler_(this) {
-  VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
-          << fd << ")";
+  VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd=" << fd
+          << ", zeroCopyBufId=" << zeroCopyBufId << ")";
   init();
   fd_ = fd;
   setCloseOnExec();
@@ -286,7 +286,10 @@ AsyncSocket::AsyncSocket(EventBase* evb, int fd)
 }
 
 AsyncSocket::AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket)
-    : AsyncSocket(oldAsyncSocket->getEventBase(), oldAsyncSocket->detachFd()) {
+    : AsyncSocket(
+          oldAsyncSocket->getEventBase(),
+          oldAsyncSocket->detachFd(),
+          oldAsyncSocket->getZeroCopyBufId()) {
   preReceivedData_ = std::move(oldAsyncSocket->preReceivedData_);
 }
 
@@ -658,6 +661,22 @@ void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) {
           << ", fd=" << fd_ << ", callback=" << callback
           << ", state=" << state_;
 
+  // In the latest stable kernel 4.14.3 as of 2017-12-04, unix domain
+  // socket does not support MSG_ERRQUEUE. So recvmsg(MSG_ERRQUEUE)
+  // will read application data from unix doamin socket as error
+  // message, which breaks the message flow in application.  Feel free
+  // to remove the next code block if MSG_ERRQUEUE is added for unix
+  // domain socket in the future.
+  if (callback != nullptr) {
+    cacheLocalAddress();
+    if (localAddr_.getFamily() == AF_UNIX) {
+      LOG(ERROR) << "Failed to set ErrMessageCallback=" << callback
+                 << " for Unix Doamin Socket where MSG_ERRQUEUE is unsupported,"
+                 << " fd=" << fd_;
+      return;
+    }
+  }
+
   // Short circuit if callback is the same as the existing errMessageCallback_.
   if (callback == errMessageCallback_) {
     return;
@@ -850,49 +869,18 @@ bool AsyncSocket::setZeroCopy(bool enable) {
   return false;
 }
 
-void AsyncSocket::setZeroCopyWriteChainThreshold(size_t threshold) {
-  zeroCopyWriteChainThreshold_ = threshold;
-}
-
 bool AsyncSocket::isZeroCopyRequest(WriteFlags flags) {
   return (zeroCopyEnabled_ && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY));
 }
 
-void AsyncSocket::adjustZeroCopyFlags(
-    folly::IOBuf* buf,
-    folly::WriteFlags& flags) {
-  if (zeroCopyEnabled_ && zeroCopyWriteChainThreshold_ && buf &&
-      buf->isManaged()) {
-    if (buf->computeChainDataLength() >= zeroCopyWriteChainThreshold_) {
-      flags |= folly::WriteFlags::WRITE_MSG_ZEROCOPY;
-    } else {
-      flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
-    }
-  }
-}
-
-void AsyncSocket::adjustZeroCopyFlags(
-    const iovec* vec,
-    uint32_t count,
-    folly::WriteFlags& flags) {
-  if (zeroCopyEnabled_ && zeroCopyWriteChainThreshold_) {
-    count = std::min<uint32_t>(count, kIovMax);
-    size_t sum = 0;
-    for (uint32_t i = 0; i < count; ++i) {
-      const iovec* v = vec + i;
-      sum += v->iov_len;
-    }
-
-    if (sum >= zeroCopyWriteChainThreshold_) {
-      flags |= folly::WriteFlags::WRITE_MSG_ZEROCOPY;
-    } else {
-      flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
-    }
+void AsyncSocket::adjustZeroCopyFlags(folly::WriteFlags& flags) {
+  if (!zeroCopyEnabled_) {
+    flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY);
   }
 }
 
 void AsyncSocket::addZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf) {
-  uint32_t id = getNextZeroCopyBuffId();
+  uint32_t id = getNextZeroCopyBufId();
   folly::IOBuf* ptr = buf.get();
 
   idZeroCopyBufPtrMap_[id] = ptr;
@@ -903,7 +891,7 @@ void AsyncSocket::addZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf) {
 }
 
 void AsyncSocket::addZeroCopyBuf(folly::IOBuf* ptr) {
-  uint32_t id = getNextZeroCopyBuffId();
+  uint32_t id = getNextZeroCopyBufId();
   idZeroCopyBufPtrMap_[id] = ptr;
 
   idZeroCopyBufInfoMap_[ptr].count_++;
@@ -918,6 +906,8 @@ void AsyncSocket::releaseZeroCopyBuf(uint32_t id) {
   if (0 == --iter1->second.count_) {
     idZeroCopyBufInfoMap_.erase(iter1);
   }
+
+  idZeroCopyBufPtrMap_.erase(iter);
 }
 
 void AsyncSocket::setZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf) {
@@ -983,7 +973,7 @@ void AsyncSocket::writev(WriteCallback* callback,
 
 void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
                               WriteFlags flags) {
-  adjustZeroCopyFlags(buf.get(), flags);
+  adjustZeroCopyFlags(flags);
 
   constexpr size_t kSmallSizeMax = 64;
   size_t count = buf->countChainElements();
@@ -1057,7 +1047,7 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
         return failWrite(__func__, callback, 0, ex);
       } else if (countWritten == count) {
         // done, add the whole buffer
-        if (isZeroCopyRequest(flags)) {
+        if (countWritten && isZeroCopyRequest(flags)) {
           addZeroCopyBuf(std::move(ioBuf));
         }
         // We successfully wrote everything.
@@ -1068,7 +1058,7 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
         return;
       } else { // continue writing the next writeReq
         // add just the ptr
-        if (isZeroCopyRequest(flags)) {
+        if (bytesWritten && isZeroCopyRequest(flags)) {
           addZeroCopyBuf(ioBuf.get());
         }
         if (bufferCallback_) {
@@ -1461,6 +1451,9 @@ void AsyncSocket::attachEventBase(EventBase* eventBase) {
 
   eventBase_ = eventBase;
   ioHandler_.attachEventBase(eventBase);
+
+  updateEventRegistration();
+
   writeTimeout_.attachEventBase(eventBase);
   if (evbChangeCb_) {
     evbChangeCb_->evbAttached(this);
@@ -1475,6 +1468,9 @@ void AsyncSocket::detachEventBase() {
   eventBase_->dcheckIsInEventBaseThread();
 
   eventBase_ = nullptr;
+
+  ioHandler_.unregisterHandler();
+
   ioHandler_.detachEventBase();
   writeTimeout_.detachEventBase();
   if (evbChangeCb_) {
@@ -1486,7 +1482,7 @@ bool AsyncSocket::isDetachable() const {
   DCHECK(eventBase_ != nullptr);
   eventBase_->dcheckIsInEventBaseThread();
 
-  return !ioHandler_.isHandlerRegistered() && !writeTimeout_.isScheduled();
+  return !writeTimeout_.isScheduled();
 }
 
 void AsyncSocket::cacheAddresses() {
@@ -1675,7 +1671,11 @@ void AsyncSocket::ioReady(uint16_t events) noexcept {
   // EventHandler::WRITE is set. Any of these flags can
   // indicate that there are messages available in the socket
   // error message queue.
-  handleErrMessages();
+  // Return if we handle any error messages - this is to avoid
+  // unnecessary read/write calls
+  if (handleErrMessages()) {
+    return;
+  }
 
   // Return now if handleErrMessages() detached us from our EventBase
   if (eventBase_ != originalEventBase) {
@@ -1749,16 +1749,15 @@ void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) {
   readCallback_->getReadBuffer(buf, buflen);
 }
 
-void AsyncSocket::handleErrMessages() noexcept {
+size_t AsyncSocket::handleErrMessages() noexcept {
   // This method has non-empty implementation only for platforms
   // supporting per-socket error queues.
   VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_
           << ", state=" << state_;
-  if (errMessageCallback_ == nullptr &&
-      (!zeroCopyEnabled_ || idZeroCopyBufPtrMap_.empty())) {
+  if (errMessageCallback_ == nullptr && idZeroCopyBufPtrMap_.empty()) {
     VLOG(7) << "AsyncSocket::handleErrMessages(): "
             << "no callback installed - exiting.";
-    return;
+    return 0;
   }
 
 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
@@ -1778,6 +1777,7 @@ void AsyncSocket::handleErrMessages() noexcept {
   msg.msg_flags = 0;
 
   int ret;
+  size_t num = 0;
   while (true) {
     ret = recvmsg(fd_, &msg, MSG_ERRQUEUE);
     VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
@@ -1793,12 +1793,14 @@ void AsyncSocket::handleErrMessages() noexcept {
           errnoCopy);
         failErrMessageRead(__func__, ex);
       }
-      return;
+
+      return num;
     }
 
     for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
          cmsg != nullptr && cmsg->cmsg_len != 0;
          cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+      ++num;
       if (isZeroCopyMsg(*cmsg)) {
         processZeroCopyMsg(*cmsg);
       } else {
@@ -1808,9 +1810,22 @@ void AsyncSocket::handleErrMessages() noexcept {
       }
     }
   }
+#else
+  return 0;
 #endif // FOLLY_HAVE_MSG_ERRQUEUE
 }
 
+bool AsyncSocket::processZeroCopyWriteInProgress() noexcept {
+  eventBase_->dcheckIsInEventBaseThread();
+  if (idZeroCopyBufPtrMap_.empty()) {
+    return true;
+  }
+
+  handleErrMessages();
+
+  return idZeroCopyBufPtrMap_.empty();
+}
+
 void AsyncSocket::handleRead() noexcept {
   VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
           << ", state=" << state_;
@@ -2354,6 +2369,14 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
     // this bug is fixed.
     tryAgain |= (errno == ENOTCONN);
 #endif
+
+    // workaround for running with zerocopy enabled but without a proper
+    // memlock value - see ulimit -l
+    if (zeroCopyEnabled_ && (errno == ENOBUFS)) {
+      tryAgain = true;
+      zeroCopyEnabled_ = false;
+    }
+
     if (!writeResult.exception && tryAgain) {
       // TCP buffer is full; we can't write any more data right now.
       *countWritten = 0;