Add ability to set custom SSLContext on TestSSLServer
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
index 386f51ba21265681810d36f2cbdb4c5107949c48..b5c6430ac268d3f18f7f56cb2154efde08b45641 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  */
 #include <folly/io/async/test/AsyncSSLSocketTest.h>
 
-#include <pthread.h>
 #include <signal.h>
 
 #include <folly/SocketAddress.h>
 #include <folly/io/async/AsyncSSLSocket.h>
 #include <folly/io/async/EventBase.h>
+#include <folly/portability/GMock.h>
+#include <folly/portability/GTest.h>
+#include <folly/portability/OpenSSL.h>
 #include <folly/portability/Sockets.h>
 #include <folly/portability/Unistd.h>
 
@@ -28,7 +30,6 @@
 
 #include <fcntl.h>
 #include <folly/io/Cursor.h>
-#include <gtest/gtest.h>
 #include <openssl/bio.h>
 #include <sys/types.h>
 #include <fstream>
@@ -37,8 +38,6 @@
 #include <set>
 #include <thread>
 
-#include <gmock/gmock.h>
-
 using std::string;
 using std::vector;
 using std::min;
@@ -53,45 +52,9 @@ uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
 
-const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
-const char* testKey = "folly/io/async/test/certs/tests-key.pem";
-const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
-
 constexpr size_t SSLClient::kMaxReadBufferSz;
 constexpr size_t SSLClient::kMaxReadsPerEvent;
 
-TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
-    : ctx_(new folly::SSLContext),
-      acb_(acb),
-      socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
-  // Set up the SSL context
-  ctx_->loadCertificate(testCert);
-  ctx_->loadPrivateKey(testKey);
-  ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
-
-  acb_->ctx_ = ctx_;
-  acb_->base_ = &evb_;
-
-  // Enable TFO
-  if (enableTFO) {
-    LOG(INFO) << "server TFO enabled";
-    socket_->setTFOEnabled(true, 1000);
-  }
-
-  // set up the listening socket
-  socket_->bind(0);
-  socket_->getAddress(&address_);
-  socket_->listen(100);
-  socket_->addAcceptCallback(acb_, &evb_);
-  socket_->startAccepting();
-
-  int ret = pthread_create(&thread_, nullptr, Main, this);
-  assert(ret == 0);
-  (void)ret;
-
-  std::cerr << "Accepting connections on " << address_ << std::endl;
-}
-
 void getfds(int fds[2]) {
   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
     FAIL() << "failed to create socketpair: " << strerror(errno);
@@ -115,10 +78,8 @@ void getctx(
   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
 
   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
-  serverCtx->loadCertificate(
-      testCert);
-  serverCtx->loadPrivateKey(
-      testKey);
+  serverCtx->loadCertificate(kTestCert);
+  serverCtx->loadPrivateKey(kTestKey);
 }
 
 void sslsocketpair(
@@ -240,6 +201,7 @@ TEST(AsyncSSLSocketTest, ReadAfterClose) {
 /**
  * Test bad renegotiation
  */
+#if !defined(OPENSSL_IS_BORINGSSL)
 TEST(AsyncSSLSocketTest, Renegotiate) {
   EventBase eventBase;
   auto clientCtx = std::make_shared<SSLContext>();
@@ -285,6 +247,7 @@ TEST(AsyncSSLSocketTest, Renegotiate) {
   eventBase.loop();
   ASSERT_TRUE(server.renegotiationError_);
 }
+#endif
 
 /**
  * Negative test for handshakeError().
@@ -312,7 +275,7 @@ TEST(AsyncSSLSocketTest, HandshakeError) {
     uint8_t readbuf[128];
     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
-  } catch (AsyncSocketException &e) {
+  } catch (AsyncSocketException&) {
     ex = true;
   }
   EXPECT_TRUE(ex);
@@ -550,7 +513,18 @@ TEST_P(NextProtocolTest, NpnTestNoOverlap) {
     // mismatch should result in a fatal alert, but this is OpenSSL's current
     // behavior and we want to know if it changes.
     expectNoProtocol();
-  } else {
+  }
+#if defined(OPENSSL_IS_BORINGSSL)
+  // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
+  // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
+  // NPN *and* ALPN
+  else if (
+      GetParam().first == SSLContext::NextProtocolType::ANY &&
+      GetParam().second == SSLContext::NextProtocolType::ANY) {
+    expectNoProtocol();
+  }
+#endif
+  else {
     expectProtocol("blub");
     expectProtocolType(
         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
@@ -621,7 +595,7 @@ INSTANTIATE_TEST_CASE_P(
             SSLContext::NextProtocolType::ANY,
             SSLContext::NextProtocolType::ANY)));
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_ALPN
 INSTANTIATE_TEST_CASE_P(
     AsyncSSLSocketTest,
     NextProtocolTLSExtTest,
@@ -643,7 +617,7 @@ INSTANTIATE_TEST_CASE_P(
     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
                                            SSLContext::NextProtocolType::NPN)));
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_ALPN
 INSTANTIATE_TEST_CASE_P(
     AsyncSSLSocketTest,
     NextProtocolMismatchTest,
@@ -878,7 +852,8 @@ TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
   cerr << "SSLClientTimeoutTest test completed" << endl;
 }
 
-
+// The next 3 tests need an FB-only extension, and will fail without it
+#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
 /**
  * Test SSL server async cache
  */
@@ -907,7 +882,6 @@ TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
   cerr << "SSLServerAsyncCacheTest test completed" << endl;
 }
 
-
 /**
  * Test SSL server accept timeout with cache path
  */
@@ -915,7 +889,6 @@ TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
   // Start listening on a local port
   WriteCallbackBase writeCallback;
   ReadCallback readCallback(&writeCallback);
-  EmptyReadCallback clientReadCallback;
   HandshakeCallback handshakeCallback(&readCallback);
   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
   TestSSLAsyncCacheServer server(&acceptCallback);
@@ -925,6 +898,8 @@ TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
   // only do a TCP connect
   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
   sock->connect(nullptr, server.getAddress());
+
+  EmptyReadCallback clientReadCallback;
   clientReadCallback.tcpSocket_ = sock;
   sock->setReadCB(&clientReadCallback);
 
@@ -998,6 +973,7 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
 
   cerr << "SSLServerCacheCloseTest test completed" << endl;
 }
+#endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
 
 /**
  * Verify Client Ciphers obtained using SSL MSG Callback.
@@ -1007,17 +983,17 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
   auto clientCtx = std::make_shared<SSLContext>();
   auto serverCtx = std::make_shared<SSLContext>();
   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
-  serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
-  serverCtx->loadPrivateKey(testKey);
-  serverCtx->loadCertificate(testCert);
-  serverCtx->loadTrustedCertificates(testCA);
-  serverCtx->loadClientCAList(testCA);
+  serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
+  serverCtx->loadPrivateKey(kTestKey);
+  serverCtx->loadCertificate(kTestCert);
+  serverCtx->loadTrustedCertificates(kTestCA);
+  serverCtx->loadClientCAList(kTestCA);
 
   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
-  clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
-  clientCtx->loadPrivateKey(testKey);
-  clientCtx->loadCertificate(testCert);
-  clientCtx->loadTrustedCertificates(testCA);
+  clientCtx->ciphers("AES256-SHA:AES128-SHA");
+  clientCtx->loadPrivateKey(kTestKey);
+  clientCtx->loadCertificate(kTestCert);
+  clientCtx->loadTrustedCertificates(kTestCA);
 
   int fds[2];
   getfds(fds);
@@ -1032,8 +1008,12 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
 
   eventBase.loop();
 
-  EXPECT_EQ(server.clientCiphers_,
-            "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
+#if defined(OPENSSL_IS_BORINGSSL)
+  EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
+#else
+  EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
+#endif
+  EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
   EXPECT_TRUE(client.handshakeVerify_);
   EXPECT_TRUE(client.handshakeSuccess_);
   EXPECT_TRUE(!client.handshakeError_);
@@ -1195,7 +1175,7 @@ TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
 
   SSLHandshakeClient client(std::move(clientSock), true, true);
-  clientCtx->loadTrustedCertificates(testCA);
+  clientCtx->loadTrustedCertificates(kTestCA);
 
   SSLHandshakeServer server(std::move(serverSock), true, true);
 
@@ -1233,7 +1213,7 @@ TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
 
   SSLHandshakeClient client(std::move(clientSock), true, false);
-  clientCtx->loadTrustedCertificates(testCA);
+  clientCtx->loadTrustedCertificates(kTestCA);
 
   SSLHandshakeServer server(std::move(serverSock), true, true);
 
@@ -1273,7 +1253,7 @@ TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
 
   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
-  clientCtx->loadTrustedCertificates(testCA);
+  clientCtx->loadTrustedCertificates(kTestCA);
 
   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
 
@@ -1300,16 +1280,16 @@ TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
   auto serverCtx = std::make_shared<SSLContext>();
   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
-  serverCtx->loadPrivateKey(testKey);
-  serverCtx->loadCertificate(testCert);
-  serverCtx->loadTrustedCertificates(testCA);
-  serverCtx->loadClientCAList(testCA);
+  serverCtx->loadPrivateKey(kTestKey);
+  serverCtx->loadCertificate(kTestCert);
+  serverCtx->loadTrustedCertificates(kTestCA);
+  serverCtx->loadClientCAList(kTestCA);
 
   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
-  clientCtx->loadPrivateKey(testKey);
-  clientCtx->loadCertificate(testCert);
-  clientCtx->loadTrustedCertificates(testCA);
+  clientCtx->loadPrivateKey(kTestKey);
+  clientCtx->loadCertificate(kTestCert);
+  clientCtx->loadTrustedCertificates(kTestCA);
 
   int fds[2];
   getfds(fds);
@@ -1417,16 +1397,16 @@ TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
   serverCtx->setVerificationOption(
       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
-  serverCtx->loadPrivateKey(testKey);
-  serverCtx->loadCertificate(testCert);
-  serverCtx->loadTrustedCertificates(testCA);
-  serverCtx->loadClientCAList(testCA);
+  serverCtx->loadPrivateKey(kTestKey);
+  serverCtx->loadCertificate(kTestCert);
+  serverCtx->loadTrustedCertificates(kTestCA);
+  serverCtx->loadClientCAList(kTestCA);
 
   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
-  clientCtx->loadPrivateKey(testKey);
-  clientCtx->loadCertificate(testCert);
-  clientCtx->loadTrustedCertificates(testCA);
+  clientCtx->loadPrivateKey(kTestKey);
+  clientCtx->loadCertificate(kTestCert);
+  clientCtx->loadTrustedCertificates(kTestCA);
 
   int fds[2];
   getfds(fds);
@@ -1462,10 +1442,10 @@ TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
   serverCtx->setVerificationOption(
       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
-  serverCtx->loadPrivateKey(testKey);
-  serverCtx->loadCertificate(testCert);
-  serverCtx->loadTrustedCertificates(testCA);
-  serverCtx->loadClientCAList(testCA);
+  serverCtx->loadPrivateKey(kTestKey);
+  serverCtx->loadCertificate(kTestCert);
+  serverCtx->loadTrustedCertificates(kTestCA);
+  serverCtx->loadClientCAList(kTestCA);
   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
 
@@ -1490,8 +1470,8 @@ TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
 }
 
 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
-  auto cert = getFileAsBuf(testCert);
-  auto key = getFileAsBuf(testKey);
+  auto cert = getFileAsBuf(kTestCert);
+  auto key = getFileAsBuf(kTestKey);
 
   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
   BIO_write(certBio.get(), cert.data(), cert.size());
@@ -1514,7 +1494,7 @@ TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
   auto ctx = std::make_shared<SSLContext>();
   ctx->loadPrivateKeyFromBufferPEM(key);
   ctx->loadCertificateFromBufferPEM(cert);
-  ctx->loadTrustedCertificates(testCA);
+  ctx->loadTrustedCertificates(kTestCA);
 
   ssl::SSLUniquePtr ssl(ctx->createSSL());
 
@@ -1690,8 +1670,24 @@ TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
   handshakeCallback.waitForHandshake();
   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
             std::string::npos);
+#if defined(OPENSSL_IS_BORINGSSL)
+  EXPECT_NE(
+      handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
+      std::string::npos);
+#else
   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
             std::string::npos);
+#endif
+}
+
+TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
+  using folly::ssl::OpenSSLUtils;
+  EXPECT_EQ(
+      OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
+  // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
+  EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
+  // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
+  EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
 }
 
 #if FOLLY_ALLOW_TFO
@@ -1786,13 +1782,15 @@ class ConnCallback : public AsyncSocket::ConnectCallback {
     state = State::SUCCESS;
   }
 
-  virtual void connectErr(const AsyncSocketException&) noexcept override {
+  virtual void connectErr(const AsyncSocketException& ex) noexcept override {
     state = State::ERROR;
+    error = ex.what();
   }
 
   enum class State { WAITING, SUCCESS, ERROR };
 
   State state{State::WAITING};
+  std::string error;
 };
 
 template <class Cardinality>
@@ -1854,6 +1852,7 @@ TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
   sock.close();
 }
 
+#if !defined(OPENSSL_IS_BORINGSSL)
 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
   // Start listening on a local port
   ConnectTimeoutCallback acceptCallback;
@@ -1867,9 +1866,11 @@ TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
   socket->enableTFO();
   EXPECT_THROW(
-      socket->open(std::chrono::milliseconds(1)), AsyncSocketException);
+      socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
 }
+#endif
 
+#if !defined(OPENSSL_IS_BORINGSSL)
 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
   // Start listening on a local port
   ConnectTimeoutCallback acceptCallback;
@@ -1885,6 +1886,76 @@ TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
   evb.loop();
   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
 }
+#endif
+
+TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
+  // Start listening on a local port
+  EmptyReadCallback readCallback;
+  HandshakeCallback handshakeCallback(
+      &readCallback, HandshakeCallback::EXPECT_ERROR);
+  HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, true);
+
+  EventBase evb;
+
+  auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 100);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+  EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
+}
+
+TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
+  // Start listening on a local port
+  EventBase evb;
+
+  // Hopefully nothing is listening on this address
+  SocketAddress addr("127.0.0.1", 65535);
+  auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
+  ConnCallback ccb;
+  socket->connect(&ccb, addr, 100);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+  EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
+}
+
+TEST(AsyncSSLSocketTest, TestPreReceivedData) {
+  EventBase clientEventBase;
+  EventBase serverEventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto dfServerCtx = std::make_shared<SSLContext>();
+  std::array<int, 2> fds;
+  getfds(fds.data());
+  getctx(clientCtx, dfServerCtx);
+
+  AsyncSSLSocket::UniquePtr clientSockPtr(
+      new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSockPtr(
+      new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
+  auto clientSock = clientSockPtr.get();
+  auto serverSock = serverSockPtr.get();
+  SSLHandshakeClient client(std::move(clientSockPtr), true, true);
+
+  // Steal some data from the server.
+  clientEventBase.loopOnce();
+  std::array<uint8_t, 10> buf;
+  recv(fds[1], buf.data(), buf.size(), 0);
+
+  serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
+  SSLHandshakeServer server(std::move(serverSockPtr), true, true);
+  while (!client.handshakeSuccess_ && !client.handshakeError_) {
+    serverEventBase.loopOnce();
+    clientEventBase.loopOnce();
+  }
+
+  EXPECT_TRUE(client.handshakeSuccess_);
+  EXPECT_TRUE(server.handshakeSuccess_);
+  EXPECT_EQ(
+      serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
+}
 
 #endif