Return if we handle any error messages to avoid unnecessarily calling recv/send
[folly.git] / folly / io / async / AsyncSocket.cpp
index 1dbaa220ceca5987833c8b114235eec30cce21a1..3065c1cd55f129c3508a9c21ba41349b8b44e586 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2017 Facebook, Inc.
+ * Copyright 2017-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -13,7 +13,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 #include <folly/io/async/AsyncSocket.h>
 
 #include <folly/ExceptionWrapper.h>
@@ -662,6 +661,22 @@ void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) {
           << ", fd=" << fd_ << ", callback=" << callback
           << ", state=" << state_;
 
+  // In the latest stable kernel 4.14.3 as of 2017-12-04, unix domain
+  // socket does not support MSG_ERRQUEUE. So recvmsg(MSG_ERRQUEUE)
+  // will read application data from unix doamin socket as error
+  // message, which breaks the message flow in application.  Feel free
+  // to remove the next code block if MSG_ERRQUEUE is added for unix
+  // domain socket in the future.
+  if (callback != nullptr) {
+    cacheLocalAddress();
+    if (localAddr_.getFamily() == AF_UNIX) {
+      LOG(ERROR) << "Failed to set ErrMessageCallback=" << callback
+                 << " for Unix Doamin Socket where MSG_ERRQUEUE is unsupported,"
+                 << " fd=" << fd_;
+      return;
+    }
+  }
+
   // Short circuit if callback is the same as the existing errMessageCallback_.
   if (callback == errMessageCallback_) {
     return;
@@ -891,6 +906,8 @@ void AsyncSocket::releaseZeroCopyBuf(uint32_t id) {
   if (0 == --iter1->second.count_) {
     idZeroCopyBufInfoMap_.erase(iter1);
   }
+
+  idZeroCopyBufPtrMap_.erase(iter);
 }
 
 void AsyncSocket::setZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf) {
@@ -1030,7 +1047,7 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
         return failWrite(__func__, callback, 0, ex);
       } else if (countWritten == count) {
         // done, add the whole buffer
-        if (isZeroCopyRequest(flags)) {
+        if (countWritten && isZeroCopyRequest(flags)) {
           addZeroCopyBuf(std::move(ioBuf));
         }
         // We successfully wrote everything.
@@ -1041,7 +1058,7 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
         return;
       } else { // continue writing the next writeReq
         // add just the ptr
-        if (isZeroCopyRequest(flags)) {
+        if (bytesWritten && isZeroCopyRequest(flags)) {
           addZeroCopyBuf(ioBuf.get());
         }
         if (bufferCallback_) {
@@ -1654,7 +1671,11 @@ void AsyncSocket::ioReady(uint16_t events) noexcept {
   // EventHandler::WRITE is set. Any of these flags can
   // indicate that there are messages available in the socket
   // error message queue.
-  handleErrMessages();
+  // Return if we handle any error messages - this is to avoid
+  // unnecessary read/write calls
+  if (handleErrMessages()) {
+    return;
+  }
 
   // Return now if handleErrMessages() detached us from our EventBase
   if (eventBase_ != originalEventBase) {
@@ -1728,7 +1749,7 @@ void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) {
   readCallback_->getReadBuffer(buf, buflen);
 }
 
-void AsyncSocket::handleErrMessages() noexcept {
+size_t AsyncSocket::handleErrMessages() noexcept {
   // This method has non-empty implementation only for platforms
   // supporting per-socket error queues.
   VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_
@@ -1736,7 +1757,7 @@ void AsyncSocket::handleErrMessages() noexcept {
   if (errMessageCallback_ == nullptr && idZeroCopyBufPtrMap_.empty()) {
     VLOG(7) << "AsyncSocket::handleErrMessages(): "
             << "no callback installed - exiting.";
-    return;
+    return 0;
   }
 
 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
@@ -1756,6 +1777,7 @@ void AsyncSocket::handleErrMessages() noexcept {
   msg.msg_flags = 0;
 
   int ret;
+  size_t num = 0;
   while (true) {
     ret = recvmsg(fd_, &msg, MSG_ERRQUEUE);
     VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
@@ -1771,12 +1793,14 @@ void AsyncSocket::handleErrMessages() noexcept {
           errnoCopy);
         failErrMessageRead(__func__, ex);
       }
-      return;
+
+      return num;
     }
 
     for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
          cmsg != nullptr && cmsg->cmsg_len != 0;
          cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+      ++num;
       if (isZeroCopyMsg(*cmsg)) {
         processZeroCopyMsg(*cmsg);
       } else {
@@ -1786,9 +1810,22 @@ void AsyncSocket::handleErrMessages() noexcept {
       }
     }
   }
+#else
+  return 0;
 #endif // FOLLY_HAVE_MSG_ERRQUEUE
 }
 
+bool AsyncSocket::processZeroCopyWriteInProgress() noexcept {
+  eventBase_->dcheckIsInEventBaseThread();
+  if (idZeroCopyBufPtrMap_.empty()) {
+    return true;
+  }
+
+  handleErrMessages();
+
+  return idZeroCopyBufPtrMap_.empty();
+}
+
 void AsyncSocket::handleRead() noexcept {
   VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
           << ", state=" << state_;