Don't declare a variable for exceptions we discard
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
index d4078e56b01818c2d5f74234c6bd5c00a98f1c96..181a33451e226b0f227f9cc2b9688988a592ed7d 100644 (file)
  */
 #include <folly/io/async/test/AsyncSSLSocketTest.h>
 
-#include <signal.h>
 #include <pthread.h>
+#include <signal.h>
 
+#include <folly/SocketAddress.h>
 #include <folly/io/async/AsyncSSLSocket.h>
 #include <folly/io/async/EventBase.h>
-#include <folly/SocketAddress.h>
+#include <folly/portability/GMock.h>
+#include <folly/portability/GTest.h>
 #include <folly/portability/Sockets.h>
 #include <folly/portability/Unistd.h>
 
 #include <folly/io/async/test/BlockingSocket.h>
 
+#include <fcntl.h>
+#include <folly/io/Cursor.h>
+#include <openssl/bio.h>
+#include <sys/types.h>
 #include <fstream>
-#include <gtest/gtest.h>
 #include <iostream>
 #include <list>
 #include <set>
-#include <fcntl.h>
-#include <openssl/bio.h>
-#include <sys/types.h>
-#include <folly/io/Cursor.h>
+#include <thread>
 
 using std::string;
 using std::vector;
@@ -43,6 +45,8 @@ using std::cerr;
 using std::endl;
 using std::list;
 
+using namespace testing;
+
 namespace folly {
 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
@@ -55,7 +59,7 @@ const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
 constexpr size_t SSLClient::kMaxReadBufferSz;
 constexpr size_t SSLClient::kMaxReadsPerEvent;
 
-TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
+TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
     : ctx_(new folly::SSLContext),
       acb_(acb),
       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
@@ -67,7 +71,13 @@ TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
   acb_->ctx_ = ctx_;
   acb_->base_ = &evb_;
 
-  //set up the listening socket
+  // Enable TFO
+  if (enableTFO) {
+    LOG(INFO) << "server TFO enabled";
+    socket_->setTFOEnabled(true, 1000);
+  }
+
+  // set up the listening socket
   socket_->bind(0);
   socket_->getAddress(&address_);
   socket_->listen(100);
@@ -301,7 +311,7 @@ TEST(AsyncSSLSocketTest, HandshakeError) {
     uint8_t readbuf[128];
     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
-  } catch (AsyncSocketException &e) {
+  } catch (AsyncSocketException&) {
     ex = true;
   }
   EXPECT_TRUE(ex);
@@ -475,6 +485,10 @@ class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
   std::unique_ptr<NpnServer> server;
 };
 
+class NextProtocolTLSExtTest : public NextProtocolTest {
+  // For extended TLS protos
+};
+
 class NextProtocolNPNOnlyTest : public NextProtocolTest {
   // For mismatching protos
 };
@@ -599,7 +613,18 @@ INSTANTIATE_TEST_CASE_P(
         NextProtocolTypePair(
             SSLContext::NextProtocolType::NPN,
             SSLContext::NextProtocolType::NPN),
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::NPN,
+            SSLContext::NextProtocolType::ANY),
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::ANY,
+            SSLContext::NextProtocolType::ANY)));
+
 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+INSTANTIATE_TEST_CASE_P(
+    AsyncSSLSocketTest,
+    NextProtocolTLSExtTest,
+    ::testing::Values(
         NextProtocolTypePair(
             SSLContext::NextProtocolType::ALPN,
             SSLContext::NextProtocolType::ALPN),
@@ -608,14 +633,8 @@ INSTANTIATE_TEST_CASE_P(
             SSLContext::NextProtocolType::ANY),
         NextProtocolTypePair(
             SSLContext::NextProtocolType::ANY,
-            SSLContext::NextProtocolType::ALPN),
+            SSLContext::NextProtocolType::ALPN)));
 #endif
-        NextProtocolTypePair(
-            SSLContext::NextProtocolType::NPN,
-            SSLContext::NextProtocolType::ANY),
-        NextProtocolTypePair(
-            SSLContext::NextProtocolType::ANY,
-            SSLContext::NextProtocolType::ANY)));
 
 INSTANTIATE_TEST_CASE_P(
     AsyncSSLSocketTest,
@@ -858,7 +877,8 @@ TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
   cerr << "SSLClientTimeoutTest test completed" << endl;
 }
 
-
+// This is a FB-only extension, and the tests will fail without it
+#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
 /**
  * Test SSL server async cache
  */
@@ -895,7 +915,6 @@ TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
   // Start listening on a local port
   WriteCallbackBase writeCallback;
   ReadCallback readCallback(&writeCallback);
-  EmptyReadCallback clientReadCallback;
   HandshakeCallback handshakeCallback(&readCallback);
   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
   TestSSLAsyncCacheServer server(&acceptCallback);
@@ -905,6 +924,8 @@ TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
   // only do a TCP connect
   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
   sock->connect(nullptr, server.getAddress());
+
+  EmptyReadCallback clientReadCallback;
   clientReadCallback.tcpSocket_ = sock;
   sock->setReadCB(&clientReadCallback);
 
@@ -978,6 +999,7 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
 
   cerr << "SSLServerCacheCloseTest test completed" << endl;
 }
+#endif
 
 /**
  * Verify Client Ciphers obtained using SSL MSG Callback.
@@ -987,14 +1009,14 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
   auto clientCtx = std::make_shared<SSLContext>();
   auto serverCtx = std::make_shared<SSLContext>();
   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
-  serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
+  serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
   serverCtx->loadPrivateKey(testKey);
   serverCtx->loadCertificate(testCert);
   serverCtx->loadTrustedCertificates(testCA);
   serverCtx->loadClientCAList(testCA);
 
   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
-  clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
+  clientCtx->ciphers("AES256-SHA:RC4-MD5");
   clientCtx->loadPrivateKey(testKey);
   clientCtx->loadCertificate(testCert);
   clientCtx->loadTrustedCertificates(testCA);
@@ -1012,8 +1034,8 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
 
   eventBase.loop();
 
-  EXPECT_EQ(server.clientCiphers_,
-            "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
+  EXPECT_EQ(server.clientCiphers_, "AES256-SHA:RC4-MD5:00ff");
+  EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
   EXPECT_TRUE(client.handshakeVerify_);
   EXPECT_TRUE(client.handshakeSuccess_);
   EXPECT_TRUE(!client.handshakeError_);
@@ -1674,8 +1696,249 @@ TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
             std::string::npos);
 }
 
+TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
+  using folly::ssl::OpenSSLUtils;
+  EXPECT_EQ(
+      OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
+  // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
+  EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
+  // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
+  EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
+}
+
+#if FOLLY_ALLOW_TFO
+
+class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
+ public:
+  using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
+
+  explicit MockAsyncTFOSSLSocket(
+      std::shared_ptr<folly::SSLContext> sslCtx,
+      EventBase* evb)
+      : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
+
+  MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
+};
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, true);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket =
+      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+  socket->enableTFO();
+  socket->open();
+
+  // write()
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  socket->write(buf.data(), buf.size());
+
+  // read()
+  std::array<uint8_t, 128> readbuf;
+  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+  // close()
+  socket->close();
+}
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, false);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket =
+      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+  socket->enableTFO();
+  socket->open();
+
+  // write()
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  socket->write(buf.data(), buf.size());
+
+  // read()
+  std::array<uint8_t, 128> readbuf;
+  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+  // close()
+  socket->close();
+}
+
+class ConnCallback : public AsyncSocket::ConnectCallback {
+ public:
+  virtual void connectSuccess() noexcept override {
+    state = State::SUCCESS;
+  }
+
+  virtual void connectErr(const AsyncSocketException& ex) noexcept override {
+    state = State::ERROR;
+    error = ex.what();
+  }
+
+  enum class State { WAITING, SUCCESS, ERROR };
+
+  State state{State::WAITING};
+  std::string error;
+};
+
+template <class Cardinality>
+MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
+    EventBase* evb,
+    const SocketAddress& address,
+    Cardinality cardinality) {
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket = MockAsyncTFOSSLSocket::UniquePtr(
+      new MockAsyncTFOSSLSocket(sslContext, evb));
+  socket->enableTFO();
+
+  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+      .Times(cardinality)
+      .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
+        sockaddr_storage addr;
+        auto len = address.getAddress(&addr);
+        return connect(fd, (const struct sockaddr*)&addr, len);
+      }));
+  return socket;
+}
+
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, true);
+
+  EventBase evb;
+
+  auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
+
+  evb.runInEventBaseThread([&] { socket->detachEventBase(); });
+  evb.loop();
+
+  BlockingSocket sock(std::move(socket));
+  // write()
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  sock.write(buf.data(), buf.size());
+
+  // read()
+  std::array<uint8_t, 128> readbuf;
+  uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+  // close()
+  sock.close();
+}
+
+TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
+  // Start listening on a local port
+  ConnectTimeoutCallback acceptCallback;
+  TestSSLServer server(&acceptCallback, true);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket =
+      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+  socket->enableTFO();
+  EXPECT_THROW(
+      socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
+}
+
+TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
+  // Start listening on a local port
+  ConnectTimeoutCallback acceptCallback;
+  TestSSLServer server(&acceptCallback, true);
+
+  EventBase evb;
+
+  auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
+  ConnCallback ccb;
+  // Set a short timeout
+  socket->connect(&ccb, server.getAddress(), 1);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+}
+
+TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
+  // Start listening on a local port
+  EmptyReadCallback readCallback;
+  HandshakeCallback handshakeCallback(
+      &readCallback, HandshakeCallback::EXPECT_ERROR);
+  HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, true);
+
+  EventBase evb;
+
+  auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 100);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+  EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
+}
+
+TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
+  // Start listening on a local port
+  EventBase evb;
+
+  // Hopefully nothing is listening on this address
+  SocketAddress addr("127.0.0.1", 65535);
+  auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
+  ConnCallback ccb;
+  socket->connect(&ccb, addr, 100);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+  EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
+}
+
+#endif
+
 } // namespace
 
+#ifdef SIGPIPE
 ///////////////////////////////////////////////////////////////////////////
 // init_unit_test_suite
 ///////////////////////////////////////////////////////////////////////////
@@ -1687,3 +1950,4 @@ struct Initializer {
 };
 Initializer initializer;
 } // anonymous
+#endif