From 8ecc23d91b78319b9d65747c4a689b6473c52ce1 Mon Sep 17 00:00:00 2001 From: Maxim Georgiev Date: Thu, 9 Mar 2017 22:19:15 -0800 Subject: [PATCH] Implementing a callback interface for folly::AsyncSocket allowing to supply an ancillary data buffer with msghdr structure to sendmsg() system call Summary: Implementing a callback interface for folly::AsyncSocket allowing to supply an ancillary data buffer with msghdr structure to sendmsg() system call. Reviewed By: afrind Differential Revision: D4422168 fbshipit-source-id: 29a23b05f704aff796d368f4ac9514c49b7ce578 --- folly/io/async/AsyncSSLSocket.cpp | 26 +-- folly/io/async/AsyncSocket.cpp | 64 +++++-- folly/io/async/AsyncSocket.h | 89 ++++++++- folly/io/async/test/AsyncSSLSocketTest.cpp | 115 ++++++++++++ folly/io/async/test/AsyncSSLSocketTest.h | 199 ++++++++++++++++++++- folly/io/async/test/AsyncSocketTest.h | 58 ++++++ folly/io/async/test/AsyncSocketTest2.cpp | 174 ++++++++++++++++-- 7 files changed, 682 insertions(+), 43 deletions(-) diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index 3b99a4c4..26090b7b 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -34,7 +34,6 @@ #include #include #include -#include using folly::SocketAddress; using folly::SSLContext; @@ -59,6 +58,7 @@ using folly::SSLContext; using namespace folly::ssl; using folly::ssl::OpenSSLUtils; + // We have one single dummy SSL context so that we can implement attach // and detach methods in a thread safe fashion without modifying opnessl. static SSLContext *dummyCtx = nullptr; @@ -1624,7 +1624,6 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) { int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) { struct msghdr msg; struct iovec iov; - int flags = 0; AsyncSSLSocket* tsslSock; iov.iov_base = const_cast(in); @@ -1639,23 +1638,28 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) { tsslSock = reinterpret_cast(appData); CHECK(tsslSock); + WriteFlags flags = WriteFlags::NONE; if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ && tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) { - flags = MSG_EOR; + flags |= WriteFlags::EOR; } -#ifdef MSG_NOSIGNAL - flags |= MSG_NOSIGNAL; -#endif - -#ifdef MSG_MORE if (tsslSock->corkCurrentWrite_) { - flags |= MSG_MORE; + flags |= WriteFlags::CORK; + } + + int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(flags); + msg.msg_controllen = + tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags); + CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize, + msg.msg_controllen); + if (msg.msg_controllen != 0) { + msg.msg_control = reinterpret_cast(alloca(msg.msg_controllen)); + tsslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control); } -#endif auto result = tsslSock->sendSocketMessage( - OpenSSLUtils::getBioFd(b, nullptr), &msg, flags); + OpenSSLUtils::getBioFd(b, nullptr), &msg, msg_flags); BIO_clear_retry_flags(b); if (!result.exception && result.writeReturn <= 0) { if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) { diff --git a/folly/io/async/AsyncSocket.cpp b/folly/io/async/AsyncSocket.cpp index 827ee8e9..ddbab294 100644 --- a/folly/io/async/AsyncSocket.cpp +++ b/folly/io/async/AsyncSocket.cpp @@ -185,6 +185,33 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { struct iovec writeOps_[]; ///< write operation(s) list }; +int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(folly::WriteFlags flags) + noexcept { + int msg_flags = MSG_DONTWAIT; + +#ifdef MSG_NOSIGNAL // Linux-only + msg_flags |= MSG_NOSIGNAL; +#ifdef MSG_MORE + if (isSet(flags, WriteFlags::CORK)) { + // MSG_MORE tells the kernel we have more data to send, so wait for us to + // give it the rest of the data rather than immediately sending a partial + // frame, even when TCP_NODELAY is enabled. + msg_flags |= MSG_MORE; + } +#endif // MSG_MORE +#endif // MSG_NOSIGNAL + if (isSet(flags, WriteFlags::EOR)) { + // marks that this is the last byte of a record (response) + msg_flags |= MSG_EOR; + } + + return msg_flags; +} + +namespace { +static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback; +} + AsyncSocket::AsyncSocket() : eventBase_(nullptr), writeTimeout_(this, nullptr), @@ -254,6 +281,7 @@ void AsyncSocket::init() { shutdownSocketSet_ = nullptr; appBytesWritten_ = 0; appBytesReceived_ = 0; + sendMsgParamCallback_ = &defaultSendMsgParamsCallback; } AsyncSocket::~AsyncSocket() { @@ -625,6 +653,14 @@ AsyncSocket::ErrMessageCallback* AsyncSocket::getErrMessageCallback() const { return errMessageCallback_; } +void AsyncSocket::setSendMsgParamCB(SendMsgParamsCallback* callback) { + sendMsgParamCallback_ = callback; +} + +AsyncSocket::SendMsgParamsCallback* AsyncSocket::getSendMsgParamsCB() const { + return sendMsgParamCallback_; +} + void AsyncSocket::setReadCB(ReadCallback *callback) { VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_ << ", callback=" << callback << ", state=" << state_; @@ -1363,7 +1399,7 @@ int AsyncSocket::setTCPProfile(int profd) { } void AsyncSocket::ioReady(uint16_t events) noexcept { - VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd" << fd_ + VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd=" << fd_ << ", events=" << std::hex << events << ", state=" << state_; DestructorGuard dg(this); assert(events & EventHandler::READ_WRITE); @@ -2023,25 +2059,19 @@ AsyncSocket::WriteResult AsyncSocket::performWrite( msg.msg_namelen = 0; msg.msg_iov = const_cast(vec); msg.msg_iovlen = std::min(count, kIovMax); - msg.msg_control = nullptr; - msg.msg_controllen = 0; msg.msg_flags = 0; + msg.msg_controllen = sendMsgParamCallback_->getAncillaryDataSize(flags); + CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize, + msg.msg_controllen); - int msg_flags = MSG_DONTWAIT; - -#ifdef MSG_NOSIGNAL // Linux-only - msg_flags |= MSG_NOSIGNAL; - if (isSet(flags, WriteFlags::CORK)) { - // MSG_MORE tells the kernel we have more data to send, so wait for us to - // give it the rest of the data rather than immediately sending a partial - // frame, even when TCP_NODELAY is enabled. - msg_flags |= MSG_MORE; - } -#endif - if (isSet(flags, WriteFlags::EOR)) { - // marks that this is the last byte of a record (response) - msg_flags |= MSG_EOR; + if (msg.msg_controllen != 0) { + msg.msg_control = reinterpret_cast(alloca(msg.msg_controllen)); + sendMsgParamCallback_->getAncillaryData(flags, msg.msg_control); + } else { + msg.msg_control = nullptr; } + int msg_flags = sendMsgParamCallback_->getFlags(flags); + auto writeResult = sendSocketMessage(fd_, &msg, msg_flags); auto totalWritten = writeResult.writeReturn; if (totalWritten < 0) { diff --git a/folly/io/async/AsyncSocket.h b/folly/io/async/AsyncSocket.h index 3e2adbce..fca560bb 100644 --- a/folly/io/async/AsyncSocket.h +++ b/folly/io/async/AsyncSocket.h @@ -139,6 +139,77 @@ class AsyncSocket : virtual public AsyncTransportWrapper { virtual void errMessageError(const AsyncSocketException& ex) noexcept = 0; }; + class SendMsgParamsCallback { + public: + virtual ~SendMsgParamsCallback() = default; + + /** + * getFlags() will be invoked to retrieve the desired flags to be passed + * to ::sendmsg() system call. This method was intentionally declared + * non-virtual, so there is no way to override it. Instead feel free to + * override getFlagsImpl(flags, defaultFlags) method instead, and enjoy + * the convenience of defaultFlags passed there. + * + * @param flags Write flags requested for the given write operation + */ + int getFlags(folly::WriteFlags flags) noexcept { + return getFlagsImpl(flags, getDefaultFlags(flags)); + } + + /** + * getAncillaryData() will be invoked to initialize ancillary data + * buffer referred by "msg_control" field of msghdr structure passed to + * ::sendmsg() system call. The function assumes that the size of buffer + * is not smaller than the value returned by getAncillaryDataSize() method + * for the same combination of flags. + * + * @param flags Write flags requested for the given write operation + * @param data Pointer to ancillary data buffer to initialize. + */ + virtual void getAncillaryData( + folly::WriteFlags /*flags*/, + void* /*data*/) noexcept {} + + /** + * getAncillaryDataSize() will be invoked to retrieve the size of + * ancillary data buffer which should be passed to ::sendmsg() system call + * + * @param flags Write flags requested for the given write operation + */ + virtual uint32_t getAncillaryDataSize(folly::WriteFlags /*flags*/) + noexcept { + return 0; + } + + static const size_t maxAncillaryDataSize{0x5000}; + + private: + /** + * getFlagsImpl() will be invoked by getFlags(folly::WriteFlags flags) + * method to retrieve the flags to be passed to ::sendmsg() system call. + * SendMsgParamsCallback::getFlags() is calling this method, and returns + * its results directly to the caller in AsyncSocket. + * Classes inheriting from SendMsgParamsCallback are welcome to override + * this method to force SendMsgParamsCallback to return its own set + * of flags. + * + * @param flags Write flags requested for the given write operation + * @param defaultflags A set of message flags returned by getDefaultFlags() + * method for the given "flags" mask. + */ + virtual int getFlagsImpl(folly::WriteFlags /*flags*/, int defaultFlags) { + return defaultFlags; + } + + /** + * getDefaultFlags() will be invoked by getFlags(folly::WriteFlags flags) + * to retrieve the default set of flags, and pass them to getFlagsImpl(...) + * + * @param flags Write flags requested for the given write operation + */ + int getDefaultFlags(folly::WriteFlags flags) noexcept; + }; + explicit AsyncSocket(); /** * Create a new unconnected AsyncSocket. @@ -411,6 +482,20 @@ class AsyncSocket : virtual public AsyncTransportWrapper { */ ErrMessageCallback* getErrMessageCallback() const; + /** + * Set a pointer to SendMsgParamsCallback implementation which + * will be used to form ::sendmsg() system call parameters + * + */ + void setSendMsgParamCB(SendMsgParamsCallback* callback); + + /** + * Get a pointer to SendMsgParamsCallback implementation currently + * registered with this socket. + * + */ + SendMsgParamsCallback* getSendMsgParamsCB() const; + // Read and write methods void setReadCB(ReadCallback* callback) override; ReadCallback* getReadCallback() const override; @@ -1010,6 +1095,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ConnectCallback* connectCallback_; ///< ConnectCallback ErrMessageCallback* errMessageCallback_; ///< TimestampCallback + SendMsgParamsCallback* ///< Callback for retreaving + sendMsgParamCallback_; ///< ::sendmsg() parameters ReadCallback* readCallback_; ///< ReadCallback WriteRequest* writeReqHead_; ///< Chain of WriteRequests WriteRequest* writeReqTail_; ///< End of WriteRequest chain @@ -1022,7 +1109,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { // socket. std::unique_ptr preReceivedData_; - int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any. + int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any std::chrono::steady_clock::time_point connectStartTime_; std::chrono::steady_clock::time_point connectEndTime_; diff --git a/folly/io/async/test/AsyncSSLSocketTest.cpp b/folly/io/async/test/AsyncSSLSocketTest.cpp index 41284b6e..fbbf999b 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.cpp +++ b/folly/io/async/test/AsyncSSLSocketTest.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -1958,6 +1959,120 @@ TEST(AsyncSSLSocketTest, TestPreReceivedData) { serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten()); } +/** + * Test overriding the flags passed to "sendmsg()" system call, + * and verifying that write requests fail properly. + */ +TEST(AsyncSSLSocketTest, SendMsgParamsCallback) { + // Start listening on a local port + SendMsgFlagsCallback msgCallback; + ExpectWriteErrorCallback writeCallback(&msgCallback); + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + // Set up SSL context. + auto sslContext = std::make_shared(); + sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + + // connect + auto socket = std::make_shared(server.getAddress(), + sslContext); + socket->open(); + + // Setting flags to "-1" to trigger "Invalid argument" error + // on attempt to use this flags in sendmsg() system call. + msgCallback.resetFlags(-1); + + // write() + std::vector buf(128, 'a'); + ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size()); + + // close() + socket->close(); + + cerr << "SendMsgParamsCallback test completed" << endl; +} + +#ifdef MSG_ERRQUEUE +/** + * Test connecting to, writing to, reading from, and closing the + * connection to the SSL server. + */ +TEST(AsyncSSLSocketTest, SendMsgDataCallback) { + // This test requires Linux kernel v4.6 or later + struct utsname s_uname; + memset(&s_uname, 0, sizeof(s_uname)); + ASSERT_EQ(uname(&s_uname), 0); + int major, minor; + folly::StringPiece extra; + if (folly::split( + '.', std::string(s_uname.release) + ".", major, minor, extra)) { + if (major < 4 || (major == 4 && minor < 6)) { + LOG(INFO) << "Kernel version: 4.6 and newer required for this test (" + << "kernel ver. " << s_uname.release << " detected)."; + return; + } + } + + // Start listening on a local port + SendMsgDataCallback msgCallback; + WriteCheckTimestampCallback writeCallback(&msgCallback); + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + // Set up SSL context. + auto sslContext = std::make_shared(); + sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + + // connect + auto socket = std::make_shared(server.getAddress(), + sslContext); + socket->open(); + + // Adding MSG_EOR flag to the message flags - it'll trigger + // timestamp generation for the last byte of the message. + msgCallback.resetFlags(MSG_DONTWAIT|MSG_NOSIGNAL|MSG_EOR); + + // Init ancillary data buffer to trigger timestamp notification + union { + uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))]; + struct cmsghdr cmsg; + } u; + u.cmsg.cmsg_level = SOL_SOCKET; + u.cmsg.cmsg_type = SO_TIMESTAMPING; + u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t)); + uint32_t flags = + SOF_TIMESTAMPING_TX_SCHED | + SOF_TIMESTAMPING_TX_SOFTWARE | + SOF_TIMESTAMPING_TX_ACK; + memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t)); + std::vector ctrl(CMSG_LEN(sizeof(uint32_t))); + memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t))); + msgCallback.resetData(std::move(ctrl)); + + // write() + std::vector buf(128, 'a'); + socket->write(buf.data(), buf.size()); + + // read() + std::vector readbuf(buf.size()); + uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size()); + EXPECT_EQ(bytesRead, buf.size()); + EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin())); + + writeCallback.checkForTimestampNotifications(); + + // close() + socket->close(); + + cerr << "SendMsgDataCallback test completed" << endl; +} +#endif // MSG_ERRQUEUE + #endif } // namespace diff --git a/folly/io/async/test/AsyncSSLSocketTest.h b/folly/io/async/test/AsyncSSLSocketTest.h index 7965f8e4..ff70d3b8 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.h +++ b/folly/io/async/test/AsyncSSLSocketTest.h @@ -46,24 +46,106 @@ namespace folly { // are responsible for setting the succeeded state properly before the // destructors are called. +class SendMsgParamsCallbackBase : + public folly::AsyncSocket::SendMsgParamsCallback { + public: + SendMsgParamsCallbackBase() {} + + void setSocket( + const std::shared_ptr &socket) { + socket_ = socket; + oldCallback_ = socket_->getSendMsgParamsCB(); + socket_->setSendMsgParamCB(this); + } + + int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept + override { + return oldCallback_->getFlags(flags); + } + + void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override { + oldCallback_->getAncillaryData(flags, data); + } + + uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override { + return oldCallback_->getAncillaryDataSize(flags); + } + + std::shared_ptr socket_; + folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr}; +}; + +class SendMsgFlagsCallback : public SendMsgParamsCallbackBase { + public: + SendMsgFlagsCallback() {} + + void resetFlags(int flags) { + flags_ = flags; + } + + int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept + override { + if (flags_) { + return flags_; + } else { + return oldCallback_->getFlags(flags); + } + } + + int flags_{0}; +}; + +class SendMsgDataCallback : public SendMsgFlagsCallback { + public: + SendMsgDataCallback() {} + + void resetData(std::vector&& data) { + ancillaryData_.swap(data); + } + + void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override { + if (ancillaryData_.size()) { + std::cerr << "getAncillaryData: copying data" << std::endl; + memcpy(data, ancillaryData_.data(), ancillaryData_.size()); + } else { + oldCallback_->getAncillaryData(flags, data); + } + } + + uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override { + if (ancillaryData_.size()) { + std::cerr << "getAncillaryDataSize: returning size" << std::endl; + return ancillaryData_.size(); + } else { + return oldCallback_->getAncillaryDataSize(flags); + } + } + + std::vector ancillaryData_; +}; + class WriteCallbackBase : public AsyncTransportWrapper::WriteCallback { public: - WriteCallbackBase() + explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr) : state(STATE_WAITING) , bytesWritten(0) - , exception(AsyncSocketException::UNKNOWN, "none") {} + , exception(AsyncSocketException::UNKNOWN, "none") + , mcb_(mcb) {} ~WriteCallbackBase() { EXPECT_EQ(STATE_SUCCEEDED, state); } - void setSocket( + virtual void setSocket( const std::shared_ptr &socket) { socket_ = socket; + if (mcb_) { + mcb_->setSocket(socket); + } } - void writeSuccess() noexcept override { + virtual void writeSuccess() noexcept override { std::cerr << "writeSuccess" << std::endl; state = STATE_SUCCEEDED; } @@ -84,7 +166,116 @@ public: StateEnum state; size_t bytesWritten; AsyncSocketException exception; + SendMsgParamsCallbackBase* mcb_; +}; + +class ExpectWriteErrorCallback : +public WriteCallbackBase { +public: + explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr) + : WriteCallbackBase(mcb) {} + + ~ExpectWriteErrorCallback() { + EXPECT_EQ(STATE_FAILED, state); + EXPECT_EQ(exception.type_, + AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR); + EXPECT_EQ(exception.errno_, 22); + // Suppress the assert in ~WriteCallbackBase() + state = STATE_SUCCEEDED; + } +}; + +#ifdef MSG_ERRQUEUE +/* copied from include/uapi/linux/net_tstamp.h */ +/* SO_TIMESTAMPING gets an integer bit field comprised of these values */ +enum SOF_TIMESTAMPING { + SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1), + SOF_TIMESTAMPING_SOFTWARE = (1 << 4), + SOF_TIMESTAMPING_OPT_ID = (1 << 7), + SOF_TIMESTAMPING_TX_SCHED = (1 << 8), + SOF_TIMESTAMPING_TX_ACK = (1 << 9), + SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11), +}; + +class WriteCheckTimestampCallback : + public WriteCallbackBase { +public: + explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr) + : WriteCallbackBase(mcb) {} + + ~WriteCheckTimestampCallback() { + EXPECT_EQ(STATE_SUCCEEDED, state); + EXPECT_TRUE(gotTimestamp_); + EXPECT_TRUE(gotByteSeq_); + } + + void setSocket( + const std::shared_ptr &socket) override { + WriteCallbackBase::setSocket(socket); + + EXPECT_NE(socket_->getFd(), 0); + int flags = SOF_TIMESTAMPING_OPT_ID + | SOF_TIMESTAMPING_OPT_TSONLY + | SOF_TIMESTAMPING_SOFTWARE; + AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING}; + int ret = tstampingOpt.apply(socket_->getFd(), flags); + EXPECT_EQ(ret, 0); + } + + void checkForTimestampNotifications() noexcept { + int fd = socket_->getFd(); + std::vector ctrl(1024, 0); + unsigned char data; + struct msghdr msg; + iovec entry; + + memset(&msg, 0, sizeof(msg)); + entry.iov_base = &data; + entry.iov_len = sizeof(data); + msg.msg_iov = &entry; + msg.msg_iovlen = 1; + msg.msg_control = ctrl.data(); + msg.msg_controllen = ctrl.size(); + + int ret; + while (true) { + ret = recvmsg(fd, &msg, MSG_ERRQUEUE); + if (ret < 0) { + if (errno != EAGAIN) { + auto errnoCopy = errno; + std::cerr << "::recvmsg exited with code " << ret + << ", errno: " << errnoCopy << std::endl; + AsyncSocketException ex( + AsyncSocketException::INTERNAL_ERROR, + "recvmsg() failed", + errnoCopy); + exception = ex; + } + return; + } + + for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg != nullptr && cmsg->cmsg_len != 0; + cmsg = CMSG_NXTHDR(&msg, cmsg)) { + if (cmsg->cmsg_level == SOL_SOCKET && + cmsg->cmsg_type == SCM_TIMESTAMPING) { + gotTimestamp_ = true; + continue; + } + + if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) || + (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) { + gotByteSeq_ = true; + continue; + } + } + } + } + + bool gotTimestamp_{false}; + bool gotByteSeq_{false}; }; +#endif // MSG_ERRQUEUE class ReadCallbackBase : public AsyncTransportWrapper::ReadCallback { diff --git a/folly/io/async/test/AsyncSocketTest.h b/folly/io/async/test/AsyncSocketTest.h index d69c851b..254c14b2 100644 --- a/folly/io/async/test/AsyncSocketTest.h +++ b/folly/io/async/test/AsyncSocketTest.h @@ -229,6 +229,64 @@ class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback { bool gotByteSeq_{false}; }; +class TestSendMsgParamsCallback : + public folly::AsyncSocket::SendMsgParamsCallback { + public: + TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data) + : flags_(flags), + writeFlags_(folly::WriteFlags::NONE), + dataSize_(dataSize), + data_(data), + queriedFlags_(false), + queriedData_(false) + {} + + void reset(int flags) { + flags_ = flags; + writeFlags_ = folly::WriteFlags::NONE; + queriedFlags_ = false; + queriedData_ = false; + } + + int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept + override { + queriedFlags_ = true; + if (writeFlags_ == folly::WriteFlags::NONE) { + writeFlags_ = flags; + } else { + assert(flags == writeFlags_); + } + return flags_; + } + + void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override { + queriedData_ = true; + if (writeFlags_ == folly::WriteFlags::NONE) { + writeFlags_ = flags; + } else { + assert(flags == writeFlags_); + } + assert(data != nullptr); + memcpy(data, data_, dataSize_); + } + + uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override { + if (writeFlags_ == folly::WriteFlags::NONE) { + writeFlags_ = flags; + } else { + assert(flags == writeFlags_); + } + return dataSize_; + } + + int flags_; + folly::WriteFlags writeFlags_; + uint32_t dataSize_; + void* data_; + bool queriedFlags_; + bool queriedData_; +}; + class TestServer { public: // Create a TestServer. diff --git a/folly/io/async/test/AsyncSocketTest2.cpp b/folly/io/async/test/AsyncSocketTest2.cpp index 864a4031..a1988722 100644 --- a/folly/io/async/test/AsyncSocketTest2.cpp +++ b/folly/io/async/test/AsyncSocketTest2.cpp @@ -2823,21 +2823,11 @@ TEST(AsyncSocketTest, EvbCallbacks) { /* copied from include/uapi/linux/net_tstamp.h */ /* SO_TIMESTAMPING gets an integer bit field comprised of these values */ enum SOF_TIMESTAMPING { - // SOF_TIMESTAMPING_TX_HARDWARE = (1 << 0), - // SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1), - // SOF_TIMESTAMPING_RX_HARDWARE = (1 << 2), - // SOF_TIMESTAMPING_RX_SOFTWARE = (1 << 3), SOF_TIMESTAMPING_SOFTWARE = (1 << 4), - // SOF_TIMESTAMPING_SYS_HARDWARE = (1 << 5), - // SOF_TIMESTAMPING_RAW_HARDWARE = (1 << 6), SOF_TIMESTAMPING_OPT_ID = (1 << 7), SOF_TIMESTAMPING_TX_SCHED = (1 << 8), - // SOF_TIMESTAMPING_TX_ACK = (1 << 9), SOF_TIMESTAMPING_OPT_CMSG = (1 << 10), SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11), - - // SOF_TIMESTAMPING_LAST = SOF_TIMESTAMPING_OPT_TSONLY, - // SOF_TIMESTAMPING_MASK = (SOF_TIMESTAMPING_LAST - 1) | SOF_TIMESTAMPING_LAST, }; TEST(AsyncSocketTest, ErrMessageCallback) { TestServer server; @@ -3039,3 +3029,167 @@ TEST(AsyncSocket, PreReceivedDataTakeover) { evb.loop(); } + +TEST(AsyncSocketTest, SendMessageFlags) { + TestServer server; + TestSendMsgParamsCallback sendMsgCB( + MSG_DONTWAIT|MSG_NOSIGNAL|MSG_MORE, 0, nullptr); + + // connect() + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + std::shared_ptr acceptedSocket = server.accept(); + + evb.loop(); + ASSERT_EQ(ccb.state, STATE_SUCCEEDED); + + // Set SendMsgParamsCallback + socket->setSendMsgParamCB(&sendMsgCB); + ASSERT_EQ(socket->getSendMsgParamsCB(), &sendMsgCB); + + // Write the first portion of data. This data is expected to be + // sent out immediately. + std::vector buf(128, 'a'); + WriteCallback wcb; + sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL); + socket->write(&wcb, buf.data(), buf.size()); + ASSERT_EQ(wcb.state, STATE_SUCCEEDED); + ASSERT_TRUE(sendMsgCB.queriedFlags_); + ASSERT_FALSE(sendMsgCB.queriedData_); + + // Using different flags for the second write operation. + // MSG_MORE flag is expected to delay sending this + // data to the wire. + sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE); + socket->write(&wcb, buf.data(), buf.size()); + ASSERT_EQ(wcb.state, STATE_SUCCEEDED); + ASSERT_TRUE(sendMsgCB.queriedFlags_); + ASSERT_FALSE(sendMsgCB.queriedData_); + + // Make sure the accepted socket saw only the data from + // the first write request. + std::vector readbuf(2 * buf.size()); + uint32_t bytesRead = acceptedSocket->read(readbuf.data(), readbuf.size()); + ASSERT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin())); + ASSERT_EQ(bytesRead, buf.size()); + + // Make sure the server got a connection and received the data + acceptedSocket->close(); + socket->close(); + + ASSERT_TRUE(socket->isClosedBySelf()); + ASSERT_FALSE(socket->isClosedByPeer()); +} + +TEST(AsyncSocketTest, SendMessageAncillaryData) { + struct sockaddr_un addr = {AF_UNIX, + "AsyncSocketTest.SendMessageAncillaryData\0"}; + + // Clean up the name in the name space we're going to use + ASSERT_FALSE(remove(addr.sun_path) == -1 && errno != ENOENT); + + // Set up listening socket + int lfd = fsp::socket(AF_UNIX, SOCK_STREAM, 0); + ASSERT_NE(lfd, -1); + ASSERT_NE(bind(lfd, (struct sockaddr*)&addr, sizeof(addr)), -1) + << "Bind failed: " << errno; + + // Create the connecting socket + int csd = fsp::socket(AF_UNIX, SOCK_STREAM, 0); + ASSERT_NE(csd, -1); + + // Listen for incoming connect + ASSERT_NE(listen(lfd, 5), -1); + + // Connect to the listening socket + ASSERT_NE(fsp::connect(csd, (struct sockaddr*)&addr, sizeof(addr)), -1) + << "Connect request failed: " << errno; + + // Accept the connection + int sfd = accept(lfd, nullptr, nullptr); + ASSERT_NE(sfd, -1); + + // Instantiate AsyncSocket object for the connected socket + EventBase evb; + std::shared_ptr socket = AsyncSocket::newSocket(&evb, csd); + + // Open a temporary file and write a magic string to it + // We'll transfer the file handle to test the message parameters + // callback logic. + int tmpfd = open("/var/tmp", O_RDWR | O_TMPFILE); + ASSERT_NE(tmpfd, -1) << "Failed to open a temporary file"; + std::string magicString("Magic string"); + ASSERT_EQ(write(tmpfd, magicString.c_str(), magicString.length()), + magicString.length()); + + // Send message + union { + // Space large enough to hold an 'int' + char control[CMSG_SPACE(sizeof(int))]; + struct cmsghdr cmh; + } s_u; + s_u.cmh.cmsg_len = CMSG_LEN(sizeof(int)); + s_u.cmh.cmsg_level = SOL_SOCKET; + s_u.cmh.cmsg_type = SCM_RIGHTS; + memcpy(CMSG_DATA(&s_u.cmh), &tmpfd, sizeof(int)); + + // Set up the callback providing message parameters + TestSendMsgParamsCallback sendMsgCB( + MSG_DONTWAIT | MSG_NOSIGNAL, sizeof(s_u.control), s_u.control); + socket->setSendMsgParamCB(&sendMsgCB); + + // We must transmit at least 1 byte of real data in order + // to send ancillary data + int s_data = 12345; + WriteCallback wcb; + socket->write(&wcb, &s_data, sizeof(s_data)); + ASSERT_EQ(wcb.state, STATE_SUCCEEDED); + + // Receive the message + union { + // Space large enough to hold an 'int' + char control[CMSG_SPACE(sizeof(int))]; + struct cmsghdr cmh; + } r_u; + struct msghdr msgh; + struct iovec iov; + int r_data = 0; + + msgh.msg_control = r_u.control; + msgh.msg_controllen = sizeof(r_u.control); + msgh.msg_name = nullptr; + msgh.msg_namelen = 0; + msgh.msg_iov = &iov; + msgh.msg_iovlen = 1; + iov.iov_base = &r_data; + iov.iov_len = sizeof(r_data); + + // Receive data + ASSERT_NE(recvmsg(sfd, &msgh, 0), -1) << "recvmsg failed: " << errno; + + // Validate the received message + ASSERT_EQ(r_u.cmh.cmsg_len, CMSG_LEN(sizeof(int))); + ASSERT_EQ(r_u.cmh.cmsg_level, SOL_SOCKET); + ASSERT_EQ(r_u.cmh.cmsg_type, SCM_RIGHTS); + ASSERT_EQ(r_data, s_data); + int fd = 0; + memcpy(&fd, CMSG_DATA(&r_u.cmh), sizeof(int)); + ASSERT_NE(fd, 0); + + std::vector transferredMagicString(magicString.length() + 1, 0); + + // Reposition to the beginning of the file + ASSERT_EQ(0, lseek(fd, 0, SEEK_SET)); + + // Read the magic string back, and compare it with the original + ASSERT_EQ( + magicString.length(), + read(fd, transferredMagicString.data(), transferredMagicString.size())); + ASSERT_TRUE(std::equal( + magicString.begin(), + magicString.end(), + transferredMagicString.begin())); +} -- 2.34.1