return AsyncSocket::handleConnect();
}
- assert(state_ == StateEnum::ESTABLISHED &&
- sslState_ == STATE_CONNECTING);
+ assert(
+ (state_ == StateEnum::FAST_OPEN || state_ == StateEnum::ESTABLISHED) &&
+ sslState_ == STATE_CONNECTING);
assert(ssl_);
int ret = SSL_connect(ssl_);
AsyncSocket::handleInitialReadWrite();
}
+void AsyncSSLSocket::invokeConnectSuccess() {
+ if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
+ // If we failed TFO, we'd fall back to trying to connect the socket,
+ // when we succeed we should handle the writes that caused us to start
+ // TFO.
+ handleWrite();
+ }
+ AsyncSocket::invokeConnectSuccess();
+}
+
void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
// turn on the buffer movable in openssl
}
int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
- int ret;
struct msghdr msg;
struct iovec iov;
int flags = 0;
flags = MSG_EOR;
}
- ret = sendmsg(BIO_get_fd(b, nullptr), &msg, flags);
+ auto result =
+ tsslSock->sendSocketMessage(BIO_get_fd(b, nullptr), &msg, flags);
BIO_clear_retry_flags(b);
- if (ret <= 0) {
- if (BIO_sock_should_retry(ret))
+ if (!result.exception && result.writeReturn <= 0) {
+ if (BIO_sock_should_retry(result.writeReturn)) {
BIO_set_retry_write(b);
+ }
}
- return ret;
+ return result.writeReturn;
}
-int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
- X509_STORE_CTX* x509Ctx) {
+int AsyncSSLSocket::sslVerifyCallback(
+ int preverifyOk,
+ X509_STORE_CTX* x509Ctx) {
SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
void invokeHandshakeErr(const AsyncSocketException& ex);
void invokeHandshakeCB();
+ void invokeConnectSuccess() override;
+
void cacheLocalPeerAddr();
static void sslInfoCallback(const SSL *ssl, int type, int val);
return detail::tfo_sendmsg(fd, msg, msg_flags);
}
-AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
- struct msghdr* msg,
- int msg_flags) {
+AsyncSocket::WriteResult
+AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
ssize_t totalWritten = 0;
if (state_ == StateEnum::FAST_OPEN) {
sockaddr_storage addr;
return WriteResult(
WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
}
- // Let's fake it that no bytes were written.
- // Some clients check errno even if return code is 0, so we
- // set it just in case.
+ // Let's fake it that no bytes were written and return an errno.
errno = EAGAIN;
- totalWritten = 0;
+ totalWritten = -1;
} else if (errno == EOPNOTSUPP) {
VLOG(4) << "TFO not supported";
// Try falling back to connecting.
}
// If there was no exception during connections,
// we would return that no bytes were written.
- // Some clients check errno even if return code is 0, so we
- // set it just in case.
errno = EAGAIN;
- totalWritten = 0;
+ totalWritten = -1;
} catch (const AsyncSocketException& ex) {
return WriteResult(
WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
AsyncSocketException::UNKNOWN, "No more free local ports"));
}
} else {
- totalWritten = ::sendmsg(fd_, msg, msg_flags);
+ totalWritten = ::sendmsg(fd, msg, msg_flags);
}
return WriteResult(totalWritten);
}
// marks that this is the last byte of a record (response)
msg_flags |= MSG_EOR;
}
- auto writeResult = sendSocketMessage(&msg, msg_flags);
+ auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
auto totalWritten = writeResult.writeReturn;
if (totalWritten < 0) {
if (!writeResult.exception && errno == EAGAIN) {
* @param msg Message to send
* @param msg_flags Flags to pass to sendmsg
*/
- AsyncSocket::WriteResult sendSocketMessage(struct msghdr* msg, int msg_flags);
+ AsyncSocket::WriteResult
+ sendSocketMessage(int fd, struct msghdr* msg, int msg_flags);
virtual ssize_t tfoSendMsg(int fd, struct msghdr* msg, int msg_flags);
void failWrite(const char* fn, const AsyncSocketException& ex);
void failAllWrites(const AsyncSocketException& ex);
void invokeConnectErr(const AsyncSocketException& ex);
- void invokeConnectSuccess();
+ virtual void invokeConnectSuccess();
void invalidState(ConnectCallback* callback);
void invalidState(ReadCallback* callback);
void invalidState(WriteCallback* callback);
*/
#include <folly/io/async/test/AsyncSSLSocketTest.h>
-#include <signal.h>
#include <pthread.h>
+#include <signal.h>
+#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h>
-#include <folly/SocketAddress.h>
#include <folly/portability/Sockets.h>
#include <folly/portability/Unistd.h>
#include <folly/io/async/test/BlockingSocket.h>
-#include <fstream>
+#include <fcntl.h>
+#include <folly/io/Cursor.h>
#include <gtest/gtest.h>
+#include <openssl/bio.h>
+#include <sys/types.h>
+#include <fstream>
#include <iostream>
#include <list>
#include <set>
-#include <fcntl.h>
-#include <openssl/bio.h>
-#include <sys/types.h>
-#include <folly/io/Cursor.h>
+
+#include <gmock/gmock.h>
using std::string;
using std::vector;
using std::endl;
using std::list;
+using namespace testing;
+
namespace folly {
uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
constexpr size_t SSLClient::kMaxReadBufferSz;
constexpr size_t SSLClient::kMaxReadsPerEvent;
-TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
+TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
: ctx_(new folly::SSLContext),
acb_(acb),
socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
acb_->ctx_ = ctx_;
acb_->base_ = &evb_;
- //set up the listening socket
+ // Enable TFO
+ if (enableTFO) {
+ LOG(INFO) << "server TFO enabled";
+ socket_->setTFOEnabled(true, 1000);
+ }
+
+ // set up the listening socket
socket_->bind(0);
socket_->getAddress(&address_);
socket_->listen(100);
std::string::npos);
}
+#if FOLLY_ALLOW_TFO
+
+class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
+ public:
+ using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
+
+ explicit MockAsyncTFOSSLSocket(
+ std::shared_ptr<folly::SSLContext> sslCtx,
+ EventBase* evb)
+ : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
+
+ MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
+};
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, true);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->enableTFO();
+ socket->open();
+
+ // write()
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+ socket->write(buf.data(), buf.size());
+
+ // read()
+ std::array<uint8_t, 128> readbuf;
+ uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, 128);
+ EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+ // close()
+ socket->close();
+}
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, false);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->enableTFO();
+ socket->open();
+
+ // write()
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+ socket->write(buf.data(), buf.size());
+
+ // read()
+ std::array<uint8_t, 128> readbuf;
+ uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, 128);
+ EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+ // close()
+ socket->close();
+}
+
+class ConnCallback : public AsyncSocket::ConnectCallback {
+ public:
+ virtual void connectSuccess() noexcept override {
+ state = State::SUCCESS;
+ }
+
+ virtual void connectErr(const AsyncSocketException&) noexcept override {
+ state = State::ERROR;
+ }
+
+ enum class State { WAITING, SUCCESS, ERROR };
+
+ State state{State::WAITING};
+};
+
+MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
+ EventBase* evb,
+ const SocketAddress& address) {
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket = MockAsyncTFOSSLSocket::UniquePtr(
+ new MockAsyncTFOSSLSocket(sslContext, evb));
+ socket->enableTFO();
+
+ EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+ .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
+ sockaddr_storage addr;
+ auto len = address.getAddress(&addr);
+ return connect(fd, (const struct sockaddr*)&addr, len);
+ }));
+ return socket;
+}
+
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, true);
+
+ EventBase evb;
+
+ auto socket = setupSocketWithFallback(&evb, server.getAddress());
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
+
+ evb.runInEventBaseThread([&] { socket->detachEventBase(); });
+ evb.loop();
+
+ BlockingSocket sock(std::move(socket));
+ // write()
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+ sock.write(buf.data(), buf.size());
+
+ // read()
+ std::array<uint8_t, 128> readbuf;
+ uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, 128);
+ EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+ // close()
+ sock.close();
+}
+
+TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadErrorCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, true);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->enableTFO();
+ EXPECT_THROW(
+ socket->open(std::chrono::milliseconds(1)), AsyncSocketException);
+}
+
+TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadErrorCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, true);
+
+ EventBase evb;
+
+ auto socket = setupSocketWithFallback(&evb, server.getAddress());
+ ConnCallback ccb;
+ // Set a short timeout
+ socket->connect(&ccb, server.getAddress(), 1);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+}
+
+#endif
+
} // namespace
///////////////////////////////////////////////////////////////////////////
public:
// Create a TestSSLServer.
// This immediately starts listening on the given port.
- explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
+ explicit TestSSLServer(
+ SSLServerAcceptCallbackBase* acb,
+ bool enableTFO = false);
// Kill the thread.
~TestSSLServer() {
#pragma once
#include <folly/Optional.h>
-#include <folly/io/async/SSLContext.h>
-#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/SSLContext.h>
class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
public folly::AsyncTransportWrapper::ReadCallback,
- public folly::AsyncTransportWrapper::WriteCallback
-{
+ public folly::AsyncTransportWrapper::WriteCallback {
public:
explicit BlockingSocket(int fd)
- : sock_(new folly::AsyncSocket(&eventBase_, fd)) {
- }
+ : sock_(new folly::AsyncSocket(&eventBase_, fd)) {}
- BlockingSocket(folly::SocketAddress address,
- std::shared_ptr<folly::SSLContext> sslContext)
- : sock_(sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_) :
- new folly::AsyncSocket(&eventBase_)),
- address_(address) {}
+ BlockingSocket(
+ folly::SocketAddress address,
+ std::shared_ptr<folly::SSLContext> sslContext)
+ : sock_(
+ sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_)
+ : new folly::AsyncSocket(&eventBase_)),
+ address_(address) {}
explicit BlockingSocket(folly::AsyncSocket::UniquePtr socket)
: sock_(std::move(socket)) {
sock_->attachEventBase(&eventBase_);
}
+ void enableTFO() {
+ sock_->enableTFO();
+ }
+
void setAddress(folly::SocketAddress address) {
address_ = address;
}
- void open() {
- sock_->connect(this, address_);
+ void open(
+ std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) {
+ sock_->connect(this, address_, timeout.count());
eventBase_.loop();
if (err_.hasValue()) {
throw err_.value();
void close() {
sock_->close();
}
- void closeWithReset() { sock_->closeWithReset(); }
+ void closeWithReset() {
+ sock_->closeWithReset();
+ }
int32_t write(uint8_t const* buf, size_t len) {
sock_->write(this, buf, len);
void flush() {}
- int32_t readAll(uint8_t *buf, size_t len) {
+ int32_t readAll(uint8_t* buf, size_t len) {
return readHelper(buf, len, true);
}
- int32_t read(uint8_t *buf, size_t len) {
+ int32_t read(uint8_t* buf, size_t len) {
return readHelper(buf, len, false);
}
folly::EventBase eventBase_;
folly::AsyncSocket::UniquePtr sock_;
folly::Optional<folly::AsyncSocketException> err_;
- uint8_t *readBuf_{nullptr};
+ uint8_t* readBuf_{nullptr};
size_t readLen_{0};
folly::SocketAddress address_;
sock_->setReadCB(nullptr);
}
}
- void readEOF() noexcept override {
- }
+ void readEOF() noexcept override {}
void readErr(const folly::AsyncSocketException& ex) noexcept override {
err_ = ex;
}
void writeSuccess() noexcept override {}
- void writeErr(size_t /* bytesWritten */,
- const folly::AsyncSocketException& ex) noexcept override {
+ void writeErr(
+ size_t /* bytesWritten */,
+ const folly::AsyncSocketException& ex) noexcept override {
err_ = ex;
}
- int32_t readHelper(uint8_t *buf, size_t len, bool all) {
+ int32_t readHelper(uint8_t* buf, size_t len, bool all) {
if (!sock_->good()) {
return 0;
}
throw err_.value();
}
if (all && readLen_ > 0) {
- throw folly::AsyncSocketException(folly::AsyncSocketException::UNKNOWN,
- "eof");
+ throw folly::AsyncSocketException(
+ folly::AsyncSocketException::UNKNOWN, "eof");
}
return len - readLen_;
}
DEFINE_int32(port, 0, "port");
DEFINE_bool(tfo, false, "enable tfo");
DEFINE_string(msg, "", "Message to send");
+DEFINE_bool(ssl, false, "use ssl");
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
// Prep the socket
EventBase evb;
- AsyncSocket::UniquePtr socket(new AsyncSocket(&evb));
+ AsyncSocket::UniquePtr socket;
+ if (FLAGS_ssl) {
+ auto sslContext = std::make_shared<SSLContext>();
+ socket = AsyncSocket::UniquePtr(new AsyncSSLSocket(sslContext, &evb));
+ } else {
+ socket = AsyncSocket::UniquePtr(new AsyncSocket(&evb));
+ }
socket->detachEventBase();
if (FLAGS_tfo) {