From 889fe563fac92a64fd1a850bb09e9f6b7c431ac2 Mon Sep 17 00:00:00 2001 From: Subodh Iyengar Date: Tue, 7 Jun 2016 07:43:37 -0700 Subject: [PATCH] Add TFO support to AsyncSSLSocket Summary: This adds TFO support to AsyncSSLSocket which uses the support for TFO from AsyncSocket. Because of the way AsyncSSLSocket inherits from AsyncSocket it is tricky. The following changes were made: 1. Openssl internally will treat only errors with return code -1 as READ_REQUIRED or WRITE_REQUIRED errors. So this diff changes the return value of the errors in the TFO fallback cases to -1. 2. In case we fallback after SSL_connect() to a normal connect, we would have to restart the connection process after connect succeeds. To do this this overrides the connection success callback and restarts the connection before sending the callback to AsyncSocket because sometimes callbacks might synchronously call sslConn() in the normal connect cases. 3. Delegated bioWrite to call sendSocketMessage instead of sendmsg directly. Reviewed By: djwatson Differential Revision: D3391735 fbshipit-source-id: 61434f6de4a9c3d03973c9ab9e51eb49e751e5cf --- folly/io/async/AsyncSSLSocket.cpp | 31 ++- folly/io/async/AsyncSSLSocket.h | 2 + folly/io/async/AsyncSocket.cpp | 19 +- folly/io/async/AsyncSocket.h | 5 +- folly/io/async/test/AsyncSSLSocketTest.cpp | 225 ++++++++++++++++++++- folly/io/async/test/AsyncSSLSocketTest.h | 4 +- folly/io/async/test/BlockingSocket.h | 55 ++--- folly/io/async/test/SocketClient.cpp | 9 +- 8 files changed, 292 insertions(+), 58 deletions(-) diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index 4fa6b6fe..f91416ed 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -1084,8 +1084,9 @@ AsyncSSLSocket::handleConnect() noexcept { return AsyncSocket::handleConnect(); } - assert(state_ == StateEnum::ESTABLISHED && - sslState_ == STATE_CONNECTING); + assert( + (state_ == StateEnum::FAST_OPEN || state_ == StateEnum::ESTABLISHED) && + sslState_ == STATE_CONNECTING); assert(ssl_); int ret = SSL_connect(ssl_); @@ -1138,6 +1139,16 @@ AsyncSSLSocket::handleConnect() noexcept { AsyncSocket::handleInitialReadWrite(); } +void AsyncSSLSocket::invokeConnectSuccess() { + if (sslState_ == SSLStateEnum::STATE_CONNECTING) { + // If we failed TFO, we'd fall back to trying to connect the socket, + // when we succeed we should handle the writes that caused us to start + // TFO. + handleWrite(); + } + AsyncSocket::invokeConnectSuccess(); +} + void AsyncSSLSocket::setReadCB(ReadCallback *callback) { #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP // turn on the buffer movable in openssl @@ -1498,7 +1509,6 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) { } int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) { - int ret; struct msghdr msg; struct iovec iov; int flags = 0; @@ -1521,17 +1531,20 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) { flags = MSG_EOR; } - ret = sendmsg(BIO_get_fd(b, nullptr), &msg, flags); + auto result = + tsslSock->sendSocketMessage(BIO_get_fd(b, nullptr), &msg, flags); BIO_clear_retry_flags(b); - if (ret <= 0) { - if (BIO_sock_should_retry(ret)) + if (!result.exception && result.writeReturn <= 0) { + if (BIO_sock_should_retry(result.writeReturn)) { BIO_set_retry_write(b); + } } - return ret; + return result.writeReturn; } -int AsyncSSLSocket::sslVerifyCallback(int preverifyOk, - X509_STORE_CTX* x509Ctx) { +int AsyncSSLSocket::sslVerifyCallback( + int preverifyOk, + X509_STORE_CTX* x509Ctx) { SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data( x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx()); AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl); diff --git a/folly/io/async/AsyncSSLSocket.h b/folly/io/async/AsyncSSLSocket.h index 40ceb87a..bd4f76d1 100644 --- a/folly/io/async/AsyncSSLSocket.h +++ b/folly/io/async/AsyncSSLSocket.h @@ -798,6 +798,8 @@ class AsyncSSLSocket : public virtual AsyncSocket { void invokeHandshakeErr(const AsyncSocketException& ex); void invokeHandshakeCB(); + void invokeConnectSuccess() override; + void cacheLocalPeerAddr(); static void sslInfoCallback(const SSL *ssl, int type, int val); diff --git a/folly/io/async/AsyncSocket.cpp b/folly/io/async/AsyncSocket.cpp index cffe1f82..6bb9126e 100644 --- a/folly/io/async/AsyncSocket.cpp +++ b/folly/io/async/AsyncSocket.cpp @@ -1752,9 +1752,8 @@ ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) { return detail::tfo_sendmsg(fd, msg, msg_flags); } -AsyncSocket::WriteResult AsyncSocket::sendSocketMessage( - struct msghdr* msg, - int msg_flags) { +AsyncSocket::WriteResult +AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) { ssize_t totalWritten = 0; if (state_ == StateEnum::FAST_OPEN) { sockaddr_storage addr; @@ -1778,11 +1777,9 @@ AsyncSocket::WriteResult AsyncSocket::sendSocketMessage( return WriteResult( WRITE_ERROR, folly::make_unique(ex)); } - // Let's fake it that no bytes were written. - // Some clients check errno even if return code is 0, so we - // set it just in case. + // Let's fake it that no bytes were written and return an errno. errno = EAGAIN; - totalWritten = 0; + totalWritten = -1; } else if (errno == EOPNOTSUPP) { VLOG(4) << "TFO not supported"; // Try falling back to connecting. @@ -1797,10 +1794,8 @@ AsyncSocket::WriteResult AsyncSocket::sendSocketMessage( } // If there was no exception during connections, // we would return that no bytes were written. - // Some clients check errno even if return code is 0, so we - // set it just in case. errno = EAGAIN; - totalWritten = 0; + totalWritten = -1; } catch (const AsyncSocketException& ex) { return WriteResult( WRITE_ERROR, folly::make_unique(ex)); @@ -1816,7 +1811,7 @@ AsyncSocket::WriteResult AsyncSocket::sendSocketMessage( AsyncSocketException::UNKNOWN, "No more free local ports")); } } else { - totalWritten = ::sendmsg(fd_, msg, msg_flags); + totalWritten = ::sendmsg(fd, msg, msg_flags); } return WriteResult(totalWritten); } @@ -1855,7 +1850,7 @@ AsyncSocket::WriteResult AsyncSocket::performWrite( // marks that this is the last byte of a record (response) msg_flags |= MSG_EOR; } - auto writeResult = sendSocketMessage(&msg, msg_flags); + auto writeResult = sendSocketMessage(fd_, &msg, msg_flags); auto totalWritten = writeResult.writeReturn; if (totalWritten < 0) { if (!writeResult.exception && errno == EAGAIN) { diff --git a/folly/io/async/AsyncSocket.h b/folly/io/async/AsyncSocket.h index f3e605d2..36949725 100644 --- a/folly/io/async/AsyncSocket.h +++ b/folly/io/async/AsyncSocket.h @@ -817,7 +817,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper { * @param msg Message to send * @param msg_flags Flags to pass to sendmsg */ - AsyncSocket::WriteResult sendSocketMessage(struct msghdr* msg, int msg_flags); + AsyncSocket::WriteResult + sendSocketMessage(int fd, struct msghdr* msg, int msg_flags); virtual ssize_t tfoSendMsg(int fd, struct msghdr* msg, int msg_flags); @@ -855,7 +856,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { void failWrite(const char* fn, const AsyncSocketException& ex); void failAllWrites(const AsyncSocketException& ex); void invokeConnectErr(const AsyncSocketException& ex); - void invokeConnectSuccess(); + virtual void invokeConnectSuccess(); void invalidState(ConnectCallback* callback); void invalidState(ReadCallback* callback); void invalidState(WriteCallback* callback); diff --git a/folly/io/async/test/AsyncSSLSocketTest.cpp b/folly/io/async/test/AsyncSSLSocketTest.cpp index d4078e56..fd2b2521 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.cpp +++ b/folly/io/async/test/AsyncSSLSocketTest.cpp @@ -15,26 +15,28 @@ */ #include -#include #include +#include +#include #include #include -#include #include #include #include -#include +#include +#include #include +#include +#include +#include #include #include #include -#include -#include -#include -#include + +#include using std::string; using std::vector; @@ -43,6 +45,8 @@ using std::cerr; using std::endl; using std::list; +using namespace testing; + namespace folly { uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0; uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0; @@ -55,7 +59,7 @@ const char* testCA = "folly/io/async/test/certs/ca-cert.pem"; constexpr size_t SSLClient::kMaxReadBufferSz; constexpr size_t SSLClient::kMaxReadsPerEvent; -TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb) +TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO) : ctx_(new folly::SSLContext), acb_(acb), socket_(folly::AsyncServerSocket::newSocket(&evb_)) { @@ -67,7 +71,13 @@ TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb) acb_->ctx_ = ctx_; acb_->base_ = &evb_; - //set up the listening socket + // Enable TFO + if (enableTFO) { + LOG(INFO) << "server TFO enabled"; + socket_->setTFOEnabled(true, 1000); + } + + // set up the listening socket socket_->bind(0); socket_->getAddress(&address_); socket_->listen(100); @@ -1674,6 +1684,203 @@ TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) { std::string::npos); } +#if FOLLY_ALLOW_TFO + +class MockAsyncTFOSSLSocket : public AsyncSSLSocket { + public: + using UniquePtr = std::unique_ptr; + + explicit MockAsyncTFOSSLSocket( + std::shared_ptr sslCtx, + EventBase* evb) + : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {} + + MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags)); +}; + +/** + * Test connecting to, writing to, reading from, and closing the + * connection to the SSL server with TFO. + */ +TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback, true); + + // Set up SSL context. + auto sslContext = std::make_shared(); + + // connect + auto socket = + std::make_shared(server.getAddress(), sslContext); + socket->enableTFO(); + socket->open(); + + // write() + std::array buf; + memset(buf.data(), 'a', buf.size()); + socket->write(buf.data(), buf.size()); + + // read() + std::array readbuf; + uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size()); + EXPECT_EQ(bytesRead, 128); + EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0); + + // close() + socket->close(); +} + +/** + * Test connecting to, writing to, reading from, and closing the + * connection to the SSL server with TFO. + */ +TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback, false); + + // Set up SSL context. + auto sslContext = std::make_shared(); + + // connect + auto socket = + std::make_shared(server.getAddress(), sslContext); + socket->enableTFO(); + socket->open(); + + // write() + std::array buf; + memset(buf.data(), 'a', buf.size()); + socket->write(buf.data(), buf.size()); + + // read() + std::array readbuf; + uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size()); + EXPECT_EQ(bytesRead, 128); + EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0); + + // close() + socket->close(); +} + +class ConnCallback : public AsyncSocket::ConnectCallback { + public: + virtual void connectSuccess() noexcept override { + state = State::SUCCESS; + } + + virtual void connectErr(const AsyncSocketException&) noexcept override { + state = State::ERROR; + } + + enum class State { WAITING, SUCCESS, ERROR }; + + State state{State::WAITING}; +}; + +MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback( + EventBase* evb, + const SocketAddress& address) { + // Set up SSL context. + auto sslContext = std::make_shared(); + + // connect + auto socket = MockAsyncTFOSSLSocket::UniquePtr( + new MockAsyncTFOSSLSocket(sslContext, evb)); + socket->enableTFO(); + + EXPECT_CALL(*socket, tfoSendMsg(_, _, _)) + .WillOnce(Invoke([&](int fd, struct msghdr*, int) { + sockaddr_storage addr; + auto len = address.getAddress(&addr); + return connect(fd, (const struct sockaddr*)&addr, len); + })); + return socket; +} + +TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback, true); + + EventBase evb; + + auto socket = setupSocketWithFallback(&evb, server.getAddress()); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + evb.loop(); + EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state); + + evb.runInEventBaseThread([&] { socket->detachEventBase(); }); + evb.loop(); + + BlockingSocket sock(std::move(socket)); + // write() + std::array buf; + memset(buf.data(), 'a', buf.size()); + sock.write(buf.data(), buf.size()); + + // read() + std::array readbuf; + uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size()); + EXPECT_EQ(bytesRead, 128); + EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0); + + // close() + sock.close(); +} + +TEST(AsyncSSLSocketTest, ConnectTFOTimeout) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadErrorCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback, true); + + // Set up SSL context. + auto sslContext = std::make_shared(); + + // connect + auto socket = + std::make_shared(server.getAddress(), sslContext); + socket->enableTFO(); + EXPECT_THROW( + socket->open(std::chrono::milliseconds(1)), AsyncSocketException); +} + +TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadErrorCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback, true); + + EventBase evb; + + auto socket = setupSocketWithFallback(&evb, server.getAddress()); + ConnCallback ccb; + // Set a short timeout + socket->connect(&ccb, server.getAddress(), 1); + + evb.loop(); + EXPECT_EQ(ConnCallback::State::ERROR, ccb.state); +} + +#endif + } // namespace /////////////////////////////////////////////////////////////////////////// diff --git a/folly/io/async/test/AsyncSSLSocketTest.h b/folly/io/async/test/AsyncSSLSocketTest.h index 43ada249..e4b51249 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.h +++ b/folly/io/async/test/AsyncSSLSocketTest.h @@ -607,7 +607,9 @@ class TestSSLServer { public: // Create a TestSSLServer. // This immediately starts listening on the given port. - explicit TestSSLServer(SSLServerAcceptCallbackBase *acb); + explicit TestSSLServer( + SSLServerAcceptCallbackBase* acb, + bool enableTFO = false); // Kill the thread. ~TestSSLServer() { diff --git a/folly/io/async/test/BlockingSocket.h b/folly/io/async/test/BlockingSocket.h index 3830648e..b3713fc2 100644 --- a/folly/io/async/test/BlockingSocket.h +++ b/folly/io/async/test/BlockingSocket.h @@ -16,36 +16,41 @@ #pragma once #include -#include -#include #include +#include +#include class BlockingSocket : public folly::AsyncSocket::ConnectCallback, public folly::AsyncTransportWrapper::ReadCallback, - public folly::AsyncTransportWrapper::WriteCallback -{ + public folly::AsyncTransportWrapper::WriteCallback { public: explicit BlockingSocket(int fd) - : sock_(new folly::AsyncSocket(&eventBase_, fd)) { - } + : sock_(new folly::AsyncSocket(&eventBase_, fd)) {} - BlockingSocket(folly::SocketAddress address, - std::shared_ptr sslContext) - : sock_(sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_) : - new folly::AsyncSocket(&eventBase_)), - address_(address) {} + BlockingSocket( + folly::SocketAddress address, + std::shared_ptr sslContext) + : sock_( + sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_) + : new folly::AsyncSocket(&eventBase_)), + address_(address) {} explicit BlockingSocket(folly::AsyncSocket::UniquePtr socket) : sock_(std::move(socket)) { sock_->attachEventBase(&eventBase_); } + void enableTFO() { + sock_->enableTFO(); + } + void setAddress(folly::SocketAddress address) { address_ = address; } - void open() { - sock_->connect(this, address_); + void open( + std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) { + sock_->connect(this, address_, timeout.count()); eventBase_.loop(); if (err_.hasValue()) { throw err_.value(); @@ -54,7 +59,9 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback, void close() { sock_->close(); } - void closeWithReset() { sock_->closeWithReset(); } + void closeWithReset() { + sock_->closeWithReset(); + } int32_t write(uint8_t const* buf, size_t len) { sock_->write(this, buf, len); @@ -67,11 +74,11 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback, void flush() {} - int32_t readAll(uint8_t *buf, size_t len) { + int32_t readAll(uint8_t* buf, size_t len) { return readHelper(buf, len, true); } - int32_t read(uint8_t *buf, size_t len) { + int32_t read(uint8_t* buf, size_t len) { return readHelper(buf, len, false); } @@ -83,7 +90,7 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback, folly::EventBase eventBase_; folly::AsyncSocket::UniquePtr sock_; folly::Optional err_; - uint8_t *readBuf_{nullptr}; + uint8_t* readBuf_{nullptr}; size_t readLen_{0}; folly::SocketAddress address_; @@ -102,18 +109,18 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback, sock_->setReadCB(nullptr); } } - void readEOF() noexcept override { - } + void readEOF() noexcept override {} void readErr(const folly::AsyncSocketException& ex) noexcept override { err_ = ex; } void writeSuccess() noexcept override {} - void writeErr(size_t /* bytesWritten */, - const folly::AsyncSocketException& ex) noexcept override { + void writeErr( + size_t /* bytesWritten */, + const folly::AsyncSocketException& ex) noexcept override { err_ = ex; } - int32_t readHelper(uint8_t *buf, size_t len, bool all) { + int32_t readHelper(uint8_t* buf, size_t len, bool all) { if (!sock_->good()) { return 0; } @@ -132,8 +139,8 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback, throw err_.value(); } if (all && readLen_ > 0) { - throw folly::AsyncSocketException(folly::AsyncSocketException::UNKNOWN, - "eof"); + throw folly::AsyncSocketException( + folly::AsyncSocketException::UNKNOWN, "eof"); } return len - readLen_; } diff --git a/folly/io/async/test/SocketClient.cpp b/folly/io/async/test/SocketClient.cpp index 7f20d480..23bef722 100644 --- a/folly/io/async/test/SocketClient.cpp +++ b/folly/io/async/test/SocketClient.cpp @@ -24,6 +24,7 @@ DEFINE_string(host, "localhost", "Host"); DEFINE_int32(port, 0, "port"); DEFINE_bool(tfo, false, "enable tfo"); DEFINE_string(msg, "", "Message to send"); +DEFINE_bool(ssl, false, "use ssl"); int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -35,7 +36,13 @@ int main(int argc, char** argv) { // Prep the socket EventBase evb; - AsyncSocket::UniquePtr socket(new AsyncSocket(&evb)); + AsyncSocket::UniquePtr socket; + if (FLAGS_ssl) { + auto sslContext = std::make_shared(); + socket = AsyncSocket::UniquePtr(new AsyncSSLSocket(sslContext, &evb)); + } else { + socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb)); + } socket->detachEventBase(); if (FLAGS_tfo) { -- 2.34.1