*/
#include <folly/io/async/test/AsyncSSLSocketTest.h>
-#include <signal.h>
#include <pthread.h>
+#include <signal.h>
+#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h>
-#include <folly/SocketAddress.h>
+#include <folly/portability/GMock.h>
+#include <folly/portability/GTest.h>
#include <folly/portability/Sockets.h>
#include <folly/portability/Unistd.h>
#include <folly/io/async/test/BlockingSocket.h>
+#include <fcntl.h>
+#include <folly/io/Cursor.h>
+#include <openssl/bio.h>
+#include <sys/types.h>
#include <fstream>
-#include <gtest/gtest.h>
#include <iostream>
#include <list>
#include <set>
-#include <fcntl.h>
-#include <openssl/bio.h>
-#include <sys/types.h>
-#include <folly/io/Cursor.h>
+#include <thread>
using std::string;
using std::vector;
using std::endl;
using std::list;
+using namespace testing;
+
namespace folly {
uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
constexpr size_t SSLClient::kMaxReadBufferSz;
constexpr size_t SSLClient::kMaxReadsPerEvent;
-TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
+TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
: ctx_(new folly::SSLContext),
acb_(acb),
socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
acb_->ctx_ = ctx_;
acb_->base_ = &evb_;
- //set up the listening socket
+ // Enable TFO
+ if (enableTFO) {
+ LOG(INFO) << "server TFO enabled";
+ socket_->setTFOEnabled(true, 1000);
+ }
+
+ // set up the listening socket
socket_->bind(0);
socket_->getAddress(&address_);
socket_->listen(100);
uint8_t readbuf[128];
uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
- } catch (AsyncSocketException &e) {
+ } catch (AsyncSocketException&) {
ex = true;
}
EXPECT_TRUE(ex);
std::unique_ptr<NpnServer> server;
};
+class NextProtocolTLSExtTest : public NextProtocolTest {
+ // For extended TLS protos
+};
+
class NextProtocolNPNOnlyTest : public NextProtocolTest {
// For mismatching protos
};
NextProtocolTypePair(
SSLContext::NextProtocolType::NPN,
SSLContext::NextProtocolType::NPN),
+ NextProtocolTypePair(
+ SSLContext::NextProtocolType::NPN,
+ SSLContext::NextProtocolType::ANY),
+ NextProtocolTypePair(
+ SSLContext::NextProtocolType::ANY,
+ SSLContext::NextProtocolType::ANY)));
+
#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+INSTANTIATE_TEST_CASE_P(
+ AsyncSSLSocketTest,
+ NextProtocolTLSExtTest,
+ ::testing::Values(
NextProtocolTypePair(
SSLContext::NextProtocolType::ALPN,
SSLContext::NextProtocolType::ALPN),
SSLContext::NextProtocolType::ANY),
NextProtocolTypePair(
SSLContext::NextProtocolType::ANY,
- SSLContext::NextProtocolType::ALPN),
+ SSLContext::NextProtocolType::ALPN)));
#endif
- NextProtocolTypePair(
- SSLContext::NextProtocolType::NPN,
- SSLContext::NextProtocolType::ANY),
- NextProtocolTypePair(
- SSLContext::NextProtocolType::ANY,
- SSLContext::NextProtocolType::ANY)));
INSTANTIATE_TEST_CASE_P(
AsyncSSLSocketTest,
cerr << "SSLClientTimeoutTest test completed" << endl;
}
-
+// This is a FB-only extension, and the tests will fail without it
+#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
/**
* Test SSL server async cache
*/
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
- EmptyReadCallback clientReadCallback;
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
TestSSLAsyncCacheServer server(&acceptCallback);
// only do a TCP connect
std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
sock->connect(nullptr, server.getAddress());
+
+ EmptyReadCallback clientReadCallback;
clientReadCallback.tcpSocket_ = sock;
sock->setReadCB(&clientReadCallback);
cerr << "SSLServerCacheCloseTest test completed" << endl;
}
+#endif
/**
* Verify Client Ciphers obtained using SSL MSG Callback.
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
- serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
+ serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
serverCtx->loadPrivateKey(testKey);
serverCtx->loadCertificate(testCert);
serverCtx->loadTrustedCertificates(testCA);
serverCtx->loadClientCAList(testCA);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
- clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
+ clientCtx->ciphers("AES256-SHA:RC4-MD5");
clientCtx->loadPrivateKey(testKey);
clientCtx->loadCertificate(testCert);
clientCtx->loadTrustedCertificates(testCA);
eventBase.loop();
- EXPECT_EQ(server.clientCiphers_,
- "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
+ EXPECT_EQ(server.clientCiphers_, "AES256-SHA:RC4-MD5:00ff");
+ EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
std::string::npos);
}
+TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
+ using folly::ssl::OpenSSLUtils;
+ EXPECT_EQ(
+ OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
+ // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
+ EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
+ // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
+ EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
+}
+
+#if FOLLY_ALLOW_TFO
+
+class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
+ public:
+ using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
+
+ explicit MockAsyncTFOSSLSocket(
+ std::shared_ptr<folly::SSLContext> sslCtx,
+ EventBase* evb)
+ : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
+
+ MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
+};
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, true);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->enableTFO();
+ socket->open();
+
+ // write()
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+ socket->write(buf.data(), buf.size());
+
+ // read()
+ std::array<uint8_t, 128> readbuf;
+ uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, 128);
+ EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+ // close()
+ socket->close();
+}
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, false);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->enableTFO();
+ socket->open();
+
+ // write()
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+ socket->write(buf.data(), buf.size());
+
+ // read()
+ std::array<uint8_t, 128> readbuf;
+ uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, 128);
+ EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+ // close()
+ socket->close();
+}
+
+class ConnCallback : public AsyncSocket::ConnectCallback {
+ public:
+ virtual void connectSuccess() noexcept override {
+ state = State::SUCCESS;
+ }
+
+ virtual void connectErr(const AsyncSocketException& ex) noexcept override {
+ state = State::ERROR;
+ error = ex.what();
+ }
+
+ enum class State { WAITING, SUCCESS, ERROR };
+
+ State state{State::WAITING};
+ std::string error;
+};
+
+template <class Cardinality>
+MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
+ EventBase* evb,
+ const SocketAddress& address,
+ Cardinality cardinality) {
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket = MockAsyncTFOSSLSocket::UniquePtr(
+ new MockAsyncTFOSSLSocket(sslContext, evb));
+ socket->enableTFO();
+
+ EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+ .Times(cardinality)
+ .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
+ sockaddr_storage addr;
+ auto len = address.getAddress(&addr);
+ return connect(fd, (const struct sockaddr*)&addr, len);
+ }));
+ return socket;
+}
+
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
+ // Start listening on a local port
+ WriteCallbackBase writeCallback;
+ ReadCallback readCallback(&writeCallback);
+ HandshakeCallback handshakeCallback(&readCallback);
+ SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, true);
+
+ EventBase evb;
+
+ auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 30);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
+
+ evb.runInEventBaseThread([&] { socket->detachEventBase(); });
+ evb.loop();
+
+ BlockingSocket sock(std::move(socket));
+ // write()
+ std::array<uint8_t, 128> buf;
+ memset(buf.data(), 'a', buf.size());
+ sock.write(buf.data(), buf.size());
+
+ // read()
+ std::array<uint8_t, 128> readbuf;
+ uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
+ EXPECT_EQ(bytesRead, 128);
+ EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+ // close()
+ sock.close();
+}
+
+TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
+ // Start listening on a local port
+ ConnectTimeoutCallback acceptCallback;
+ TestSSLServer server(&acceptCallback, true);
+
+ // Set up SSL context.
+ auto sslContext = std::make_shared<SSLContext>();
+
+ // connect
+ auto socket =
+ std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+ socket->enableTFO();
+ EXPECT_THROW(
+ socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
+}
+
+TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
+ // Start listening on a local port
+ ConnectTimeoutCallback acceptCallback;
+ TestSSLServer server(&acceptCallback, true);
+
+ EventBase evb;
+
+ auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
+ ConnCallback ccb;
+ // Set a short timeout
+ socket->connect(&ccb, server.getAddress(), 1);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+}
+
+TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
+ // Start listening on a local port
+ EmptyReadCallback readCallback;
+ HandshakeCallback handshakeCallback(
+ &readCallback, HandshakeCallback::EXPECT_ERROR);
+ HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
+ TestSSLServer server(&acceptCallback, true);
+
+ EventBase evb;
+
+ auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
+ ConnCallback ccb;
+ socket->connect(&ccb, server.getAddress(), 100);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+ EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
+}
+
+TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
+ // Start listening on a local port
+ EventBase evb;
+
+ // Hopefully nothing is listening on this address
+ SocketAddress addr("127.0.0.1", 65535);
+ auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
+ ConnCallback ccb;
+ socket->connect(&ccb, addr, 100);
+
+ evb.loop();
+ EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+ EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
+}
+
+#endif
+
} // namespace
+#ifdef SIGPIPE
///////////////////////////////////////////////////////////////////////////
// init_unit_test_suite
///////////////////////////////////////////////////////////////////////////
};
Initializer initializer;
} // anonymous
+#endif