X-Git-Url: http://plrg.eecs.uci.edu/git/?p=folly.git;a=blobdiff_plain;f=folly%2Fio%2Fasync%2FAsyncSocket.cpp;h=3065c1cd55f129c3508a9c21ba41349b8b44e586;hp=6d32ca61063d04dc456a2b296c1a410ddc3e2b5c;hb=2a4ad2c8ddc1eb1be8b7ffb607de954ccc2b666e;hpb=69d97159209c5a77fdf7805155738604233d0b8a diff --git a/folly/io/async/AsyncSocket.cpp b/folly/io/async/AsyncSocket.cpp index 6d32ca61..3065c1cd 100644 --- a/folly/io/async/AsyncSocket.cpp +++ b/folly/io/async/AsyncSocket.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2017 Facebook, Inc. + * Copyright 2017-present Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include #include @@ -42,11 +41,11 @@ namespace fsp = folly::portability::sockets; namespace folly { static constexpr bool msgErrQueueSupported = -#ifdef MSG_ERRQUEUE +#ifdef FOLLY_HAVE_MSG_ERRQUEUE true; #else false; -#endif // MSG_ERRQUEUE +#endif // FOLLY_HAVE_MSG_ERRQUEUE // static members initializers const AsyncSocket::OptionMap AsyncSocket::emptyOptionMap; @@ -106,7 +105,7 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { writeFlags |= WriteFlags::CORK; } - socket_->adjustZeroCopyFlags(getOps(), getOpCount(), writeFlags); + socket_->adjustZeroCopyFlags(writeFlags); auto writeResult = socket_->performWrite( getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_); @@ -114,16 +113,16 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { if (bytesWritten_) { if (socket_->isZeroCopyRequest(writeFlags)) { if (isComplete()) { - socket_->addZeroCopyBuff(std::move(ioBuf_)); + socket_->addZeroCopyBuf(std::move(ioBuf_)); } else { - socket_->addZeroCopyBuff(ioBuf_.get()); + socket_->addZeroCopyBuf(ioBuf_.get()); } } else { // this happens if at least one of the prev requests were sent // with zero copy but not the last one if (isComplete() && socket_->getZeroCopy() && - socket_->containsZeroCopyBuff(ioBuf_.get())) { - socket_->setZeroCopyBuff(std::move(ioBuf_)); + socket_->containsZeroCopyBuf(ioBuf_.get())) { + socket_->setZeroCopyBuf(std::move(ioBuf_)); } } } @@ -237,7 +236,7 @@ int AsyncSocket::SendMsgParamsCallback::getDefaultFlags( namespace { static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback; -} +} // namespace AsyncSocket::AsyncSocket() : eventBase_(nullptr), @@ -272,13 +271,14 @@ AsyncSocket::AsyncSocket(EventBase* evb, connect(nullptr, ip, port, connectTimeout); } -AsyncSocket::AsyncSocket(EventBase* evb, int fd) - : eventBase_(evb), +AsyncSocket::AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId) + : zeroCopyBufId_(zeroCopyBufId), + eventBase_(evb), writeTimeout_(this, evb), ioHandler_(this, evb, fd), immediateReadHandler_(this) { - VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd=" - << fd << ")"; + VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd=" << fd + << ", zeroCopyBufId=" << zeroCopyBufId << ")"; init(); fd_ = fd; setCloseOnExec(); @@ -286,7 +286,10 @@ AsyncSocket::AsyncSocket(EventBase* evb, int fd) } AsyncSocket::AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket) - : AsyncSocket(oldAsyncSocket->getEventBase(), oldAsyncSocket->detachFd()) { + : AsyncSocket( + oldAsyncSocket->getEventBase(), + oldAsyncSocket->detachFd(), + oldAsyncSocket->getZeroCopyBufId()) { preReceivedData_ = std::move(oldAsyncSocket->preReceivedData_); } @@ -307,7 +310,7 @@ void AsyncSocket::init() { readCallback_ = nullptr; writeReqHead_ = nullptr; writeReqTail_ = nullptr; - shutdownSocketSet_ = nullptr; + wShutdownSocketSet_.reset(); appBytesWritten_ = 0; appBytesReceived_ = 0; sendMsgParamCallback_ = &defaultSendMsgParamsCallback; @@ -336,8 +339,8 @@ int AsyncSocket::detachFd() { << ", events=" << std::hex << eventFlags_ << ")"; // Extract the fd, and set fd_ to -1 first, so closeNow() won't // actually close the descriptor. - if (shutdownSocketSet_) { - shutdownSocketSet_->remove(fd_); + if (const auto socketSet = wShutdownSocketSet_.lock()) { + socketSet->remove(fd_); } int fd = fd_; fd_ = -1; @@ -355,17 +358,24 @@ const folly::SocketAddress& AsyncSocket::anyAddress() { return anyAddress; } -void AsyncSocket::setShutdownSocketSet(ShutdownSocketSet* newSS) { - if (shutdownSocketSet_ == newSS) { +void AsyncSocket::setShutdownSocketSet( + const std::weak_ptr& wNewSS) { + const auto newSS = wNewSS.lock(); + const auto shutdownSocketSet = wShutdownSocketSet_.lock(); + + if (newSS == shutdownSocketSet) { return; } - if (shutdownSocketSet_ && fd_ != -1) { - shutdownSocketSet_->remove(fd_); + + if (shutdownSocketSet && fd_ != -1) { + shutdownSocketSet->remove(fd_); } - shutdownSocketSet_ = newSS; - if (shutdownSocketSet_ && fd_ != -1) { - shutdownSocketSet_->add(fd_); + + if (newSS && fd_ != -1) { + newSS->add(fd_); } + + wShutdownSocketSet_ = wNewSS; } void AsyncSocket::setCloseOnExec() { @@ -420,8 +430,8 @@ void AsyncSocket::connect(ConnectCallback* callback, withAddr("failed to create socket"), errnoCopy); } - if (shutdownSocketSet_) { - shutdownSocketSet_->add(fd_); + if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) { + shutdownSocketSet->add(fd_); } ioHandler_.changeHandlerFD(fd_); @@ -651,6 +661,22 @@ void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) { << ", fd=" << fd_ << ", callback=" << callback << ", state=" << state_; + // In the latest stable kernel 4.14.3 as of 2017-12-04, unix domain + // socket does not support MSG_ERRQUEUE. So recvmsg(MSG_ERRQUEUE) + // will read application data from unix doamin socket as error + // message, which breaks the message flow in application. Feel free + // to remove the next code block if MSG_ERRQUEUE is added for unix + // domain socket in the future. + if (callback != nullptr) { + cacheLocalAddress(); + if (localAddr_.getFamily() == AF_UNIX) { + LOG(ERROR) << "Failed to set ErrMessageCallback=" << callback + << " for Unix Doamin Socket where MSG_ERRQUEUE is unsupported," + << " fd=" << fd_; + return; + } + } + // Short circuit if callback is the same as the existing errMessageCallback_. if (callback == errMessageCallback_) { return; @@ -843,90 +869,61 @@ bool AsyncSocket::setZeroCopy(bool enable) { return false; } -void AsyncSocket::setZeroCopyWriteChainThreshold(size_t threshold) { - zeroCopyWriteChainThreshold_ = threshold; -} - bool AsyncSocket::isZeroCopyRequest(WriteFlags flags) { return (zeroCopyEnabled_ && isSet(flags, WriteFlags::WRITE_MSG_ZEROCOPY)); } -void AsyncSocket::adjustZeroCopyFlags( - folly::IOBuf* buf, - folly::WriteFlags& flags) { - if (zeroCopyEnabled_ && zeroCopyWriteChainThreshold_ && buf) { - if (buf->computeChainDataLength() >= zeroCopyWriteChainThreshold_) { - flags |= folly::WriteFlags::WRITE_MSG_ZEROCOPY; - } else { - flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY); - } - } -} - -void AsyncSocket::adjustZeroCopyFlags( - const iovec* vec, - uint32_t count, - folly::WriteFlags& flags) { - if (zeroCopyEnabled_ && zeroCopyWriteChainThreshold_) { - count = std::min(count, kIovMax); - size_t sum = 0; - for (uint32_t i = 0; i < count; ++i) { - const iovec* v = vec + i; - sum += v->iov_len; - } - - if (sum >= zeroCopyWriteChainThreshold_) { - flags |= folly::WriteFlags::WRITE_MSG_ZEROCOPY; - } else { - flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY); - } +void AsyncSocket::adjustZeroCopyFlags(folly::WriteFlags& flags) { + if (!zeroCopyEnabled_) { + flags = unSet(flags, folly::WriteFlags::WRITE_MSG_ZEROCOPY); } } -void AsyncSocket::addZeroCopyBuff(std::unique_ptr&& buf) { - uint32_t id = getNextZeroCopyBuffId(); +void AsyncSocket::addZeroCopyBuf(std::unique_ptr&& buf) { + uint32_t id = getNextZeroCopyBufId(); folly::IOBuf* ptr = buf.get(); idZeroCopyBufPtrMap_[id] = ptr; - auto& p = idZeroCopyBufPtrToBufMap_[ptr]; - p.first++; - CHECK(p.second.get() == nullptr); - p.second = std::move(buf); + auto& p = idZeroCopyBufInfoMap_[ptr]; + p.count_++; + CHECK(p.buf_.get() == nullptr); + p.buf_ = std::move(buf); } -void AsyncSocket::addZeroCopyBuff(folly::IOBuf* ptr) { - uint32_t id = getNextZeroCopyBuffId(); +void AsyncSocket::addZeroCopyBuf(folly::IOBuf* ptr) { + uint32_t id = getNextZeroCopyBufId(); idZeroCopyBufPtrMap_[id] = ptr; - idZeroCopyBufPtrToBufMap_[ptr].first++; + idZeroCopyBufInfoMap_[ptr].count_++; } -void AsyncSocket::releaseZeroCopyBuff(uint32_t id) { +void AsyncSocket::releaseZeroCopyBuf(uint32_t id) { auto iter = idZeroCopyBufPtrMap_.find(id); CHECK(iter != idZeroCopyBufPtrMap_.end()); auto ptr = iter->second; - auto iter1 = idZeroCopyBufPtrToBufMap_.find(ptr); - CHECK(iter1 != idZeroCopyBufPtrToBufMap_.end()); - if (0 == --iter1->second.first) { - idZeroCopyBufPtrToBufMap_.erase(iter1); + auto iter1 = idZeroCopyBufInfoMap_.find(ptr); + CHECK(iter1 != idZeroCopyBufInfoMap_.end()); + if (0 == --iter1->second.count_) { + idZeroCopyBufInfoMap_.erase(iter1); } + + idZeroCopyBufPtrMap_.erase(iter); } -void AsyncSocket::setZeroCopyBuff(std::unique_ptr&& buf) { +void AsyncSocket::setZeroCopyBuf(std::unique_ptr&& buf) { folly::IOBuf* ptr = buf.get(); - auto& p = idZeroCopyBufPtrToBufMap_[ptr]; - CHECK(p.second.get() == nullptr); + auto& p = idZeroCopyBufInfoMap_[ptr]; + CHECK(p.buf_.get() == nullptr); - p.second = std::move(buf); + p.buf_ = std::move(buf); } -bool AsyncSocket::containsZeroCopyBuff(folly::IOBuf* ptr) { - return ( - idZeroCopyBufPtrToBufMap_.find(ptr) != idZeroCopyBufPtrToBufMap_.end()); +bool AsyncSocket::containsZeroCopyBuf(folly::IOBuf* ptr) { + return (idZeroCopyBufInfoMap_.find(ptr) != idZeroCopyBufInfoMap_.end()); } bool AsyncSocket::isZeroCopyMsg(const cmsghdr& cmsg) const { -#ifdef MSG_ERRQUEUE +#ifdef FOLLY_HAVE_MSG_ERRQUEUE if (zeroCopyEnabled_ && ((cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) || (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR))) { @@ -940,14 +937,21 @@ bool AsyncSocket::isZeroCopyMsg(const cmsghdr& cmsg) const { } void AsyncSocket::processZeroCopyMsg(const cmsghdr& cmsg) { -#ifdef MSG_ERRQUEUE +#ifdef FOLLY_HAVE_MSG_ERRQUEUE const struct sock_extended_err* serr = reinterpret_cast(CMSG_DATA(&cmsg)); uint32_t hi = serr->ee_data; uint32_t lo = serr->ee_info; + // disable zero copy if the buffer was actually copied + if ((serr->ee_code & SO_EE_CODE_ZEROCOPY_COPIED) && zeroCopyEnabled_) { + VLOG(2) << "AsyncSocket::processZeroCopyMsg(): setting " + << "zeroCopyEnabled_ = false due to SO_EE_CODE_ZEROCOPY_COPIED " + << "on " << fd_; + zeroCopyEnabled_ = false; + } for (uint32_t i = lo; i <= hi; i++) { - releaseZeroCopyBuff(i); + releaseZeroCopyBuf(i); } #endif } @@ -969,7 +973,7 @@ void AsyncSocket::writev(WriteCallback* callback, void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr&& buf, WriteFlags flags) { - adjustZeroCopyFlags(buf.get(), flags); + adjustZeroCopyFlags(flags); constexpr size_t kSmallSizeMax = 64; size_t count = buf->countChainElements(); @@ -1043,8 +1047,8 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec, return failWrite(__func__, callback, 0, ex); } else if (countWritten == count) { // done, add the whole buffer - if (isZeroCopyRequest(flags)) { - addZeroCopyBuff(std::move(ioBuf)); + if (countWritten && isZeroCopyRequest(flags)) { + addZeroCopyBuf(std::move(ioBuf)); } // We successfully wrote everything. // Invoke the callback and return. @@ -1054,8 +1058,8 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec, return; } else { // continue writing the next writeReq // add just the ptr - if (isZeroCopyRequest(flags)) { - addZeroCopyBuff(ioBuf.get()); + if (bytesWritten && isZeroCopyRequest(flags)) { + addZeroCopyBuf(ioBuf.get()); } if (bufferCallback_) { bufferCallback_->onEgressBuffered(); @@ -1447,6 +1451,9 @@ void AsyncSocket::attachEventBase(EventBase* eventBase) { eventBase_ = eventBase; ioHandler_.attachEventBase(eventBase); + + updateEventRegistration(); + writeTimeout_.attachEventBase(eventBase); if (evbChangeCb_) { evbChangeCb_->evbAttached(this); @@ -1461,6 +1468,9 @@ void AsyncSocket::detachEventBase() { eventBase_->dcheckIsInEventBaseThread(); eventBase_ = nullptr; + + ioHandler_.unregisterHandler(); + ioHandler_.detachEventBase(); writeTimeout_.detachEventBase(); if (evbChangeCb_) { @@ -1472,7 +1482,7 @@ bool AsyncSocket::isDetachable() const { DCHECK(eventBase_ != nullptr); eventBase_->dcheckIsInEventBaseThread(); - return !ioHandler_.isHandlerRegistered() && !writeTimeout_.isScheduled(); + return !writeTimeout_.isScheduled(); } void AsyncSocket::cacheAddresses() { @@ -1501,6 +1511,11 @@ void AsyncSocket::cachePeerAddress() const { } } +bool AsyncSocket::isZeroCopyWriteInProgress() const noexcept { + eventBase_->dcheckIsInEventBaseThread(); + return (!idZeroCopyBufPtrMap_.empty()); +} + void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const { cacheLocalAddress(); *address = localAddr_; @@ -1656,7 +1671,11 @@ void AsyncSocket::ioReady(uint16_t events) noexcept { // EventHandler::WRITE is set. Any of these flags can // indicate that there are messages available in the socket // error message queue. - handleErrMessages(); + // Return if we handle any error messages - this is to avoid + // unnecessary read/write calls + if (handleErrMessages()) { + return; + } // Return now if handleErrMessages() detached us from our EventBase if (eventBase_ != originalEventBase) { @@ -1730,19 +1749,18 @@ void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) { readCallback_->getReadBuffer(buf, buflen); } -void AsyncSocket::handleErrMessages() noexcept { +size_t AsyncSocket::handleErrMessages() noexcept { // This method has non-empty implementation only for platforms // supporting per-socket error queues. VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_ << ", state=" << state_; - if (errMessageCallback_ == nullptr && - (!zeroCopyEnabled_ || idZeroCopyBufPtrMap_.empty())) { + if (errMessageCallback_ == nullptr && idZeroCopyBufPtrMap_.empty()) { VLOG(7) << "AsyncSocket::handleErrMessages(): " << "no callback installed - exiting."; - return; + return 0; } -#ifdef MSG_ERRQUEUE +#ifdef FOLLY_HAVE_MSG_ERRQUEUE uint8_t ctrl[1024]; unsigned char data; struct msghdr msg; @@ -1759,6 +1777,7 @@ void AsyncSocket::handleErrMessages() noexcept { msg.msg_flags = 0; int ret; + size_t num = 0; while (true) { ret = recvmsg(fd_, &msg, MSG_ERRQUEUE); VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret; @@ -1774,12 +1793,14 @@ void AsyncSocket::handleErrMessages() noexcept { errnoCopy); failErrMessageRead(__func__, ex); } - return; + + return num; } for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr && cmsg->cmsg_len != 0; cmsg = CMSG_NXTHDR(&msg, cmsg)) { + ++num; if (isZeroCopyMsg(*cmsg)) { processZeroCopyMsg(*cmsg); } else { @@ -1789,7 +1810,20 @@ void AsyncSocket::handleErrMessages() noexcept { } } } -#endif //MSG_ERRQUEUE +#else + return 0; +#endif // FOLLY_HAVE_MSG_ERRQUEUE +} + +bool AsyncSocket::processZeroCopyWriteInProgress() noexcept { + eventBase_->dcheckIsInEventBaseThread(); + if (idZeroCopyBufPtrMap_.empty()) { + return true; + } + + handleErrMessages(); + + return idZeroCopyBufPtrMap_.empty(); } void AsyncSocket::handleRead() noexcept { @@ -2335,6 +2369,14 @@ AsyncSocket::WriteResult AsyncSocket::performWrite( // this bug is fixed. tryAgain |= (errno == ENOTCONN); #endif + + // workaround for running with zerocopy enabled but without a proper + // memlock value - see ulimit -l + if (zeroCopyEnabled_ && (errno == ENOBUFS)) { + tryAgain = true; + zeroCopyEnabled_ = false; + } + if (!writeResult.exception && tryAgain) { // TCP buffer is full; we can't write any more data right now. *countWritten = 0; @@ -2684,9 +2726,11 @@ void AsyncSocket::invalidState(WriteCallback* callback) { } void AsyncSocket::doClose() { - if (fd_ == -1) return; - if (shutdownSocketSet_) { - shutdownSocketSet_->close(fd_); + if (fd_ == -1) { + return; + } + if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) { + shutdownSocketSet->close(fd_); } else { ::close(fd_); }