/*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
#include <folly/io/async/test/AsyncSSLSocketTest.h>
#include <signal.h>
-#include <pthread.h>
+#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h>
-#include <folly/SocketAddress.h>
+#include <folly/portability/GMock.h>
+#include <folly/portability/GTest.h>
+#include <folly/portability/OpenSSL.h>
+#include <folly/portability/Sockets.h>
+#include <folly/portability/Unistd.h>
#include <folly/io/async/test/BlockingSocket.h>
-#include <gtest/gtest.h>
+#include <fcntl.h>
+#include <folly/io/Cursor.h>
+#include <openssl/bio.h>
+#include <sys/types.h>
+#include <fstream>
#include <iostream>
#include <list>
#include <set>
-#include <unistd.h>
-#include <fcntl.h>
-#include <poll.h>
-#include <sys/types.h>
-#include <sys/socket.h>
-#include <netinet/tcp.h>
-#include <folly/io/Cursor.h>
+#include <thread>
using std::string;
using std::vector;
using std::endl;
using std::list;
+using namespace testing;
+
namespace folly {
uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
-const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
-const char* testKey = "folly/io/async/test/certs/tests-key.pem";
-const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
-
constexpr size_t SSLClient::kMaxReadBufferSz;
constexpr size_t SSLClient::kMaxReadsPerEvent;
-TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) :
-ctx_(new folly::SSLContext),
- acb_(acb),
- socket_(new folly::AsyncServerSocket(&evb_)) {
- // Set up the SSL context
- ctx_->loadCertificate(testCert);
- ctx_->loadPrivateKey(testKey);
- ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
-
- acb_->ctx_ = ctx_;
- acb_->base_ = &evb_;
-
- //set up the listening socket
- socket_->bind(0);
- socket_->getAddress(&address_);
- socket_->listen(100);
- socket_->addAcceptCallback(acb_, &evb_);
- socket_->startAccepting();
-
- int ret = pthread_create(&thread_, nullptr, Main, this);
- assert(ret == 0);
-
- std::cerr << "Accepting connections on " << address_ << std::endl;
-}
-
void getfds(int fds[2]) {
if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
FAIL() << "failed to create socketpair: " << strerror(errno);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
- serverCtx->loadCertificate(
- testCert);
- serverCtx->loadPrivateKey(
- testKey);
+ serverCtx->loadCertificate(kTestCert);
+ serverCtx->loadPrivateKey(kTestKey);
}
void sslsocketpair(
// (*serverSock)->setSendTimeout(100);
}
+// client protocol filters
+bool clientProtoFilterPickPony(unsigned char** client,
+ unsigned int* client_len, const unsigned char*, unsigned int ) {
+ //the protocol string in length prefixed byte string. the
+ //length byte is not included in the length
+ static unsigned char p[7] = {6,'p','o','n','i','e','s'};
+ *client = p;
+ *client_len = 7;
+ return true;
+}
+
+bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
+ const unsigned char*, unsigned int) {
+ return false;
+}
+
+std::string getFileAsBuf(const char* fileName) {
+ std::string buffer;
+ folly::readFile(fileName, buffer);
+ return buffer;
+}
+
+std::string getCommonName(X509* cert) {
+ X509_NAME* subject = X509_get_subject_name(cert);
+ std::string cn;
+ cn.resize(ub_common_name);
+ X509_NAME_get_text_by_NID(
+ subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
+ return cn;
+}
/**
* Test connecting to, writing to, reading from, and closing the
cerr << "ConnectWriteReadClose test completed" << endl;
}
+/**
+ * Test reading after server close.
+ */
+TEST(AsyncSSLSocketTest, ReadAfterClose) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadEOFCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+ sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ auto socket =
+ std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
+ socket->open();
+
+ // This should trigger an EOF on the client.
+ auto evb = handshakeCallback.getSocket()->getEventBase();
+ evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
+ std::array<uint8_t, 128> readbuf;
+ auto bytesRead = socket->read(readbuf.data(), readbuf.size());
+ EXPECT_EQ(0, bytesRead);
+}
+
+/**
+ * Test bad renegotiation
+ */
+#if !defined(OPENSSL_IS_BORINGSSL)
+TEST(AsyncSSLSocketTest, Renegotiate) {
+ EventBase eventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto dfServerCtx = std::make_shared<SSLContext>();
+ std::array<int, 2> fds;
+ getfds(fds.data());
+ getctx(clientCtx, dfServerCtx);
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+ SSLHandshakeClient client(std::move(clientSock), true, true);
+ RenegotiatingServer server(std::move(serverSock));
+
+ while (!client.handshakeSuccess_ && !client.handshakeError_) {
+ eventBase.loopOnce();
+ }
+
+ ASSERT_TRUE(client.handshakeSuccess_);
+
+ auto sslSock = std::move(client).moveSocket();
+ sslSock->detachEventBase();
+ // This is nasty, however we don't want to add support for
+ // renegotiation in AsyncSSLSocket.
+ SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
+
+ auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
+
+ std::thread t([&]() { eventBase.loopForever(); });
+
+ // Trigger the renegotiation.
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+ try {
+ socket->write(buf.data(), buf.size());
+ } catch (AsyncSocketException& e) {
+ LOG(INFO) << "client got error " << e.what();
+ }
+ eventBase.terminateLoopSoon();
+ t.join();
+
+ eventBase.loop();
+ ASSERT_TRUE(server.renegotiationError_);
+}
+#endif
+
/**
* Negative test for handshakeError().
*/
TEST(AsyncSSLSocketTest, HandshakeError) {
// Start listening on a local port
WriteCallbackBase writeCallback;
- ReadCallback readCallback(&writeCallback);
+ WriteErrorCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
HandshakeErrorCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
uint8_t readbuf[128];
uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
- } catch (AsyncSocketException &e) {
+ LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
+ } catch (AsyncSocketException&) {
ex = true;
}
EXPECT_TRUE(ex);
cerr << "SocketWithDelay test completed" << endl;
}
-TEST(AsyncSSLSocketTest, NpnTestOverlap) {
+using NextProtocolTypePair =
+ std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
+
+class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
+ // For matching protos
+ public:
+ void SetUp() override { getctx(clientCtx, serverCtx); }
+
+ void connect(bool unset = false) {
+ getfds(fds);
+
+ if (unset) {
+ // unsetting NPN for any of [client, server] is enough to make NPN not
+ // work
+ clientCtx->unsetNextProtocols();
+ }
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+ client = folly::make_unique<NpnClient>(std::move(clientSock));
+ server = folly::make_unique<NpnServer>(std::move(serverSock));
+
+ eventBase.loop();
+ }
+
+ void expectProtocol(const std::string& proto) {
+ EXPECT_NE(client->nextProtoLength, 0);
+ EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
+ EXPECT_EQ(
+ memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
+ 0);
+ string selected((const char*)client->nextProto, client->nextProtoLength);
+ EXPECT_EQ(proto, selected);
+ }
+
+ void expectNoProtocol() {
+ EXPECT_EQ(client->nextProtoLength, 0);
+ EXPECT_EQ(server->nextProtoLength, 0);
+ EXPECT_EQ(client->nextProto, nullptr);
+ EXPECT_EQ(server->nextProto, nullptr);
+ }
+
+ void expectProtocolType() {
+ if (GetParam().first == SSLContext::NextProtocolType::ANY &&
+ GetParam().second == SSLContext::NextProtocolType::ANY) {
+ EXPECT_EQ(client->protocolType, server->protocolType);
+ } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
+ GetParam().second == SSLContext::NextProtocolType::ANY) {
+ // Well not much we can say
+ } else {
+ expectProtocolType(GetParam());
+ }
+ }
+
+ void expectProtocolType(NextProtocolTypePair expected) {
+ EXPECT_EQ(client->protocolType, expected.first);
+ EXPECT_EQ(server->protocolType, expected.second);
+ }
+
EventBase eventBase;
- std::shared_ptr<SSLContext> clientCtx(new SSLContext);
- std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
+ std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
+ std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
int fds[2];
- getfds(fds);
- getctx(clientCtx, serverCtx);
+ std::unique_ptr<NpnClient> client;
+ std::unique_ptr<NpnServer> server;
+};
- clientCtx->setAdvertisedNextProtocols({"blub","baz"});
- serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+class NextProtocolTLSExtTest : public NextProtocolTest {
+ // For extended TLS protos
+};
- AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
- AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
- NpnClient client(std::move(clientSock));
- NpnServer server(std::move(serverSock));
+class NextProtocolNPNOnlyTest : public NextProtocolTest {
+ // For mismatching protos
+};
- eventBase.loop();
+class NextProtocolMismatchTest : public NextProtocolTest {
+ // For mismatching protos
+};
+
+TEST_P(NextProtocolTest, NpnTestOverlap) {
+ clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
+ serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+ GetParam().second);
+
+ connect();
- EXPECT_TRUE(client.nextProtoLength != 0);
- EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
- EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
- server.nextProtoLength), 0);
- string selected((const char*)client.nextProto, client.nextProtoLength);
- EXPECT_EQ(selected.compare("baz"), 0);
+ expectProtocol("baz");
+ expectProtocolType();
}
-TEST(AsyncSSLSocketTest, NpnTestUnset) {
+TEST_P(NextProtocolTest, NpnTestUnset) {
// Identical to above test, except that we want unset NPN before
// looping.
- EventBase eventBase;
- std::shared_ptr<SSLContext> clientCtx(new SSLContext);
- std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
- int fds[2];
- getfds(fds);
- getctx(clientCtx, serverCtx);
+ clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
+ serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+ GetParam().second);
- clientCtx->setAdvertisedNextProtocols({"blub","baz"});
- serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+ connect(true /* unset */);
- AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
- AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+ // if alpn negotiation fails, type will appear as npn
+ expectNoProtocol();
+ EXPECT_EQ(client->protocolType, server->protocolType);
+}
- // unsetting NPN for any of [client, server] is enought to make NPN not
- // work
- clientCtx->unsetNextProtocols();
+TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
+ clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
+ serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+ GetParam().second);
- NpnClient client(std::move(clientSock));
- NpnServer server(std::move(serverSock));
+ connect();
- eventBase.loop();
+ expectNoProtocol();
+ expectProtocolType(
+ {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
+}
- EXPECT_TRUE(client.nextProtoLength == 0);
- EXPECT_TRUE(server.nextProtoLength == 0);
- EXPECT_TRUE(client.nextProto == nullptr);
- EXPECT_TRUE(server.nextProto == nullptr);
+// Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
+// will fail on 1.0.2 before that.
+TEST_P(NextProtocolTest, NpnTestNoOverlap) {
+ clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
+ serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+ GetParam().second);
+
+ connect();
+
+ if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
+ GetParam().second == SSLContext::NextProtocolType::ALPN) {
+ // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
+ // mismatch should result in a fatal alert, but this is OpenSSL's current
+ // behavior and we want to know if it changes.
+ expectNoProtocol();
+ }
+#if defined(OPENSSL_IS_BORINGSSL)
+ // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
+ // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
+ // NPN *and* ALPN
+ else if (
+ GetParam().first == SSLContext::NextProtocolType::ANY &&
+ GetParam().second == SSLContext::NextProtocolType::ANY) {
+ expectNoProtocol();
+ }
+#endif
+ else {
+ expectProtocol("blub");
+ expectProtocolType(
+ {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
+ }
}
-TEST(AsyncSSLSocketTest, NpnTestNoOverlap) {
- EventBase eventBase;
- std::shared_ptr<SSLContext> clientCtx(new SSLContext);
- std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
- int fds[2];
- getfds(fds);
- getctx(clientCtx, serverCtx);
+TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
+ clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
+ clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
+ serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+ GetParam().second);
- clientCtx->setAdvertisedNextProtocols({"blub"});
- serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+ connect();
- AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
- AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
- NpnClient client(std::move(clientSock));
- NpnServer server(std::move(serverSock));
+ expectProtocol("ponies");
+ expectProtocolType();
+}
- eventBase.loop();
+TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
+ clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
+ clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
+ serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+ GetParam().second);
- EXPECT_TRUE(client.nextProtoLength != 0);
- EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
- EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
- server.nextProtoLength), 0);
- string selected((const char*)client.nextProto, client.nextProtoLength);
- EXPECT_EQ(selected.compare("blub"), 0);
+ connect();
+
+ expectProtocol("blub");
+ expectProtocolType();
}
-TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
+TEST_P(NextProtocolTest, RandomizedNpnTest) {
// Probability that this test will fail is 2^-64, which could be considered
// as negligible.
const int kTries = 64;
+ clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+ GetParam().first);
+ serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
+ GetParam().second);
+
std::set<string> selectedProtocols;
for (int i = 0; i < kTries; ++i) {
- EventBase eventBase;
- std::shared_ptr<SSLContext> clientCtx = std::make_shared<SSLContext>();
- std::shared_ptr<SSLContext> serverCtx = std::make_shared<SSLContext>();
- int fds[2];
- getfds(fds);
- getctx(clientCtx, serverCtx);
-
- clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
- serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}},
- {1, {"bar"}}});
-
-
- AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
- AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
- NpnClient client(std::move(clientSock));
- NpnServer server(std::move(serverSock));
-
- eventBase.loop();
-
- EXPECT_TRUE(client.nextProtoLength != 0);
- EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
- EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
- server.nextProtoLength), 0);
- string selected((const char*)client.nextProto, client.nextProtoLength);
+ connect();
+
+ EXPECT_NE(client->nextProtoLength, 0);
+ EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
+ EXPECT_EQ(
+ memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
+ 0);
+ string selected((const char*)client->nextProto, client->nextProtoLength);
selectedProtocols.insert(selected);
+ expectProtocolType();
}
EXPECT_EQ(selectedProtocols.size(), 2);
}
+INSTANTIATE_TEST_CASE_P(
+ AsyncSSLSocketTest,
+ NextProtocolTest,
+ ::testing::Values(
+ NextProtocolTypePair(
+ SSLContext::NextProtocolType::NPN,
+ SSLContext::NextProtocolType::NPN),
+ NextProtocolTypePair(
+ SSLContext::NextProtocolType::NPN,
+ SSLContext::NextProtocolType::ANY),
+ NextProtocolTypePair(
+ SSLContext::NextProtocolType::ANY,
+ SSLContext::NextProtocolType::ANY)));
+
+#if FOLLY_OPENSSL_HAS_ALPN
+INSTANTIATE_TEST_CASE_P(
+ AsyncSSLSocketTest,
+ NextProtocolTLSExtTest,
+ ::testing::Values(
+ NextProtocolTypePair(
+ SSLContext::NextProtocolType::ALPN,
+ SSLContext::NextProtocolType::ALPN),
+ NextProtocolTypePair(
+ SSLContext::NextProtocolType::ALPN,
+ SSLContext::NextProtocolType::ANY),
+ NextProtocolTypePair(
+ SSLContext::NextProtocolType::ANY,
+ SSLContext::NextProtocolType::ALPN)));
+#endif
+
+INSTANTIATE_TEST_CASE_P(
+ AsyncSSLSocketTest,
+ NextProtocolNPNOnlyTest,
+ ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
+ SSLContext::NextProtocolType::NPN)));
+
+#if FOLLY_OPENSSL_HAS_ALPN
+INSTANTIATE_TEST_CASE_P(
+ AsyncSSLSocketTest,
+ NextProtocolMismatchTest,
+ ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
+ SSLContext::NextProtocolType::ALPN),
+ NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
+ SSLContext::NextProtocolType::NPN)));
+#endif
#ifndef OPENSSL_NO_TLSEXT
/**
// Set up SSL client
EventBase eventBase;
- std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
- 1));
+ auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
client->connect();
EventBaseAborter eba(&eventBase, 3000);
// Set up SSL client
EventBase eventBase;
- std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
- 10));
+ auto client =
+ std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
client->connect();
EventBaseAborter eba(&eventBase, 3000);
// Set up SSL client
EventBase eventBase;
- std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
- 1, 10));
+ auto client =
+ std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
client->connect(true /* write before connect completes */);
EventBaseAborter eba(&eventBase, 3000);
eventBase.loop();
cerr << "SSLClientTimeoutTest test completed" << endl;
}
-
+// The next 3 tests need an FB-only extension, and will fail without it
+#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
/**
* Test SSL server async cache
*/
// Set up SSL client
EventBase eventBase;
- std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
- 10, 500));
+ auto client =
+ std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
client->connect();
EventBaseAborter eba(&eventBase, 3000);
cerr << "SSLServerAsyncCacheTest test completed" << endl;
}
-
/**
* Test SSL server accept timeout with cache path
*/
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
- EmptyReadCallback clientReadCallback;
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
TestSSLAsyncCacheServer server(&acceptCallback);
// only do a TCP connect
std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
sock->connect(nullptr, server.getAddress());
+
+ EmptyReadCallback clientReadCallback;
clientReadCallback.tcpSocket_ = sock;
sock->setReadCB(&clientReadCallback);
// Set up SSL client
EventBase eventBase;
- std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
- 2));
+ auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
client->connect();
EventBaseAborter eba(&eventBase, 3000);
// Set up SSL client
EventBase eventBase;
- std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
- 2, 100));
+ auto client =
+ std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
client->connect();
EventBaseAborter eba(&eventBase, 3000);
server.getEventBase().runInEventBaseThread([&handshakeCallback]{
handshakeCallback.closeSocket();});
// give time for the cache lookup to come back and find it closed
- usleep(500000);
+ handshakeCallback.waitForHandshake();
EXPECT_EQ(server.getAsyncCallbacks(), 1);
EXPECT_EQ(server.getAsyncLookups(), 1);
cerr << "SSLServerCacheCloseTest test completed" << endl;
}
+#endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
/**
* Verify Client Ciphers obtained using SSL MSG Callback.
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
- serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
- serverCtx->loadPrivateKey(testKey);
- serverCtx->loadCertificate(testCert);
- serverCtx->loadTrustedCertificates(testCA);
- serverCtx->loadClientCAList(testCA);
+ serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
+ serverCtx->loadPrivateKey(kTestKey);
+ serverCtx->loadCertificate(kTestCert);
+ serverCtx->loadTrustedCertificates(kTestCA);
+ serverCtx->loadClientCAList(kTestCA);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
- clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
- clientCtx->loadPrivateKey(testKey);
- clientCtx->loadCertificate(testCert);
- clientCtx->loadTrustedCertificates(testCA);
+ clientCtx->ciphers("AES256-SHA:AES128-SHA");
+ clientCtx->loadPrivateKey(kTestKey);
+ clientCtx->loadCertificate(kTestCert);
+ clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
eventBase.loop();
- EXPECT_EQ(server.clientCiphers_,
- "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
+#if defined(OPENSSL_IS_BORINGSSL)
+ EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
+#else
+ EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
+#endif
+ EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
cursor.write<uint32_t>(0);
SSL* ssl = ctx->createSSL();
+ SCOPE_EXIT { SSL_free(ssl); };
AsyncSSLSocket::UniquePtr sock(
new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
sock->enableClientHelloParsing();
cursor.write<uint32_t>(0);
SSL* ssl = ctx->createSSL();
+ SCOPE_EXIT { SSL_free(ssl); };
AsyncSSLSocket::UniquePtr sock(
new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
sock->enableClientHelloParsing();
cursor.write<uint32_t>(0);
SSL* ssl = ctx->createSSL();
+ SCOPE_EXIT { SSL_free(ssl); };
AsyncSSLSocket::UniquePtr sock(
new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
sock->enableClientHelloParsing();
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, true);
- clientCtx->loadTrustedCertificates(testCA);
+ clientCtx->loadTrustedCertificates(kTestCA);
SSLHandshakeServer server(std::move(serverSock), true, true);
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
+ EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(!server.handshakeError_);
+ EXPECT_LE(0, server.handshakeTime.count());
}
/**
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, false);
- clientCtx->loadTrustedCertificates(testCA);
+ clientCtx->loadTrustedCertificates(kTestCA);
SSLHandshakeServer server(std::move(serverSock), true, true);
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(!client.handshakeSuccess_);
EXPECT_TRUE(client.handshakeError_);
+ EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(!server.handshakeSuccess_);
EXPECT_TRUE(server.handshakeError_);
+ EXPECT_LE(0, server.handshakeTime.count());
}
/**
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
- clientCtx->loadTrustedCertificates(testCA);
+ clientCtx->loadTrustedCertificates(kTestCA);
SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
EXPECT_TRUE(!client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
+ EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(!server.handshakeError_);
+ EXPECT_LE(0, server.handshakeTime.count());
}
/**
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
- serverCtx->loadPrivateKey(testKey);
- serverCtx->loadCertificate(testCert);
- serverCtx->loadTrustedCertificates(testCA);
- serverCtx->loadClientCAList(testCA);
+ serverCtx->loadPrivateKey(kTestKey);
+ serverCtx->loadCertificate(kTestCert);
+ serverCtx->loadTrustedCertificates(kTestCA);
+ serverCtx->loadClientCAList(kTestCA);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
- clientCtx->loadPrivateKey(testKey);
- clientCtx->loadCertificate(testCert);
- clientCtx->loadTrustedCertificates(testCA);
+ clientCtx->loadPrivateKey(kTestKey);
+ clientCtx->loadCertificate(kTestCert);
+ clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_FALSE(client.handshakeError_);
+ EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_FALSE(server.handshakeError_);
+ EXPECT_LE(0, server.handshakeTime.count());
}
/**
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
+ EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(!server.handshakeError_);
+ EXPECT_LE(0, server.handshakeTime.count());
}
/**
EXPECT_TRUE(!client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
+ EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(!server.handshakeError_);
+ EXPECT_LE(0, server.handshakeTime.count());
}
/**
serverCtx->setVerificationOption(
SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
- serverCtx->loadPrivateKey(testKey);
- serverCtx->loadCertificate(testCert);
- serverCtx->loadTrustedCertificates(testCA);
- serverCtx->loadClientCAList(testCA);
+ serverCtx->loadPrivateKey(kTestKey);
+ serverCtx->loadCertificate(kTestCert);
+ serverCtx->loadTrustedCertificates(kTestCA);
+ serverCtx->loadClientCAList(kTestCA);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
- clientCtx->loadPrivateKey(testKey);
- clientCtx->loadCertificate(testCert);
- clientCtx->loadTrustedCertificates(testCA);
+ clientCtx->loadPrivateKey(kTestKey);
+ clientCtx->loadCertificate(kTestCert);
+ clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_FALSE(client.handshakeError_);
+ EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_FALSE(server.handshakeError_);
+ EXPECT_LE(0, server.handshakeTime.count());
}
serverCtx->setVerificationOption(
SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
- serverCtx->loadPrivateKey(testKey);
- serverCtx->loadCertificate(testCert);
- serverCtx->loadTrustedCertificates(testCA);
- serverCtx->loadClientCAList(testCA);
+ serverCtx->loadPrivateKey(kTestKey);
+ serverCtx->loadCertificate(kTestCert);
+ serverCtx->loadTrustedCertificates(kTestCA);
+ serverCtx->loadClientCAList(kTestCA);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
EXPECT_FALSE(server.handshakeVerify_);
EXPECT_FALSE(server.handshakeSuccess_);
EXPECT_TRUE(server.handshakeError_);
+ EXPECT_LE(0, client.handshakeTime.count());
+ EXPECT_LE(0, server.handshakeTime.count());
+}
+
+TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
+ auto cert = getFileAsBuf(kTestCert);
+ auto key = getFileAsBuf(kTestKey);
+
+ ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
+ BIO_write(certBio.get(), cert.data(), cert.size());
+ ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
+ BIO_write(keyBio.get(), key.data(), key.size());
+
+ // Create SSL structs from buffers to get properties
+ ssl::X509UniquePtr certStruct(
+ PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
+ ssl::EvpPkeyUniquePtr keyStruct(
+ PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
+ certBio = nullptr;
+ keyBio = nullptr;
+
+ auto origCommonName = getCommonName(certStruct.get());
+ auto origKeySize = EVP_PKEY_bits(keyStruct.get());
+ certStruct = nullptr;
+ keyStruct = nullptr;
+
+ auto ctx = std::make_shared<SSLContext>();
+ ctx->loadPrivateKeyFromBufferPEM(key);
+ ctx->loadCertificateFromBufferPEM(cert);
+ ctx->loadTrustedCertificates(kTestCA);
+
+ ssl::SSLUniquePtr ssl(ctx->createSSL());
+
+ auto newCert = SSL_get_certificate(ssl.get());
+ auto newKey = SSL_get_privatekey(ssl.get());
+
+ // Get properties from SSL struct
+ auto newCommonName = getCommonName(newCert);
+ auto newKeySize = EVP_PKEY_bits(newKey);
+
+ // Check that the key and cert have the expected properties
+ EXPECT_EQ(origCommonName, newCommonName);
+ EXPECT_EQ(origKeySize, newKeySize);
}
TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
}
+TEST(AsyncSSLSocketTest, ConnResetErrorString) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ WriteErrorCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback,
+ HandshakeCallback::EXPECT_ERROR);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+ socket->open();
+ uint8_t buf[3] = {0x16, 0x03, 0x01};
+ socket->write(buf, sizeof(buf));
+ socket->closeWithReset();
+
+ handshakeCallback.waitForHandshake();
+ EXPECT_NE(
+ handshakeCallback.errorString_.find("Network error"), std::string::npos);
+ EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
+}
+
+TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ WriteErrorCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback,
+ HandshakeCallback::EXPECT_ERROR);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+ socket->open();
+ uint8_t buf[3] = {0x16, 0x03, 0x01};
+ socket->write(buf, sizeof(buf));
+ socket->close();
+
+ handshakeCallback.waitForHandshake();
+ EXPECT_NE(
+ handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
+ EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
+}
+
+TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ WriteErrorCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback,
+ HandshakeCallback::EXPECT_ERROR);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+ socket->open();
+ uint8_t buf[256] = {0x16, 0x03};
+ memset(buf + 2, 'a', sizeof(buf) - 2);
+ socket->write(buf, sizeof(buf));
+ socket->close();
+
+ handshakeCallback.waitForHandshake();
+ EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
+ std::string::npos);
+#if defined(OPENSSL_IS_BORINGSSL)
+ EXPECT_NE(
+ handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
+ std::string::npos);
+#else
+ EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
+ std::string::npos);
+#endif
+}
+
+TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
+ using folly::ssl::OpenSSLUtils;
+ EXPECT_EQ(
+ OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
+ // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
+ EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
+ // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
+ EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
+}
+
+#if FOLLY_ALLOW_TFO
+
+class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
+ public:
+ using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
+
+ explicit MockAsyncTFOSSLSocket(
+ std::shared_ptr<folly::SSLContext> sslCtx,
+ EventBase* evb)
+ : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
+
+ MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
+};
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, true);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->enableTFO();
+ socket->open();
+
+ // write()
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+ socket->write(buf.data(), buf.size());
+
+ // read()
+ std::array<uint8_t, 128> readbuf;
+ uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, 128);
+ EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+ // close()
+ socket->close();
+}
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, false);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->enableTFO();
+ socket->open();
+
+ // write()
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+ socket->write(buf.data(), buf.size());
+
+ // read()
+ std::array<uint8_t, 128> readbuf;
+ uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, 128);
+ EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+ // close()
+ socket->close();
+}
+
+class ConnCallback : public AsyncSocket::ConnectCallback {
+ public:
+ virtual void connectSuccess() noexcept override {
+ state = State::SUCCESS;
+ }
+
+ virtual void connectErr(const AsyncSocketException& ex) noexcept override {
+ state = State::ERROR;
+ error = ex.what();
+ }
+
+ enum class State { WAITING, SUCCESS, ERROR };
+
+ State state{State::WAITING};
+ std::string error;
+};
+
+template <class Cardinality>
+MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
+ EventBase* evb,
+ const SocketAddress& address,
+ Cardinality cardinality) {
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket = MockAsyncTFOSSLSocket::UniquePtr(
+ new MockAsyncTFOSSLSocket(sslContext, evb));
+ socket->enableTFO();
+
+ EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+ .Times(cardinality)
+ .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
+ sockaddr_storage addr;
+ auto len = address.getAddress(&addr);
+ return connect(fd, (const struct sockaddr*)&addr, len);
+ }));
+ return socket;
+}
+
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, true);
+
+ EventBase evb;
+
+ auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
+
+ evb.runInEventBaseThread([&] { socket->detachEventBase(); });
+ evb.loop();
+
+ BlockingSocket sock(std::move(socket));
+ // write()
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+ sock.write(buf.data(), buf.size());
+
+ // read()
+ std::array<uint8_t, 128> readbuf;
+ uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, 128);
+ EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+ // close()
+ sock.close();
+}
+
+#if !defined(OPENSSL_IS_BORINGSSL)
+TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
+ // Start listening on a local port
+ ConnectTimeoutCallback acceptCallback;
+ TestSSLServer server(&acceptCallback, true);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->enableTFO();
+ EXPECT_THROW(
+ socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
+}
+#endif
+
+#if !defined(OPENSSL_IS_BORINGSSL)
+TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
+ // Start listening on a local port
+ ConnectTimeoutCallback acceptCallback;
+ TestSSLServer server(&acceptCallback, true);
+
+ EventBase evb;
+
+ auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
+ ConnCallback ccb;
+ // Set a short timeout
+ socket->connect(&ccb, server.getAddress(), 1);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+}
+#endif
+
+TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
+ // Start listening on a local port
+ EmptyReadCallback readCallback;
+ HandshakeCallback handshakeCallback(
+ &readCallback, HandshakeCallback::EXPECT_ERROR);
+ HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, true);
+
+ EventBase evb;
+
+ auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 100);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+ EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
+}
+
+TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
+ // Start listening on a local port
+ EventBase evb;
+
+ // Hopefully nothing is listening on this address
+ SocketAddress addr("127.0.0.1", 65535);
+ auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
+ ConnCallback ccb;
+ socket->connect(&ccb, addr, 100);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+ EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
+}
+
+TEST(AsyncSSLSocketTest, TestPreReceivedData) {
+ EventBase clientEventBase;
+ EventBase serverEventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto dfServerCtx = std::make_shared<SSLContext>();
+ std::array<int, 2> fds;
+ getfds(fds.data());
+ getctx(clientCtx, dfServerCtx);
+
+ AsyncSSLSocket::UniquePtr clientSockPtr(
+ new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSockPtr(
+ new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
+ auto clientSock = clientSockPtr.get();
+ auto serverSock = serverSockPtr.get();
+ SSLHandshakeClient client(std::move(clientSockPtr), true, true);
+
+ // Steal some data from the server.
+ clientEventBase.loopOnce();
+ std::array<uint8_t, 10> buf;
+ recv(fds[1], buf.data(), buf.size(), 0);
+
+ serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
+ SSLHandshakeServer server(std::move(serverSockPtr), true, true);
+ while (!client.handshakeSuccess_ && !client.handshakeError_) {
+ serverEventBase.loopOnce();
+ clientEventBase.loopOnce();
+ }
+
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+ EXPECT_EQ(
+ serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
+}
+
+#endif
+
} // namespace
+#ifdef SIGPIPE
///////////////////////////////////////////////////////////////////////////
// init_unit_test_suite
///////////////////////////////////////////////////////////////////////////
};
Initializer initializer;
} // anonymous
+#endif