Change kDefaultZeroCopyThreshold to 0 to avoid a regression and avoid a failure while...
[folly.git] / folly / io / async / AsyncSocket.cpp
index a6fbe6fe28f73ae4e60d71b1d01785c9b07c170c..a4c41969e49a9278d95094bc4101abd9cdda935d 100644 (file)
@@ -42,11 +42,11 @@ namespace fsp = folly::portability::sockets;
 namespace folly {
 
 static constexpr bool msgErrQueueSupported =
-#ifdef MSG_ERRQUEUE
+#ifdef FOLLY_HAVE_MSG_ERRQUEUE
     true;
 #else
     false;
-#endif // MSG_ERRQUEUE
+#endif // FOLLY_HAVE_MSG_ERRQUEUE
 
 // static members initializers
 const AsyncSocket::OptionMap AsyncSocket::emptyOptionMap;
@@ -114,16 +114,16 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
     if (bytesWritten_) {
       if (socket_->isZeroCopyRequest(writeFlags)) {
         if (isComplete()) {
-          socket_->addZeroCopyBuff(std::move(ioBuf_));
+          socket_->addZeroCopyBuf(std::move(ioBuf_));
         } else {
-          socket_->addZeroCopyBuff(ioBuf_.get());
+          socket_->addZeroCopyBuf(ioBuf_.get());
         }
       } else {
         // this happens if at least one of the prev requests were sent
         // with zero copy but not the last one
         if (isComplete() && socket_->getZeroCopy() &&
-            socket_->containsZeroCopyBuff(ioBuf_.get())) {
-          socket_->setZeroCopyBuff(std::move(ioBuf_));
+            socket_->containsZeroCopyBuf(ioBuf_.get())) {
+          socket_->setZeroCopyBuf(std::move(ioBuf_));
         }
       }
     }
@@ -237,7 +237,7 @@ int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(
 
 namespace {
 static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback;
-}
+} // namespace
 
 AsyncSocket::AsyncSocket()
     : eventBase_(nullptr),
@@ -891,50 +891,49 @@ void AsyncSocket::adjustZeroCopyFlags(
   }
 }
 
-void AsyncSocket::addZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf) {
+void AsyncSocket::addZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf) {
   uint32_t id = getNextZeroCopyBuffId();
   folly::IOBuf* ptr = buf.get();
 
   idZeroCopyBufPtrMap_[id] = ptr;
-  auto& p = idZeroCopyBufPtrToBufMap_[ptr];
-  p.first++;
-  CHECK(p.second.get() == nullptr);
-  p.second = std::move(buf);
+  auto& p = idZeroCopyBufInfoMap_[ptr];
+  p.count_++;
+  CHECK(p.buf_.get() == nullptr);
+  p.buf_ = std::move(buf);
 }
 
-void AsyncSocket::addZeroCopyBuff(folly::IOBuf* ptr) {
+void AsyncSocket::addZeroCopyBuf(folly::IOBuf* ptr) {
   uint32_t id = getNextZeroCopyBuffId();
   idZeroCopyBufPtrMap_[id] = ptr;
 
-  idZeroCopyBufPtrToBufMap_[ptr].first++;
+  idZeroCopyBufInfoMap_[ptr].count_++;
 }
 
-void AsyncSocket::releaseZeroCopyBuff(uint32_t id) {
+void AsyncSocket::releaseZeroCopyBuf(uint32_t id) {
   auto iter = idZeroCopyBufPtrMap_.find(id);
   CHECK(iter != idZeroCopyBufPtrMap_.end());
   auto ptr = iter->second;
-  auto iter1 = idZeroCopyBufPtrToBufMap_.find(ptr);
-  CHECK(iter1 != idZeroCopyBufPtrToBufMap_.end());
-  if (0 == --iter1->second.first) {
-    idZeroCopyBufPtrToBufMap_.erase(iter1);
+  auto iter1 = idZeroCopyBufInfoMap_.find(ptr);
+  CHECK(iter1 != idZeroCopyBufInfoMap_.end());
+  if (0 == --iter1->second.count_) {
+    idZeroCopyBufInfoMap_.erase(iter1);
   }
 }
 
-void AsyncSocket::setZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf) {
+void AsyncSocket::setZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf) {
   folly::IOBuf* ptr = buf.get();
-  auto& p = idZeroCopyBufPtrToBufMap_[ptr];
-  CHECK(p.second.get() == nullptr);
+  auto& p = idZeroCopyBufInfoMap_[ptr];
+  CHECK(p.buf_.get() == nullptr);
 
-  p.second = std::move(buf);
+  p.buf_ = std::move(buf);
 }
 
-bool AsyncSocket::containsZeroCopyBuff(folly::IOBuf* ptr) {
-  return (
-      idZeroCopyBufPtrToBufMap_.find(ptr) != idZeroCopyBufPtrToBufMap_.end());
+bool AsyncSocket::containsZeroCopyBuf(folly::IOBuf* ptr) {
+  return (idZeroCopyBufInfoMap_.find(ptr) != idZeroCopyBufInfoMap_.end());
 }
 
 bool AsyncSocket::isZeroCopyMsg(const cmsghdr& cmsg) const {
-#ifdef MSG_ERRQUEUE
+#ifdef FOLLY_HAVE_MSG_ERRQUEUE
   if (zeroCopyEnabled_ &&
       ((cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
        (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR))) {
@@ -948,14 +947,21 @@ bool AsyncSocket::isZeroCopyMsg(const cmsghdr& cmsg) const {
 }
 
 void AsyncSocket::processZeroCopyMsg(const cmsghdr& cmsg) {
-#ifdef MSG_ERRQUEUE
+#ifdef FOLLY_HAVE_MSG_ERRQUEUE
   const struct sock_extended_err* serr =
       reinterpret_cast<const struct sock_extended_err*>(CMSG_DATA(&cmsg));
   uint32_t hi = serr->ee_data;
   uint32_t lo = serr->ee_info;
+  // disable zero copy if the buffer was actually copied
+  if ((serr->ee_code & SO_EE_CODE_ZEROCOPY_COPIED) && zeroCopyEnabled_) {
+    VLOG(2) << "AsyncSocket::processZeroCopyMsg(): setting "
+            << "zeroCopyEnabled_ = false due to SO_EE_CODE_ZEROCOPY_COPIED "
+            << "on " << fd_;
+    zeroCopyEnabled_ = false;
+  }
 
   for (uint32_t i = lo; i <= hi; i++) {
-    releaseZeroCopyBuff(i);
+    releaseZeroCopyBuf(i);
   }
 #endif
 }
@@ -1052,7 +1058,7 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
       } else if (countWritten == count) {
         // done, add the whole buffer
         if (isZeroCopyRequest(flags)) {
-          addZeroCopyBuff(std::move(ioBuf));
+          addZeroCopyBuf(std::move(ioBuf));
         }
         // We successfully wrote everything.
         // Invoke the callback and return.
@@ -1063,7 +1069,7 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
       } else { // continue writing the next writeReq
         // add just the ptr
         if (isZeroCopyRequest(flags)) {
-          addZeroCopyBuff(ioBuf.get());
+          addZeroCopyBuf(ioBuf.get());
         }
         if (bufferCallback_) {
           bufferCallback_->onEgressBuffered();
@@ -1509,6 +1515,11 @@ void AsyncSocket::cachePeerAddress() const {
   }
 }
 
+bool AsyncSocket::isZeroCopyWriteInProgress() const noexcept {
+  eventBase_->dcheckIsInEventBaseThread();
+  return (!idZeroCopyBufPtrMap_.empty());
+}
+
 void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
   cacheLocalAddress();
   *address = localAddr_;
@@ -1750,7 +1761,7 @@ void AsyncSocket::handleErrMessages() noexcept {
     return;
   }
 
-#ifdef MSG_ERRQUEUE
+#ifdef FOLLY_HAVE_MSG_ERRQUEUE
   uint8_t ctrl[1024];
   unsigned char data;
   struct msghdr msg;
@@ -1797,7 +1808,7 @@ void AsyncSocket::handleErrMessages() noexcept {
       }
     }
   }
-#endif //MSG_ERRQUEUE
+#endif // FOLLY_HAVE_MSG_ERRQUEUE
 }
 
 void AsyncSocket::handleRead() noexcept {