X-Git-Url: http://plrg.eecs.uci.edu/git/?p=folly.git;a=blobdiff_plain;f=folly%2Fio%2Fasync%2Ftest%2FAsyncSSLSocketTest.cpp;h=cdacacada0a75d22c407c03b005446edcf351c63;hp=a3e148868d1242aca0f67f523ffca8a22b779d22;hb=a3bd593ad9374cd3a1db31066874b9bad1cf74b4;hpb=18435bce240108397da4ba1ce0b9db317272b7fd diff --git a/folly/io/async/test/AsyncSSLSocketTest.cpp b/folly/io/async/test/AsyncSSLSocketTest.cpp index a3e14886..cdacacad 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.cpp +++ b/folly/io/async/test/AsyncSSLSocketTest.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2015 Facebook, Inc. + * Copyright 2016 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,26 +15,28 @@ */ #include -#include #include +#include +#include #include #include -#include +#include +#include #include +#include +#include #include +#include +#include +#include #include #include #include -#include -#include -#include -#include -#include -#include -#include + +#include using std::string; using std::vector; @@ -43,6 +45,8 @@ using std::cerr; using std::endl; using std::list; +using namespace testing; + namespace folly { uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0; uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0; @@ -55,10 +59,10 @@ const char* testCA = "folly/io/async/test/certs/ca-cert.pem"; constexpr size_t SSLClient::kMaxReadBufferSz; constexpr size_t SSLClient::kMaxReadsPerEvent; -TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) : -ctx_(new folly::SSLContext), - acb_(acb), - socket_(folly::AsyncServerSocket::newSocket(&evb_)) { +TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO) + : ctx_(new folly::SSLContext), + acb_(acb), + socket_(folly::AsyncServerSocket::newSocket(&evb_)) { // Set up the SSL context ctx_->loadCertificate(testCert); ctx_->loadPrivateKey(testKey); @@ -67,7 +71,13 @@ ctx_(new folly::SSLContext), acb_->ctx_ = ctx_; acb_->base_ = &evb_; - //set up the listening socket + // Enable TFO + if (enableTFO) { + LOG(INFO) << "server TFO enabled"; + socket_->setTFOEnabled(true, 1000); + } + + // set up the listening socket socket_->bind(0); socket_->getAddress(&address_); socket_->listen(100); @@ -144,6 +154,21 @@ bool clientProtoFilterPickNone(unsigned char**, unsigned int*, return false; } +std::string getFileAsBuf(const char* fileName) { + std::string buffer; + folly::readFile(fileName, buffer); + return buffer; +} + +std::string getCommonName(X509* cert) { + X509_NAME* subject = X509_get_subject_name(cert); + std::string cn; + cn.resize(ub_common_name); + X509_NAME_get_text_by_NID( + subject, NID_commonName, const_cast(cn.data()), ub_common_name); + return cn; +} + /** * Test connecting to, writing to, reading from, and closing the * connection to the SSL server. @@ -184,13 +209,89 @@ TEST(AsyncSSLSocketTest, ConnectWriteReadClose) { cerr << "ConnectWriteReadClose test completed" << endl; } +/** + * Test reading after server close. + */ +TEST(AsyncSSLSocketTest, ReadAfterClose) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadEOFCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + auto server = folly::make_unique(&acceptCallback); + + // Set up SSL context. + auto sslContext = std::make_shared(); + sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + + auto socket = + std::make_shared(server->getAddress(), sslContext); + socket->open(); + + // This should trigger an EOF on the client. + auto evb = handshakeCallback.getSocket()->getEventBase(); + evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); }); + std::array readbuf; + auto bytesRead = socket->read(readbuf.data(), readbuf.size()); + EXPECT_EQ(0, bytesRead); +} + +/** + * Test bad renegotiation + */ +TEST(AsyncSSLSocketTest, Renegotiate) { + EventBase eventBase; + auto clientCtx = std::make_shared(); + auto dfServerCtx = std::make_shared(); + std::array fds; + getfds(fds.data()); + getctx(clientCtx, dfServerCtx); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); + SSLHandshakeClient client(std::move(clientSock), true, true); + RenegotiatingServer server(std::move(serverSock)); + + while (!client.handshakeSuccess_ && !client.handshakeError_) { + eventBase.loopOnce(); + } + + ASSERT_TRUE(client.handshakeSuccess_); + + auto sslSock = std::move(client).moveSocket(); + sslSock->detachEventBase(); + // This is nasty, however we don't want to add support for + // renegotiation in AsyncSSLSocket. + SSL_renegotiate(const_cast(sslSock->getSSL())); + + auto socket = std::make_shared(std::move(sslSock)); + + std::thread t([&]() { eventBase.loopForever(); }); + + // Trigger the renegotiation. + std::array buf; + memset(buf.data(), 'a', buf.size()); + try { + socket->write(buf.data(), buf.size()); + } catch (AsyncSocketException& e) { + LOG(INFO) << "client got error " << e.what(); + } + eventBase.terminateLoopSoon(); + t.join(); + + eventBase.loop(); + ASSERT_TRUE(server.renegotiationError_); +} + /** * Negative test for handshakeError(). */ TEST(AsyncSSLSocketTest, HandshakeError) { // Start listening on a local port WriteCallbackBase writeCallback; - ReadCallback readCallback(&writeCallback); + WriteErrorCallback readCallback(&writeCallback); HandshakeCallback handshakeCallback(&readCallback); HandshakeErrorCallback acceptCallback(&handshakeCallback); TestSSLServer server(&acceptCallback); @@ -315,191 +416,241 @@ TEST(AsyncSSLSocketTest, SocketWithDelay) { cerr << "SocketWithDelay test completed" << endl; } -TEST(AsyncSSLSocketTest, NpnTestOverlap) { - EventBase eventBase; - std::shared_ptr clientCtx(new SSLContext); - std::shared_ptr serverCtx(new SSLContext);; - int fds[2]; - getfds(fds); - getctx(clientCtx, serverCtx); +using NextProtocolTypePair = + std::pair; - clientCtx->setAdvertisedNextProtocols({"blub","baz"}); - serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); +class NextProtocolTest : public testing::TestWithParam { + // For matching protos + public: + void SetUp() override { getctx(clientCtx, serverCtx); } - AsyncSSLSocket::UniquePtr clientSock( - new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); - AsyncSSLSocket::UniquePtr serverSock( - new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); - NpnClient client(std::move(clientSock)); - NpnServer server(std::move(serverSock)); + void connect(bool unset = false) { + getfds(fds); - eventBase.loop(); + if (unset) { + // unsetting NPN for any of [client, server] is enough to make NPN not + // work + clientCtx->unsetNextProtocols(); + } - EXPECT_TRUE(client.nextProtoLength != 0); - EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); - EXPECT_EQ(memcmp(client.nextProto, server.nextProto, - server.nextProtoLength), 0); - string selected((const char*)client.nextProto, client.nextProtoLength); - EXPECT_EQ(selected.compare("baz"), 0); -} + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); + client = folly::make_unique(std::move(clientSock)); + server = folly::make_unique(std::move(serverSock)); + + eventBase.loop(); + } + + void expectProtocol(const std::string& proto) { + EXPECT_NE(client->nextProtoLength, 0); + EXPECT_EQ(client->nextProtoLength, server->nextProtoLength); + EXPECT_EQ( + memcmp(client->nextProto, server->nextProto, server->nextProtoLength), + 0); + string selected((const char*)client->nextProto, client->nextProtoLength); + EXPECT_EQ(proto, selected); + } + + void expectNoProtocol() { + EXPECT_EQ(client->nextProtoLength, 0); + EXPECT_EQ(server->nextProtoLength, 0); + EXPECT_EQ(client->nextProto, nullptr); + EXPECT_EQ(server->nextProto, nullptr); + } + + void expectProtocolType() { + if (GetParam().first == SSLContext::NextProtocolType::ANY && + GetParam().second == SSLContext::NextProtocolType::ANY) { + EXPECT_EQ(client->protocolType, server->protocolType); + } else if (GetParam().first == SSLContext::NextProtocolType::ANY || + GetParam().second == SSLContext::NextProtocolType::ANY) { + // Well not much we can say + } else { + expectProtocolType(GetParam()); + } + } + + void expectProtocolType(NextProtocolTypePair expected) { + EXPECT_EQ(client->protocolType, expected.first); + EXPECT_EQ(server->protocolType, expected.second); + } -TEST(AsyncSSLSocketTest, NpnTestUnset) { - // Identical to above test, except that we want unset NPN before - // looping. EventBase eventBase; - std::shared_ptr clientCtx(new SSLContext); - std::shared_ptr serverCtx(new SSLContext);; + std::shared_ptr clientCtx{std::make_shared()}; + std::shared_ptr serverCtx{std::make_shared()}; int fds[2]; - getfds(fds); - getctx(clientCtx, serverCtx); + std::unique_ptr client; + std::unique_ptr server; +}; - clientCtx->setAdvertisedNextProtocols({"blub","baz"}); - serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); +class NextProtocolTLSExtTest : public NextProtocolTest { + // For extended TLS protos +}; - AsyncSSLSocket::UniquePtr clientSock( - new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); - AsyncSSLSocket::UniquePtr serverSock( - new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); +class NextProtocolNPNOnlyTest : public NextProtocolTest { + // For mismatching protos +}; - // unsetting NPN for any of [client, server] is enought to make NPN not - // work - clientCtx->unsetNextProtocols(); +class NextProtocolMismatchTest : public NextProtocolTest { + // For mismatching protos +}; - NpnClient client(std::move(clientSock)); - NpnServer server(std::move(serverSock)); +TEST_P(NextProtocolTest, NpnTestOverlap) { + clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first); + serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().second); - eventBase.loop(); + connect(); - EXPECT_TRUE(client.nextProtoLength == 0); - EXPECT_TRUE(server.nextProtoLength == 0); - EXPECT_TRUE(client.nextProto == nullptr); - EXPECT_TRUE(server.nextProto == nullptr); + expectProtocol("baz"); + expectProtocolType(); } -TEST(AsyncSSLSocketTest, NpnTestNoOverlap) { - EventBase eventBase; - std::shared_ptr clientCtx(new SSLContext); - std::shared_ptr serverCtx(new SSLContext);; - int fds[2]; - getfds(fds); - getctx(clientCtx, serverCtx); +TEST_P(NextProtocolTest, NpnTestUnset) { + // Identical to above test, except that we want unset NPN before + // looping. + clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first); + serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().second); - clientCtx->setAdvertisedNextProtocols({"blub"}); - serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); + connect(true /* unset */); - AsyncSSLSocket::UniquePtr clientSock( - new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); - AsyncSSLSocket::UniquePtr serverSock( - new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); - NpnClient client(std::move(clientSock)); - NpnServer server(std::move(serverSock)); + // if alpn negotiation fails, type will appear as npn + expectNoProtocol(); + EXPECT_EQ(client->protocolType, server->protocolType); +} - eventBase.loop(); +TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) { + clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first); + serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().second); + + connect(); - EXPECT_TRUE(client.nextProtoLength != 0); - EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); - EXPECT_EQ(memcmp(client.nextProto, server.nextProto, - server.nextProtoLength), 0); - string selected((const char*)client.nextProto, client.nextProtoLength); - EXPECT_EQ(selected.compare("blub"), 0); + expectNoProtocol(); + expectProtocolType( + {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN}); } -TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterHit) { - EventBase eventBase; - auto clientCtx = std::make_shared(); - auto serverCtx = std::make_shared(); - int fds[2]; - getfds(fds); - getctx(clientCtx, serverCtx); +// Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test +// will fail on 1.0.2 before that. +TEST_P(NextProtocolTest, NpnTestNoOverlap) { + clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first); + serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().second); + + connect(); + + if (GetParam().first == SSLContext::NextProtocolType::ALPN || + GetParam().second == SSLContext::NextProtocolType::ALPN) { + // This is arguably incorrect behavior since RFC7301 states an ALPN protocol + // mismatch should result in a fatal alert, but this is OpenSSL's current + // behavior and we want to know if it changes. + expectNoProtocol(); + } else { + expectProtocol("blub"); + expectProtocolType( + {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN}); + } +} - clientCtx->setAdvertisedNextProtocols({"blub"}); +TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) { + clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first); clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony); - serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); + serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().second); - AsyncSSLSocket::UniquePtr clientSock( - new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); - AsyncSSLSocket::UniquePtr serverSock( - new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); - NpnClient client(std::move(clientSock)); - NpnServer server(std::move(serverSock)); + connect(); - eventBase.loop(); - - EXPECT_TRUE(client.nextProtoLength != 0); - EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); - EXPECT_EQ(memcmp(client.nextProto, server.nextProto, - server.nextProtoLength), 0); - string selected((const char*)client.nextProto, client.nextProtoLength); - EXPECT_EQ(selected.compare("ponies"), 0); + expectProtocol("ponies"); + expectProtocolType(); } -TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterMiss) { - EventBase eventBase; - auto clientCtx = std::make_shared(); - auto serverCtx = std::make_shared(); - int fds[2]; - getfds(fds); - getctx(clientCtx, serverCtx); - - clientCtx->setAdvertisedNextProtocols({"blub"}); +TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) { + clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first); clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone); - serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); - - AsyncSSLSocket::UniquePtr clientSock( - new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); - AsyncSSLSocket::UniquePtr serverSock( - new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); - NpnClient client(std::move(clientSock)); - NpnServer server(std::move(serverSock)); + serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().second); - eventBase.loop(); + connect(); - EXPECT_TRUE(client.nextProtoLength != 0); - EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); - EXPECT_EQ(memcmp(client.nextProto, server.nextProto, - server.nextProtoLength), 0); - string selected((const char*)client.nextProto, client.nextProtoLength); - EXPECT_EQ(selected.compare("blub"), 0); + expectProtocol("blub"); + expectProtocolType(); } -TEST(AsyncSSLSocketTest, RandomizedNpnTest) { +TEST_P(NextProtocolTest, RandomizedNpnTest) { // Probability that this test will fail is 2^-64, which could be considered // as negligible. const int kTries = 64; + clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().first); + serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}}, + GetParam().second); + std::set selectedProtocols; for (int i = 0; i < kTries; ++i) { - EventBase eventBase; - std::shared_ptr clientCtx = std::make_shared(); - std::shared_ptr serverCtx = std::make_shared(); - int fds[2]; - getfds(fds); - getctx(clientCtx, serverCtx); - - clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}); - serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, - {1, {"bar"}}}); - - - AsyncSSLSocket::UniquePtr clientSock( - new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); - AsyncSSLSocket::UniquePtr serverSock( - new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); - NpnClient client(std::move(clientSock)); - NpnServer server(std::move(serverSock)); - - eventBase.loop(); - - EXPECT_TRUE(client.nextProtoLength != 0); - EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); - EXPECT_EQ(memcmp(client.nextProto, server.nextProto, - server.nextProtoLength), 0); - string selected((const char*)client.nextProto, client.nextProtoLength); + connect(); + + EXPECT_NE(client->nextProtoLength, 0); + EXPECT_EQ(client->nextProtoLength, server->nextProtoLength); + EXPECT_EQ( + memcmp(client->nextProto, server->nextProto, server->nextProtoLength), + 0); + string selected((const char*)client->nextProto, client->nextProtoLength); selectedProtocols.insert(selected); + expectProtocolType(); } EXPECT_EQ(selectedProtocols.size(), 2); } +INSTANTIATE_TEST_CASE_P( + AsyncSSLSocketTest, + NextProtocolTest, + ::testing::Values( + NextProtocolTypePair( + SSLContext::NextProtocolType::NPN, + SSLContext::NextProtocolType::NPN), + NextProtocolTypePair( + SSLContext::NextProtocolType::NPN, + SSLContext::NextProtocolType::ANY), + NextProtocolTypePair( + SSLContext::NextProtocolType::ANY, + SSLContext::NextProtocolType::ANY))); + +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) +INSTANTIATE_TEST_CASE_P( + AsyncSSLSocketTest, + NextProtocolTLSExtTest, + ::testing::Values( + NextProtocolTypePair( + SSLContext::NextProtocolType::ALPN, + SSLContext::NextProtocolType::ALPN), + NextProtocolTypePair( + SSLContext::NextProtocolType::ALPN, + SSLContext::NextProtocolType::ANY), + NextProtocolTypePair( + SSLContext::NextProtocolType::ANY, + SSLContext::NextProtocolType::ALPN))); +#endif + +INSTANTIATE_TEST_CASE_P( + AsyncSSLSocketTest, + NextProtocolNPNOnlyTest, + ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN, + SSLContext::NextProtocolType::NPN))); + +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) +INSTANTIATE_TEST_CASE_P( + AsyncSSLSocketTest, + NextProtocolMismatchTest, + ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN, + SSLContext::NextProtocolType::ALPN), + NextProtocolTypePair(SSLContext::NextProtocolType::ALPN, + SSLContext::NextProtocolType::NPN))); +#endif #ifndef OPENSSL_NO_TLSEXT /** @@ -657,8 +808,7 @@ TEST(AsyncSSLSocketTest, SSLClientTest) { // Set up SSL client EventBase eventBase; - std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), - 1)); + auto client = std::make_shared(&eventBase, server.getAddress(), 1); client->connect(); EventBaseAborter eba(&eventBase, 3000); @@ -684,8 +834,8 @@ TEST(AsyncSSLSocketTest, SSLClientTestReuse) { // Set up SSL client EventBase eventBase; - std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), - 10)); + auto client = + std::make_shared(&eventBase, server.getAddress(), 10); client->connect(); EventBaseAborter eba(&eventBase, 3000); @@ -710,8 +860,8 @@ TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) { // Set up SSL client EventBase eventBase; - std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), - 1, 10)); + auto client = + std::make_shared(&eventBase, server.getAddress(), 1, 10); client->connect(true /* write before connect completes */); EventBaseAborter eba(&eventBase, 3000); eventBase.loop(); @@ -741,8 +891,8 @@ TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) { // Set up SSL client EventBase eventBase; - std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), - 10, 500)); + auto client = + std::make_shared(&eventBase, server.getAddress(), 10, 500); client->connect(); EventBaseAborter eba(&eventBase, 3000); @@ -798,8 +948,7 @@ TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) { // Set up SSL client EventBase eventBase; - std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), - 2)); + auto client = std::make_shared(&eventBase, server.getAddress(), 2); client->connect(); EventBaseAborter eba(&eventBase, 3000); @@ -828,8 +977,8 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) { // Set up SSL client EventBase eventBase; - std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), - 2, 100)); + auto client = + std::make_shared(&eventBase, server.getAddress(), 2, 100); client->connect(); EventBaseAborter eba(&eventBase, 3000); @@ -838,7 +987,7 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) { server.getEventBase().runInEventBaseThread([&handshakeCallback]{ handshakeCallback.closeSocket();}); // give time for the cache lookup to come back and find it closed - usleep(500000); + handshakeCallback.waitForHandshake(); EXPECT_EQ(server.getAsyncCallbacks(), 1); EXPECT_EQ(server.getAsyncLookups(), 1); @@ -1339,6 +1488,47 @@ TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) { EXPECT_LE(0, server.handshakeTime.count()); } +TEST(AsyncSSLSocketTest, LoadCertFromMemory) { + auto cert = getFileAsBuf(testCert); + auto key = getFileAsBuf(testKey); + + ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem())); + BIO_write(certBio.get(), cert.data(), cert.size()); + ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem())); + BIO_write(keyBio.get(), key.data(), key.size()); + + // Create SSL structs from buffers to get properties + ssl::X509UniquePtr certStruct( + PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr)); + ssl::EvpPkeyUniquePtr keyStruct( + PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr)); + certBio = nullptr; + keyBio = nullptr; + + auto origCommonName = getCommonName(certStruct.get()); + auto origKeySize = EVP_PKEY_bits(keyStruct.get()); + certStruct = nullptr; + keyStruct = nullptr; + + auto ctx = std::make_shared(); + ctx->loadPrivateKeyFromBufferPEM(key); + ctx->loadCertificateFromBufferPEM(cert); + ctx->loadTrustedCertificates(testCA); + + ssl::SSLUniquePtr ssl(ctx->createSSL()); + + auto newCert = SSL_get_certificate(ssl.get()); + auto newKey = SSL_get_privatekey(ssl.get()); + + // Get properties from SSL struct + auto newCommonName = getCommonName(newCert); + auto newKeySize = EVP_PKEY_bits(newKey); + + // Check that the key and cert have the expected properties + EXPECT_EQ(origCommonName, newCommonName); + EXPECT_EQ(origKeySize, newKeySize); +} + TEST(AsyncSSLSocketTest, MinWriteSizeTest) { EventBase eb; @@ -1438,6 +1628,265 @@ TEST(AsyncSSLSocketTest, UnencryptedTest) { EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState()); } +TEST(AsyncSSLSocketTest, ConnResetErrorString) { + // Start listening on a local port + WriteCallbackBase writeCallback; + WriteErrorCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback, + HandshakeCallback::EXPECT_ERROR); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + auto socket = std::make_shared(server.getAddress(), nullptr); + socket->open(); + uint8_t buf[3] = {0x16, 0x03, 0x01}; + socket->write(buf, sizeof(buf)); + socket->closeWithReset(); + + handshakeCallback.waitForHandshake(); + EXPECT_NE( + handshakeCallback.errorString_.find("Network error"), std::string::npos); + EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos); +} + +TEST(AsyncSSLSocketTest, ConnEOFErrorString) { + // Start listening on a local port + WriteCallbackBase writeCallback; + WriteErrorCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback, + HandshakeCallback::EXPECT_ERROR); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + auto socket = std::make_shared(server.getAddress(), nullptr); + socket->open(); + uint8_t buf[3] = {0x16, 0x03, 0x01}; + socket->write(buf, sizeof(buf)); + socket->close(); + + handshakeCallback.waitForHandshake(); + EXPECT_NE( + handshakeCallback.errorString_.find("Connection EOF"), std::string::npos); + EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos); +} + +TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) { + // Start listening on a local port + WriteCallbackBase writeCallback; + WriteErrorCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback, + HandshakeCallback::EXPECT_ERROR); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback); + + auto socket = std::make_shared(server.getAddress(), nullptr); + socket->open(); + uint8_t buf[256] = {0x16, 0x03}; + memset(buf + 2, 'a', sizeof(buf) - 2); + socket->write(buf, sizeof(buf)); + socket->close(); + + handshakeCallback.waitForHandshake(); + EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"), + std::string::npos); + EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"), + std::string::npos); +} + +#if FOLLY_ALLOW_TFO + +class MockAsyncTFOSSLSocket : public AsyncSSLSocket { + public: + using UniquePtr = std::unique_ptr; + + explicit MockAsyncTFOSSLSocket( + std::shared_ptr sslCtx, + EventBase* evb) + : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {} + + MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags)); +}; + +/** + * Test connecting to, writing to, reading from, and closing the + * connection to the SSL server with TFO. + */ +TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback, true); + + // Set up SSL context. + auto sslContext = std::make_shared(); + + // connect + auto socket = + std::make_shared(server.getAddress(), sslContext); + socket->enableTFO(); + socket->open(); + + // write() + std::array buf; + memset(buf.data(), 'a', buf.size()); + socket->write(buf.data(), buf.size()); + + // read() + std::array readbuf; + uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size()); + EXPECT_EQ(bytesRead, 128); + EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0); + + // close() + socket->close(); +} + +/** + * Test connecting to, writing to, reading from, and closing the + * connection to the SSL server with TFO. + */ +TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback, false); + + // Set up SSL context. + auto sslContext = std::make_shared(); + + // connect + auto socket = + std::make_shared(server.getAddress(), sslContext); + socket->enableTFO(); + socket->open(); + + // write() + std::array buf; + memset(buf.data(), 'a', buf.size()); + socket->write(buf.data(), buf.size()); + + // read() + std::array readbuf; + uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size()); + EXPECT_EQ(bytesRead, 128); + EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0); + + // close() + socket->close(); +} + +class ConnCallback : public AsyncSocket::ConnectCallback { + public: + virtual void connectSuccess() noexcept override { + state = State::SUCCESS; + } + + virtual void connectErr(const AsyncSocketException&) noexcept override { + state = State::ERROR; + } + + enum class State { WAITING, SUCCESS, ERROR }; + + State state{State::WAITING}; +}; + +template +MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback( + EventBase* evb, + const SocketAddress& address, + Cardinality cardinality) { + // Set up SSL context. + auto sslContext = std::make_shared(); + + // connect + auto socket = MockAsyncTFOSSLSocket::UniquePtr( + new MockAsyncTFOSSLSocket(sslContext, evb)); + socket->enableTFO(); + + EXPECT_CALL(*socket, tfoSendMsg(_, _, _)) + .Times(cardinality) + .WillOnce(Invoke([&](int fd, struct msghdr*, int) { + sockaddr_storage addr; + auto len = address.getAddress(&addr); + return connect(fd, (const struct sockaddr*)&addr, len); + })); + return socket; +} + +TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) { + // Start listening on a local port + WriteCallbackBase writeCallback; + ReadCallback readCallback(&writeCallback); + HandshakeCallback handshakeCallback(&readCallback); + SSLServerAcceptCallback acceptCallback(&handshakeCallback); + TestSSLServer server(&acceptCallback, true); + + EventBase evb; + + auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30); + + evb.loop(); + EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state); + + evb.runInEventBaseThread([&] { socket->detachEventBase(); }); + evb.loop(); + + BlockingSocket sock(std::move(socket)); + // write() + std::array buf; + memset(buf.data(), 'a', buf.size()); + sock.write(buf.data(), buf.size()); + + // read() + std::array readbuf; + uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size()); + EXPECT_EQ(bytesRead, 128); + EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0); + + // close() + sock.close(); +} + +TEST(AsyncSSLSocketTest, ConnectTFOTimeout) { + // Start listening on a local port + ConnectTimeoutCallback acceptCallback; + TestSSLServer server(&acceptCallback, true); + + // Set up SSL context. + auto sslContext = std::make_shared(); + + // connect + auto socket = + std::make_shared(server.getAddress(), sslContext); + socket->enableTFO(); + EXPECT_THROW( + socket->open(std::chrono::milliseconds(1)), AsyncSocketException); +} + +TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) { + // Start listening on a local port + ConnectTimeoutCallback acceptCallback; + TestSSLServer server(&acceptCallback, true); + + EventBase evb; + + auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1)); + ConnCallback ccb; + // Set a short timeout + socket->connect(&ccb, server.getAddress(), 1); + + evb.loop(); + EXPECT_EQ(ConnCallback::State::ERROR, ccb.state); +} + +#endif + } // namespace ///////////////////////////////////////////////////////////////////////////