/*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2011-present Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
*/
#include <folly/io/async/test/AsyncSSLSocketTest.h>
-#include <pthread.h>
-#include <signal.h>
-
#include <folly/SocketAddress.h>
+#include <folly/io/Cursor.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h>
+#include <folly/portability/GMock.h>
+#include <folly/portability/GTest.h>
+#include <folly/portability/OpenSSL.h>
#include <folly/portability/Sockets.h>
#include <folly/portability/Unistd.h>
#include <folly/io/async/test/BlockingSocket.h>
#include <fcntl.h>
-#include <folly/io/Cursor.h>
-#include <gtest/gtest.h>
-#include <openssl/bio.h>
+#include <signal.h>
#include <sys/types.h>
+
#include <fstream>
#include <iostream>
#include <list>
#include <set>
#include <thread>
-#include <gmock/gmock.h>
+#ifdef FOLLY_HAVE_MSG_ERRQUEUE
+#include <sys/utsname.h>
+#endif
using std::string;
using std::vector;
uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
-const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
-const char* testKey = "folly/io/async/test/certs/tests-key.pem";
-const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
-
constexpr size_t SSLClient::kMaxReadBufferSz;
constexpr size_t SSLClient::kMaxReadsPerEvent;
-TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
- : ctx_(new folly::SSLContext),
- acb_(acb),
- socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
- // Set up the SSL context
- ctx_->loadCertificate(testCert);
- ctx_->loadPrivateKey(testKey);
- ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
-
- acb_->ctx_ = ctx_;
- acb_->base_ = &evb_;
-
- // Enable TFO
- if (enableTFO) {
- LOG(INFO) << "server TFO enabled";
- socket_->setTFOEnabled(true, 1000);
- }
-
- // set up the listening socket
- socket_->bind(0);
- socket_->getAddress(&address_);
- socket_->listen(100);
- socket_->addAcceptCallback(acb_, &evb_);
- socket_->startAccepting();
-
- int ret = pthread_create(&thread_, nullptr, Main, this);
- assert(ret == 0);
- (void)ret;
-
- std::cerr << "Accepting connections on " << address_ << std::endl;
-}
-
void getfds(int fds[2]) {
if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
FAIL() << "failed to create socketpair: " << strerror(errno);
<< strerror(errno);
}
if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
- FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
- << strerror(errno);
+ FAIL() << "failed to put socket " << idx
+ << " in non-blocking mode: " << strerror(errno);
}
}
}
void getctx(
- std::shared_ptr<folly::SSLContext> clientCtx,
- std::shared_ptr<folly::SSLContext> serverCtx) {
+ std::shared_ptr<folly::SSLContext> clientCtx,
+ std::shared_ptr<folly::SSLContext> serverCtx) {
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
- serverCtx->loadCertificate(
- testCert);
- serverCtx->loadPrivateKey(
- testKey);
+ serverCtx->loadCertificate(kTestCert);
+ serverCtx->loadPrivateKey(kTestKey);
}
void sslsocketpair(
- EventBase* eventBase,
- AsyncSSLSocket::UniquePtr* clientSock,
- AsyncSSLSocket::UniquePtr* serverSock) {
+ EventBase* eventBase,
+ AsyncSSLSocket::UniquePtr* clientSock,
+ AsyncSSLSocket::UniquePtr* serverSock) {
auto clientCtx = std::make_shared<folly::SSLContext>();
auto serverCtx = std::make_shared<folly::SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, serverCtx);
- clientSock->reset(new AsyncSSLSocket(
- clientCtx, eventBase, fds[0], false));
- serverSock->reset(new AsyncSSLSocket(
- serverCtx, eventBase, fds[1], true));
+ clientSock->reset(new AsyncSSLSocket(clientCtx, eventBase, fds[0], false));
+ serverSock->reset(new AsyncSSLSocket(serverCtx, eventBase, fds[1], true));
// (*clientSock)->setSendTimeout(100);
// (*serverSock)->setSendTimeout(100);
}
// client protocol filters
-bool clientProtoFilterPickPony(unsigned char** client,
- unsigned int* client_len, const unsigned char*, unsigned int ) {
- //the protocol string in length prefixed byte string. the
- //length byte is not included in the length
- static unsigned char p[7] = {6,'p','o','n','i','e','s'};
+bool clientProtoFilterPickPony(
+ unsigned char** client,
+ unsigned int* client_len,
+ const unsigned char*,
+ unsigned int) {
+ // the protocol string in length prefixed byte string. the
+ // length byte is not included in the length
+ static unsigned char p[7] = {6, 'p', 'o', 'n', 'i', 'e', 's'};
*client = p;
*client_len = 7;
return true;
}
-bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
- const unsigned char*, unsigned int) {
+bool clientProtoFilterPickNone(
+ unsigned char**,
+ unsigned int*,
+ const unsigned char*,
+ unsigned int) {
return false;
}
return cn;
}
+TEST(AsyncSSLSocketTest, ClientCertValidationResultTest) {
+ EventBase ev;
+ int fd = 0;
+
+ AsyncSSLSocket::UniquePtr sock(
+ new AsyncSSLSocket(std::make_shared<SSLContext>(), &ev, fd, false));
+
+ // Initially the cert is not validated, so no result is available.
+ EXPECT_EQ(nullptr, get_pointer(sock->getClientCertValidationResult()));
+
+ sock->setClientCertValidationResult(
+ make_optional(AsyncSSLSocket::CertValidationResult::CERT_VALID));
+
+ EXPECT_EQ(
+ AsyncSSLSocket::CertValidationResult::CERT_VALID,
+ *sock->getClientCertValidationResult());
+}
+
/**
* Test connecting to, writing to, reading from, and closing the
* connection to the SSL server.
// Set up SSL context.
std::shared_ptr<SSLContext> sslContext(new SSLContext());
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
- //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
- //sslContext->authenticate(true, false);
+ // sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
+ // sslContext->authenticate(true, false);
// connect
- auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
- sslContext);
- socket->open();
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->open(std::chrono::milliseconds(10000));
// write()
uint8_t buf[128];
socket->close();
cerr << "ConnectWriteReadClose test completed" << endl;
+ EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
}
/**
ReadEOFCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
- auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
+ auto server = std::make_unique<TestSSLServer>(&acceptCallback);
// Set up SSL context.
auto sslContext = std::make_shared<SSLContext>();
/**
* Test bad renegotiation
*/
+#if !defined(OPENSSL_IS_BORINGSSL)
TEST(AsyncSSLSocketTest, Renegotiate) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
eventBase.loop();
ASSERT_TRUE(server.renegotiationError_);
}
+#endif
/**
* Negative test for handshakeError().
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect
- auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
- sslContext);
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
// read()
bool ex = false;
try {
uint8_t readbuf[128];
uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
- } catch (AsyncSocketException &e) {
+ } catch (AsyncSocketException&) {
ex = true;
}
EXPECT_TRUE(ex);
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect
- auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
- sslContext);
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->open();
// write something to trigger ssl handshake
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect
- auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
- sslContext);
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->open();
// write something to trigger ssl handshake
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect
- auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
- sslContext);
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->open();
// write()
class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
// For matching protos
public:
- void SetUp() override { getctx(clientCtx, serverCtx); }
+ void SetUp() override {
+ getctx(clientCtx, serverCtx);
+ }
void connect(bool unset = false) {
getfds(fds);
}
AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
- client = folly::make_unique<NpnClient>(std::move(clientSock));
- server = folly::make_unique<NpnServer>(std::move(serverSock));
+ new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+ client = std::make_unique<NpnClient>(std::move(clientSock));
+ server = std::make_unique<NpnServer>(std::move(serverSock));
eventBase.loop();
}
void expectProtocol(const std::string& proto) {
+ expectHandshakeSuccess();
EXPECT_NE(client->nextProtoLength, 0);
EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
EXPECT_EQ(
}
void expectNoProtocol() {
+ expectHandshakeSuccess();
EXPECT_EQ(client->nextProtoLength, 0);
EXPECT_EQ(server->nextProtoLength, 0);
EXPECT_EQ(client->nextProto, nullptr);
}
void expectProtocolType() {
+ expectHandshakeSuccess();
if (GetParam().first == SSLContext::NextProtocolType::ANY &&
GetParam().second == SSLContext::NextProtocolType::ANY) {
EXPECT_EQ(client->protocolType, server->protocolType);
}
void expectProtocolType(NextProtocolTypePair expected) {
+ expectHandshakeSuccess();
EXPECT_EQ(client->protocolType, expected.first);
EXPECT_EQ(server->protocolType, expected.second);
}
+ void expectHandshakeSuccess() {
+ EXPECT_FALSE(client->except.hasValue())
+ << "client handshake error: " << client->except->what();
+ EXPECT_FALSE(server->except.hasValue())
+ << "server handshake error: " << server->except->what();
+ }
+
+ void expectHandshakeError() {
+ EXPECT_TRUE(client->except.hasValue())
+ << "Expected client handshake error!";
+ EXPECT_TRUE(server->except.hasValue())
+ << "Expected server handshake error!";
+ }
+
EventBase eventBase;
std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
TEST_P(NextProtocolTest, NpnTestOverlap) {
clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
- serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
- GetParam().second);
+ serverCtx->setAdvertisedNextProtocols(
+ {"foo", "bar", "baz"}, GetParam().second);
connect();
// Identical to above test, except that we want unset NPN before
// looping.
clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
- serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
- GetParam().second);
+ serverCtx->setAdvertisedNextProtocols(
+ {"foo", "bar", "baz"}, GetParam().second);
connect(true /* unset */);
TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
- serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
- GetParam().second);
+ serverCtx->setAdvertisedNextProtocols(
+ {"foo", "bar", "baz"}, GetParam().second);
connect();
// will fail on 1.0.2 before that.
TEST_P(NextProtocolTest, NpnTestNoOverlap) {
clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
- serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
- GetParam().second);
-
+ serverCtx->setAdvertisedNextProtocols(
+ {"foo", "bar", "baz"}, GetParam().second);
connect();
if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
GetParam().second == SSLContext::NextProtocolType::ALPN) {
// This is arguably incorrect behavior since RFC7301 states an ALPN protocol
- // mismatch should result in a fatal alert, but this is OpenSSL's current
- // behavior and we want to know if it changes.
+ // mismatch should result in a fatal alert, but this is the current behavior
+ // on all OpenSSL versions/variants, and we want to know if it changes.
+ expectNoProtocol();
+ }
+#if FOLLY_OPENSSL_IS_110 || defined(OPENSSL_IS_BORINGSSL)
+ else if (
+ GetParam().first == SSLContext::NextProtocolType::ANY &&
+ GetParam().second == SSLContext::NextProtocolType::ANY) {
+#if FOLLY_OPENSSL_IS_110
+ // OpenSSL 1.1.0 sends a fatal alert on mismatch, which is probavbly the
+ // correct behavior per RFC7301
+ expectHandshakeError();
+#else
+ // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
+ // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
+ // NPN *and* ALPN
expectNoProtocol();
- } else {
+#endif
+ }
+#endif
+ else {
expectProtocol("blub");
expectProtocolType(
{SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
- serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
- GetParam().second);
+ serverCtx->setAdvertisedNextProtocols(
+ {"foo", "bar", "baz"}, GetParam().second);
connect();
TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
- serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
- GetParam().second);
+ serverCtx->setAdvertisedNextProtocols(
+ {"foo", "bar", "baz"}, GetParam().second);
connect();
// as negligible.
const int kTries = 64;
- clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
- GetParam().first);
- serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
- GetParam().second);
+ clientCtx->setAdvertisedNextProtocols(
+ {"foo", "bar", "baz"}, GetParam().first);
+ serverCtx->setRandomizedAdvertisedNextProtocols(
+ {{1, {"foo"}}, {1, {"bar"}}}, GetParam().second);
std::set<string> selectedProtocols;
for (int i = 0; i < kTries; ++i) {
SSLContext::NextProtocolType::ANY,
SSLContext::NextProtocolType::ANY)));
-#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_ALPN
INSTANTIATE_TEST_CASE_P(
AsyncSSLSocketTest,
NextProtocolTLSExtTest,
INSTANTIATE_TEST_CASE_P(
AsyncSSLSocketTest,
NextProtocolNPNOnlyTest,
- ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
- SSLContext::NextProtocolType::NPN)));
+ ::testing::Values(NextProtocolTypePair(
+ SSLContext::NextProtocolType::NPN,
+ SSLContext::NextProtocolType::NPN)));
-#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_ALPN
INSTANTIATE_TEST_CASE_P(
AsyncSSLSocketTest,
NextProtocolMismatchTest,
- ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
- SSLContext::NextProtocolType::ALPN),
- NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
- SSLContext::NextProtocolType::NPN)));
+ ::testing::Values(
+ NextProtocolTypePair(
+ SSLContext::NextProtocolType::NPN,
+ SSLContext::NextProtocolType::ALPN),
+ NextProtocolTypePair(
+ SSLContext::NextProtocolType::ALPN,
+ SSLContext::NextProtocolType::NPN)));
#endif
#ifndef OPENSSL_NO_TLSEXT
getctx(clientCtx, dfServerCtx);
AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SNIClient client(std::move(clientSock));
- SNIServer server(std::move(serverSock),
- dfServerCtx,
- hskServerCtx,
- serverName);
+ SNIServer server(
+ std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
eventBase.loop();
getfds(fds);
getctx(clientCtx, dfServerCtx);
- AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx,
- &eventBase,
- fds[0],
- clientRequestingServerName));
+ AsyncSSLSocket::UniquePtr clientSock(new AsyncSSLSocket(
+ clientCtx, &eventBase, fds[0], clientRequestingServerName));
AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SNIClient client(std::move(clientSock));
- SNIServer server(std::move(serverSock),
- dfServerCtx,
- hskServerCtx,
- serverExpectedServerName);
+ SNIServer server(
+ std::move(serverSock),
+ dfServerCtx,
+ hskServerCtx,
+ serverExpectedServerName);
eventBase.loop();
*/
TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
- EventBase eventBase;
+ EventBase eventBase;
std::shared_ptr<SSLContext> clientCtx(new SSLContext);
std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
// Use the same SSLContext to continue the handshake after
getctx(clientCtx, dfServerCtx);
AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
- //Change the server name
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
+ // Change the server name
std::string newName("new.com");
clientSock->setServerName(newName);
AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SNIClient client(std::move(clientSock));
- SNIServer server(std::move(serverSock),
- dfServerCtx,
- hskServerCtx,
- serverName);
+ SNIServer server(
+ std::move(serverSock), dfServerCtx, hskServerCtx, serverName);
eventBase.loop();
getctx(clientCtx, dfServerCtx);
AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SNIClient client(std::move(clientSock));
- SNIServer server(std::move(serverSock),
- dfServerCtx,
- hskServerCtx,
- serverExpectedServerName);
+ SNIServer server(
+ std::move(serverSock),
+ dfServerCtx,
+ hskServerCtx,
+ serverExpectedServerName);
eventBase.loop();
cerr << "SSLClientTest test completed" << endl;
}
-
/**
* Test SSL client socket session re-use
*/
TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
// Start listening on a local port
EmptyReadCallback readCallback;
- HandshakeCallback handshakeCallback(&readCallback,
- HandshakeCallback::EXPECT_ERROR);
+ HandshakeCallback handshakeCallback(
+ &readCallback, HandshakeCallback::EXPECT_ERROR);
HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
cerr << "SSLClientTimeoutTest test completed" << endl;
}
-
+// The next 3 tests need an FB-only extension, and will fail without it
+#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
/**
* Test SSL server async cache
*/
cerr << "SSLServerAsyncCacheTest test completed" << endl;
}
-
/**
* Test SSL server accept timeout with cache path
*/
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
- EmptyReadCallback clientReadCallback;
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
TestSSLAsyncCacheServer server(&acceptCallback);
// only do a TCP connect
std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
sock->connect(nullptr, server.getAddress());
+
+ EmptyReadCallback clientReadCallback;
clientReadCallback.tcpSocket_ = sock;
sock->setReadCB(&clientReadCallback);
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
- HandshakeCallback handshakeCallback(&readCallback,
- HandshakeCallback::EXPECT_ERROR);
+ HandshakeCallback handshakeCallback(
+ &readCallback, HandshakeCallback::EXPECT_ERROR);
SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
TestSSLAsyncCacheServer server(&acceptCallback, 500);
EventBaseAborter eba(&eventBase, 3000);
eventBase.loop();
- server.getEventBase().runInEventBaseThread([&handshakeCallback]{
- handshakeCallback.closeSocket();});
+ server.getEventBase().runInEventBaseThread(
+ [&handshakeCallback] { handshakeCallback.closeSocket(); });
// give time for the cache lookup to come back and find it closed
handshakeCallback.waitForHandshake();
cerr << "SSLServerCacheCloseTest test completed" << endl;
}
+#endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
/**
* Verify Client Ciphers obtained using SSL MSG Callback.
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
- serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
- serverCtx->loadPrivateKey(testKey);
- serverCtx->loadCertificate(testCert);
- serverCtx->loadTrustedCertificates(testCA);
- serverCtx->loadClientCAList(testCA);
+ serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
+ serverCtx->loadPrivateKey(kTestKey);
+ serverCtx->loadCertificate(kTestCert);
+ serverCtx->loadTrustedCertificates(kTestCA);
+ serverCtx->loadClientCAList(kTestCA);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
- clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
- clientCtx->loadPrivateKey(testKey);
- clientCtx->loadCertificate(testCert);
- clientCtx->loadTrustedCertificates(testCA);
+ clientCtx->ciphers("AES256-SHA:AES128-SHA");
+ clientCtx->loadPrivateKey(kTestKey);
+ clientCtx->loadCertificate(kTestCert);
+ clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
eventBase.loop();
- EXPECT_EQ(server.clientCiphers_,
- "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
+#if defined(OPENSSL_IS_BORINGSSL)
+ EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
+#else
+ EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
+#endif
+ EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
EXPECT_TRUE(!server.handshakeError_);
}
+/**
+ * Verify that server is able to get client cert by getPeerCert() API.
+ */
+TEST(AsyncSSLSocketTest, GetClientCertificate) {
+ EventBase eventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto serverCtx = std::make_shared<SSLContext>();
+ serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+ serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
+ serverCtx->loadPrivateKey(kTestKey);
+ serverCtx->loadCertificate(kTestCert);
+ serverCtx->loadTrustedCertificates(kClientTestCA);
+ serverCtx->loadClientCAList(kClientTestCA);
+
+ clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
+ clientCtx->ciphers("AES256-SHA:AES128-SHA");
+ clientCtx->loadPrivateKey(kClientTestKey);
+ clientCtx->loadCertificate(kClientTestCert);
+ clientCtx->loadTrustedCertificates(kTestCA);
+
+ std::array<int, 2> fds;
+ getfds(fds.data());
+
+ AsyncSSLSocket::UniquePtr clientSock(
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSock(
+ new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+
+ SSLHandshakeClient client(std::move(clientSock), true, true);
+ SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
+
+ eventBase.loop();
+
+ // Handshake should succeed.
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+
+ // Reclaim the sockets from SSLHandshakeBase.
+ auto cliSocket = std::move(client).moveSocket();
+ auto srvSocket = std::move(server).moveSocket();
+
+ // Client cert retrieved from server side.
+ folly::ssl::X509UniquePtr serverPeerCert = srvSocket->getPeerCert();
+ CHECK(serverPeerCert);
+
+ // Client cert retrieved from client side.
+ const X509* clientSelfCert = cliSocket->getSelfCert();
+ CHECK(clientSelfCert);
+
+ // The two certs should be the same.
+ EXPECT_EQ(0, X509_cmp(clientSelfCert, serverPeerCert.get()));
+}
+
TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
EventBase eventBase;
auto ctx = std::make_shared<SSLContext>();
cursor.write<uint32_t>(0);
SSL* ssl = ctx->createSSL();
- SCOPE_EXIT { SSL_free(ssl); };
+ SCOPE_EXIT {
+ SSL_free(ssl);
+ };
AsyncSSLSocket::UniquePtr sock(
new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
sock->enableClientHelloParsing();
cursor.write<uint32_t>(0);
SSL* ssl = ctx->createSSL();
- SCOPE_EXIT { SSL_free(ssl); };
+ SCOPE_EXIT {
+ SSL_free(ssl);
+ };
AsyncSSLSocket::UniquePtr sock(
new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
sock->enableClientHelloParsing();
// Test parsing with two packets with first packet size < 3
auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
AsyncSSLSocket::clientHelloParsingCallback(
- 0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
- ssl, sock.get());
+ 0,
+ 0,
+ SSL3_RT_HANDSHAKE,
+ bufCopy->data(),
+ bufCopy->length(),
+ ssl,
+ sock.get());
bufCopy.reset();
bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
AsyncSSLSocket::clientHelloParsingCallback(
- 0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
- ssl, sock.get());
+ 0,
+ 0,
+ SSL3_RT_HANDSHAKE,
+ bufCopy->data(),
+ bufCopy->length(),
+ ssl,
+ sock.get());
bufCopy.reset();
auto parsedClientHello = sock->getClientHelloInfo();
cursor.write<uint32_t>(0);
SSL* ssl = ctx->createSSL();
- SCOPE_EXIT { SSL_free(ssl); };
+ SCOPE_EXIT {
+ SSL_free(ssl);
+ };
AsyncSSLSocket::UniquePtr sock(
new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
sock->enableClientHelloParsing();
auto bufCopy = folly::IOBuf::copyBuffer(
buf->data() + i, std::min((uint64_t)3, buf->length() - i));
AsyncSSLSocket::clientHelloParsingCallback(
- 0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
- ssl, sock.get());
+ 0,
+ 0,
+ SSL3_RT_HANDSHAKE,
+ bufCopy->data(),
+ bufCopy->length(),
+ ssl,
+ sock.get());
bufCopy.reset();
}
dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, true);
- clientCtx->loadTrustedCertificates(testCA);
+ clientCtx->loadTrustedCertificates(kTestCA);
SSLHandshakeServer server(std::move(serverSock), true, true);
dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, false);
- clientCtx->loadTrustedCertificates(testCA);
+ clientCtx->loadTrustedCertificates(kTestCA);
SSLHandshakeServer server(std::move(serverSock), true, true);
dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
- clientCtx->loadTrustedCertificates(testCA);
+ clientCtx->loadTrustedCertificates(kTestCA);
SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
- serverCtx->loadPrivateKey(testKey);
- serverCtx->loadCertificate(testCert);
- serverCtx->loadTrustedCertificates(testCA);
- serverCtx->loadClientCAList(testCA);
+ serverCtx->loadPrivateKey(kTestKey);
+ serverCtx->loadCertificate(kTestCert);
+ serverCtx->loadTrustedCertificates(kTestCA);
+ serverCtx->loadClientCAList(kTestCA);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
- clientCtx->loadPrivateKey(testKey);
- clientCtx->loadCertificate(testCert);
- clientCtx->loadTrustedCertificates(testCA);
+ clientCtx->loadPrivateKey(kTestKey);
+ clientCtx->loadCertificate(kTestCert);
+ clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, true);
SSLHandshakeServer server(std::move(serverSock), true, true);
dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
AsyncSSLSocket::UniquePtr clientSock(
- new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+ new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
- new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+ new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServer server(std::move(serverSock), false, false);
serverCtx->setVerificationOption(
SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
- serverCtx->loadPrivateKey(testKey);
- serverCtx->loadCertificate(testCert);
- serverCtx->loadTrustedCertificates(testCA);
- serverCtx->loadClientCAList(testCA);
+ serverCtx->loadPrivateKey(kTestKey);
+ serverCtx->loadCertificate(kTestCert);
+ serverCtx->loadTrustedCertificates(kTestCA);
+ serverCtx->loadClientCAList(kTestCA);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
- clientCtx->loadPrivateKey(testKey);
- clientCtx->loadCertificate(testCert);
- clientCtx->loadTrustedCertificates(testCA);
+ clientCtx->loadPrivateKey(kTestKey);
+ clientCtx->loadCertificate(kTestCert);
+ clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
EXPECT_LE(0, server.handshakeTime.count());
}
-
/**
* Test requireClientCert with no client cert
*/
serverCtx->setVerificationOption(
SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
- serverCtx->loadPrivateKey(testKey);
- serverCtx->loadCertificate(testCert);
- serverCtx->loadTrustedCertificates(testCA);
- serverCtx->loadClientCAList(testCA);
+ serverCtx->loadPrivateKey(kTestKey);
+ serverCtx->loadCertificate(kTestCert);
+ serverCtx->loadTrustedCertificates(kTestCA);
+ serverCtx->loadClientCAList(kTestCA);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
}
TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
- auto cert = getFileAsBuf(testCert);
- auto key = getFileAsBuf(testKey);
+ auto cert = getFileAsBuf(kTestCert);
+ auto key = getFileAsBuf(kTestKey);
ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
BIO_write(certBio.get(), cert.data(), cert.size());
auto ctx = std::make_shared<SSLContext>();
ctx->loadPrivateKeyFromBufferPEM(key);
ctx->loadCertificateFromBufferPEM(cert);
- ctx->loadTrustedCertificates(testCA);
+ ctx->loadTrustedCertificates(kTestCA);
ssl::SSLUniquePtr ssl(ctx->createSSL());
class ReadCallbackTerminator : public ReadCallback {
public:
- ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
- : ReadCallback(wcb)
- , base_(base) {}
+ ReadCallbackTerminator(EventBase* base, WriteCallbackBase* wcb)
+ : ReadCallback(wcb), base_(base) {}
// Do not write data back, terminate the loop.
void readDataAvailable(size_t len) noexcept override {
socket_->setReadCB(nullptr);
base_->terminateLoopSoon();
}
+
private:
EventBase* base_;
};
-
/**
* Test a full unencrypted codepath
*/
int fds[2];
getfds(fds);
getctx(clientCtx, serverCtx);
- auto client = AsyncSSLSocket::newSocket(
- clientCtx, &base, fds[0], false, true);
- auto server = AsyncSSLSocket::newSocket(
- serverCtx, &base, fds[1], true, true);
+ auto client =
+ AsyncSSLSocket::newSocket(clientCtx, &base, fds[0], false, true);
+ auto server = AsyncSSLSocket::newSocket(serverCtx, &base, fds[1], true, true);
ReadCallbackTerminator readCallback(&base, nullptr);
server->setReadCB(&readCallback);
EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
EXPECT_NE('a', c2);
-
base.loop();
EXPECT_EQ(2, readCallback.buffers.size());
EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
}
+TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) {
+ auto clientCtx = std::make_shared<folly::SSLContext>();
+ auto serverCtx = std::make_shared<folly::SSLContext>();
+ getctx(clientCtx, serverCtx);
+
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ EventBase evb;
+ std::shared_ptr<AsyncSSLSocket> socket =
+ AsyncSSLSocket::newSocket(clientCtx, &evb, true);
+ socket->connect(nullptr, server.getAddress(), 0);
+
+ evb.loop();
+
+ EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, socket->getSSLState());
+ socket->sslConn(nullptr);
+ evb.loop();
+ EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, socket->getSSLState());
+
+ // write()
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+ socket->write(nullptr, buf.data(), buf.size());
+
+ socket->close();
+}
+
TEST(AsyncSSLSocketTest, ConnResetErrorString) {
// Start listening on a local port
WriteCallbackBase writeCallback;
WriteErrorCallback readCallback(&writeCallback);
- HandshakeCallback handshakeCallback(&readCallback,
- HandshakeCallback::EXPECT_ERROR);
+ HandshakeCallback handshakeCallback(
+ &readCallback, HandshakeCallback::EXPECT_ERROR);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
// Start listening on a local port
WriteCallbackBase writeCallback;
WriteErrorCallback readCallback(&writeCallback);
- HandshakeCallback handshakeCallback(&readCallback,
- HandshakeCallback::EXPECT_ERROR);
+ HandshakeCallback handshakeCallback(
+ &readCallback, HandshakeCallback::EXPECT_ERROR);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
socket->close();
handshakeCallback.waitForHandshake();
+#if FOLLY_OPENSSL_IS_110
+ EXPECT_NE(
+ handshakeCallback.errorString_.find("Network error"), std::string::npos);
+#else
EXPECT_NE(
handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
- EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
+#endif
}
TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
// Start listening on a local port
WriteCallbackBase writeCallback;
WriteErrorCallback readCallback(&writeCallback);
- HandshakeCallback handshakeCallback(&readCallback,
- HandshakeCallback::EXPECT_ERROR);
+ HandshakeCallback handshakeCallback(
+ &readCallback, HandshakeCallback::EXPECT_ERROR);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
socket->close();
handshakeCallback.waitForHandshake();
- EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
- std::string::npos);
- EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
- std::string::npos);
+ EXPECT_NE(
+ handshakeCallback.errorString_.find("SSL routines"), std::string::npos);
+#if defined(OPENSSL_IS_BORINGSSL)
+ EXPECT_NE(
+ handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
+ std::string::npos);
+#elif FOLLY_OPENSSL_IS_110
+ EXPECT_NE(
+ handshakeCallback.errorString_.find("packet length too long"),
+ std::string::npos);
+#else
+ EXPECT_NE(
+ handshakeCallback.errorString_.find("unknown protocol"),
+ std::string::npos);
+#endif
+}
+
+TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
+ using folly::ssl::OpenSSLUtils;
+ EXPECT_EQ(
+ OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
+ // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
+ EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
+ // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
+ EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
}
#if FOLLY_ALLOW_TFO
class ConnCallback : public AsyncSocket::ConnectCallback {
public:
- virtual void connectSuccess() noexcept override {
+ void connectSuccess() noexcept override {
state = State::SUCCESS;
}
- virtual void connectErr(const AsyncSocketException&) noexcept override {
+ void connectErr(const AsyncSocketException& ex) noexcept override {
state = State::ERROR;
+ error = ex.what();
}
enum class State { WAITING, SUCCESS, ERROR };
State state{State::WAITING};
+ std::string error;
};
template <class Cardinality>
sock.close();
}
+#if !defined(OPENSSL_IS_BORINGSSL)
TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
// Start listening on a local port
ConnectTimeoutCallback acceptCallback;
std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->enableTFO();
EXPECT_THROW(
- socket->open(std::chrono::milliseconds(1)), AsyncSocketException);
+ socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
}
+#endif
+#if !defined(OPENSSL_IS_BORINGSSL)
TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
// Start listening on a local port
ConnectTimeoutCallback acceptCallback;
evb.loop();
EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
}
+#endif
+
+TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
+ // Start listening on a local port
+ EmptyReadCallback readCallback;
+ HandshakeCallback handshakeCallback(
+ &readCallback, HandshakeCallback::EXPECT_ERROR);
+ HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, true);
+
+ EventBase evb;
+
+ auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 100);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+ EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
+}
+
+TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
+ // Start listening on a local port
+ EventBase evb;
+
+ // Hopefully nothing is listening on this address
+ SocketAddress addr("127.0.0.1", 65535);
+ auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
+ ConnCallback ccb;
+ socket->connect(&ccb, addr, 100);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+ EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
+}
+
+TEST(AsyncSSLSocketTest, TestPreReceivedData) {
+ EventBase clientEventBase;
+ EventBase serverEventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto dfServerCtx = std::make_shared<SSLContext>();
+ std::array<int, 2> fds;
+ getfds(fds.data());
+ getctx(clientCtx, dfServerCtx);
+
+ AsyncSSLSocket::UniquePtr clientSockPtr(
+ new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
+ AsyncSSLSocket::UniquePtr serverSockPtr(
+ new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
+ auto clientSock = clientSockPtr.get();
+ auto serverSock = serverSockPtr.get();
+ SSLHandshakeClient client(std::move(clientSockPtr), true, true);
+
+ // Steal some data from the server.
+ clientEventBase.loopOnce();
+ std::array<uint8_t, 10> buf;
+ recv(fds[1], buf.data(), buf.size(), 0);
+
+ serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
+ SSLHandshakeServer server(std::move(serverSockPtr), true, true);
+ while (!client.handshakeSuccess_ && !client.handshakeError_) {
+ serverEventBase.loopOnce();
+ clientEventBase.loopOnce();
+ }
+
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+ EXPECT_EQ(
+ serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
+}
+
+TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) {
+ EventBase clientEventBase;
+ EventBase serverEventBase;
+ auto clientCtx = std::make_shared<SSLContext>();
+ auto dfServerCtx = std::make_shared<SSLContext>();
+ std::array<int, 2> fds;
+ getfds(fds.data());
+ getctx(clientCtx, dfServerCtx);
+
+ AsyncSSLSocket::UniquePtr clientSockPtr(
+ new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
+ AsyncSocket::UniquePtr serverSockPtr(
+ new AsyncSocket(&serverEventBase, fds[1]));
+ auto clientSock = clientSockPtr.get();
+ auto serverSock = serverSockPtr.get();
+ SSLHandshakeClient client(std::move(clientSockPtr), true, true);
+
+ // Steal some data from the server.
+ clientEventBase.loopOnce();
+ std::array<uint8_t, 10> buf;
+ recv(fds[1], buf.data(), buf.size(), 0);
+ serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
+ AsyncSSLSocket::UniquePtr serverSSLSockPtr(
+ new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true));
+ auto serverSSLSock = serverSSLSockPtr.get();
+ SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true);
+ while (!client.handshakeSuccess_ && !client.handshakeError_) {
+ serverEventBase.loopOnce();
+ clientEventBase.loopOnce();
+ }
+
+ EXPECT_TRUE(client.handshakeSuccess_);
+ EXPECT_TRUE(server.handshakeSuccess_);
+ EXPECT_EQ(
+ serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
+}
+
+/**
+ * Test overriding the flags passed to "sendmsg()" system call,
+ * and verifying that write requests fail properly.
+ */
+TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
+ // Start listening on a local port
+ SendMsgFlagsCallback msgCallback;
+ ExpectWriteErrorCallback writeCallback(&msgCallback);
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+ sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->open();
+
+ // Setting flags to "-1" to trigger "Invalid argument" error
+ // on attempt to use this flags in sendmsg() system call.
+ msgCallback.resetFlags(-1);
+
+ // write()
+ std::vector<uint8_t> buf(128, 'a');
+ ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
+
+ // close()
+ socket->close();
+
+ cerr << "SendMsgParamsCallback test completed" << endl;
+}
+
+#ifdef FOLLY_HAVE_MSG_ERRQUEUE
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server.
+ */
+TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
+ // This test requires Linux kernel v4.6 or later
+ struct utsname s_uname;
+ memset(&s_uname, 0, sizeof(s_uname));
+ ASSERT_EQ(uname(&s_uname), 0);
+ int major, minor;
+ folly::StringPiece extra;
+ if (folly::split<false>(
+ '.', std::string(s_uname.release) + ".", major, minor, extra)) {
+ if (major < 4 || (major == 4 && minor < 6)) {
+ LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
+ << "kernel ver. " << s_uname.release << " detected).";
+ return;
+ }
+ }
+
+ // Start listening on a local port
+ SendMsgDataCallback msgCallback;
+ WriteCheckTimestampCallback writeCallback(&msgCallback);
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+ sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->open();
+
+ // Adding MSG_EOR flag to the message flags - it'll trigger
+ // timestamp generation for the last byte of the message.
+ msgCallback.resetFlags(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_EOR);
+
+ // Init ancillary data buffer to trigger timestamp notification
+ union {
+ uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
+ struct cmsghdr cmsg;
+ } u;
+ u.cmsg.cmsg_level = SOL_SOCKET;
+ u.cmsg.cmsg_type = SO_TIMESTAMPING;
+ u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
+ uint32_t flags = SOF_TIMESTAMPING_TX_SCHED | SOF_TIMESTAMPING_TX_SOFTWARE |
+ SOF_TIMESTAMPING_TX_ACK;
+ memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
+ std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
+ memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
+ msgCallback.resetData(std::move(ctrl));
+
+ // write()
+ std::vector<uint8_t> buf(128, 'a');
+ socket->write(buf.data(), buf.size());
+
+ // read()
+ std::vector<uint8_t> readbuf(buf.size());
+ uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, buf.size());
+ EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
+
+ writeCallback.checkForTimestampNotifications();
+
+ // close()
+ socket->close();
+
+ cerr << "SendMsgDataCallback test completed" << endl;
+}
+#endif // FOLLY_HAVE_MSG_ERRQUEUE
#endif
-} // namespace
+} // namespace folly
+#ifdef SIGPIPE
///////////////////////////////////////////////////////////////////////////
// init_unit_test_suite
///////////////////////////////////////////////////////////////////////////
}
};
Initializer initializer;
-} // anonymous
+} // namespace
+#endif