fix flaky ConnectTFOTimeout and ConnectTFOFallbackTimeout tests
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
index 5f1ca7304083b800fa45cc1ca06844b41b41e556..cdacacada0a75d22c407c03b005446edcf351c63 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2016 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  */
 #include <folly/io/async/test/AsyncSSLSocketTest.h>
 
-#include <signal.h>
 #include <pthread.h>
+#include <signal.h>
 
+#include <folly/SocketAddress.h>
 #include <folly/io/async/AsyncSSLSocket.h>
 #include <folly/io/async/EventBase.h>
-#include <folly/SocketAddress.h>
+#include <folly/portability/Sockets.h>
+#include <folly/portability/Unistd.h>
 
 #include <folly/io/async/test/BlockingSocket.h>
 
+#include <fcntl.h>
+#include <folly/io/Cursor.h>
 #include <gtest/gtest.h>
+#include <openssl/bio.h>
+#include <sys/types.h>
+#include <fstream>
 #include <iostream>
 #include <list>
 #include <set>
-#include <unistd.h>
-#include <fcntl.h>
-#include <poll.h>
-#include <sys/types.h>
-#include <sys/socket.h>
-#include <netinet/tcp.h>
-#include <folly/io/Cursor.h>
+
+#include <gmock/gmock.h>
 
 using std::string;
 using std::vector;
@@ -43,6 +45,8 @@ using std::cerr;
 using std::endl;
 using std::list;
 
+using namespace testing;
+
 namespace folly {
 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
@@ -52,10 +56,13 @@ const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
 
-TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) :
-ctx_(new folly::SSLContext),
-    acb_(acb),
-  socket_(new folly::AsyncServerSocket(&evb_)) {
+constexpr size_t SSLClient::kMaxReadBufferSz;
+constexpr size_t SSLClient::kMaxReadsPerEvent;
+
+TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
+    : ctx_(new folly::SSLContext),
+      acb_(acb),
+      socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
   // Set up the SSL context
   ctx_->loadCertificate(testCert);
   ctx_->loadPrivateKey(testKey);
@@ -64,7 +71,13 @@ ctx_(new folly::SSLContext),
   acb_->ctx_ = ctx_;
   acb_->base_ = &evb_;
 
-  //set up the listening socket
+  // Enable TFO
+  if (enableTFO) {
+    LOG(INFO) << "server TFO enabled";
+    socket_->setTFOEnabled(true, 1000);
+  }
+
+  // set up the listening socket
   socket_->bind(0);
   socket_->getAddress(&address_);
   socket_->listen(100);
@@ -73,6 +86,7 @@ ctx_(new folly::SSLContext),
 
   int ret = pthread_create(&thread_, nullptr, Main, this);
   assert(ret == 0);
+  (void)ret;
 
   std::cerr << "Accepting connections on " << address_ << std::endl;
 }
@@ -124,6 +138,36 @@ void sslsocketpair(
   // (*serverSock)->setSendTimeout(100);
 }
 
+// client protocol filters
+bool clientProtoFilterPickPony(unsigned char** client,
+  unsigned int* client_len, const unsigned char*, unsigned int ) {
+  //the protocol string in length prefixed byte string. the
+  //length byte is not included in the length
+  static unsigned char p[7] = {6,'p','o','n','i','e','s'};
+  *client = p;
+  *client_len = 7;
+  return true;
+}
+
+bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
+  const unsigned char*, unsigned int) {
+  return false;
+}
+
+std::string getFileAsBuf(const char* fileName) {
+  std::string buffer;
+  folly::readFile(fileName, buffer);
+  return buffer;
+}
+
+std::string getCommonName(X509* cert) {
+  X509_NAME* subject = X509_get_subject_name(cert);
+  std::string cn;
+  cn.resize(ub_common_name);
+  X509_NAME_get_text_by_NID(
+      subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
+  return cn;
+}
 
 /**
  * Test connecting to, writing to, reading from, and closing the
@@ -165,13 +209,89 @@ TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
   cerr << "ConnectWriteReadClose test completed" << endl;
 }
 
+/**
+ * Test reading after server close.
+ */
+TEST(AsyncSSLSocketTest, ReadAfterClose) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadEOFCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+  auto socket =
+      std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
+  socket->open();
+
+  // This should trigger an EOF on the client.
+  auto evb = handshakeCallback.getSocket()->getEventBase();
+  evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
+  std::array<uint8_t, 128> readbuf;
+  auto bytesRead = socket->read(readbuf.data(), readbuf.size());
+  EXPECT_EQ(0, bytesRead);
+}
+
+/**
+ * Test bad renegotiation
+ */
+TEST(AsyncSSLSocketTest, Renegotiate) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto dfServerCtx = std::make_shared<SSLContext>();
+  std::array<int, 2> fds;
+  getfds(fds.data());
+  getctx(clientCtx, dfServerCtx);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+  SSLHandshakeClient client(std::move(clientSock), true, true);
+  RenegotiatingServer server(std::move(serverSock));
+
+  while (!client.handshakeSuccess_ && !client.handshakeError_) {
+    eventBase.loopOnce();
+  }
+
+  ASSERT_TRUE(client.handshakeSuccess_);
+
+  auto sslSock = std::move(client).moveSocket();
+  sslSock->detachEventBase();
+  // This is nasty, however we don't want to add support for
+  // renegotiation in AsyncSSLSocket.
+  SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
+
+  auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
+
+  std::thread t([&]() { eventBase.loopForever(); });
+
+  // Trigger the renegotiation.
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  try {
+    socket->write(buf.data(), buf.size());
+  } catch (AsyncSocketException& e) {
+    LOG(INFO) << "client got error " << e.what();
+  }
+  eventBase.terminateLoopSoon();
+  t.join();
+
+  eventBase.loop();
+  ASSERT_TRUE(server.renegotiationError_);
+}
+
 /**
  * Negative test for handshakeError().
  */
 TEST(AsyncSSLSocketTest, HandshakeError) {
   // Start listening on a local port
   WriteCallbackBase writeCallback;
-  ReadCallback readCallback(&writeCallback);
+  WriteErrorCallback readCallback(&writeCallback);
   HandshakeCallback handshakeCallback(&readCallback);
   HandshakeErrorCallback acceptCallback(&handshakeCallback);
   TestSSLServer server(&acceptCallback);
@@ -190,6 +310,7 @@ TEST(AsyncSSLSocketTest, HandshakeError) {
 
     uint8_t readbuf[128];
     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
+    LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
   } catch (AsyncSocketException &e) {
     ex = true;
   }
@@ -295,133 +416,241 @@ TEST(AsyncSSLSocketTest, SocketWithDelay) {
   cerr << "SocketWithDelay test completed" << endl;
 }
 
-TEST(AsyncSSLSocketTest, NpnTestOverlap) {
+using NextProtocolTypePair =
+    std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
+
+class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
+  // For matching protos
+ public:
+  void SetUp() override { getctx(clientCtx, serverCtx); }
+
+  void connect(bool unset = false) {
+    getfds(fds);
+
+    if (unset) {
+      // unsetting NPN for any of [client, server] is enough to make NPN not
+      // work
+      clientCtx->unsetNextProtocols();
+    }
+
+    AsyncSSLSocket::UniquePtr clientSock(
+      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+    AsyncSSLSocket::UniquePtr serverSock(
+      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+    client = folly::make_unique<NpnClient>(std::move(clientSock));
+    server = folly::make_unique<NpnServer>(std::move(serverSock));
+
+    eventBase.loop();
+  }
+
+  void expectProtocol(const std::string& proto) {
+    EXPECT_NE(client->nextProtoLength, 0);
+    EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
+    EXPECT_EQ(
+        memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
+        0);
+    string selected((const char*)client->nextProto, client->nextProtoLength);
+    EXPECT_EQ(proto, selected);
+  }
+
+  void expectNoProtocol() {
+    EXPECT_EQ(client->nextProtoLength, 0);
+    EXPECT_EQ(server->nextProtoLength, 0);
+    EXPECT_EQ(client->nextProto, nullptr);
+    EXPECT_EQ(server->nextProto, nullptr);
+  }
+
+  void expectProtocolType() {
+    if (GetParam().first == SSLContext::NextProtocolType::ANY &&
+        GetParam().second == SSLContext::NextProtocolType::ANY) {
+      EXPECT_EQ(client->protocolType, server->protocolType);
+    } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
+               GetParam().second == SSLContext::NextProtocolType::ANY) {
+      // Well not much we can say
+    } else {
+      expectProtocolType(GetParam());
+    }
+  }
+
+  void expectProtocolType(NextProtocolTypePair expected) {
+    EXPECT_EQ(client->protocolType, expected.first);
+    EXPECT_EQ(server->protocolType, expected.second);
+  }
+
   EventBase eventBase;
-  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
-  std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
+  std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
+  std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
   int fds[2];
-  getfds(fds);
-  getctx(clientCtx, serverCtx);
+  std::unique_ptr<NpnClient> client;
+  std::unique_ptr<NpnServer> server;
+};
 
-  clientCtx->setAdvertisedNextProtocols({"blub","baz"});
-  serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+class NextProtocolTLSExtTest : public NextProtocolTest {
+  // For extended TLS protos
+};
 
-  AsyncSSLSocket::UniquePtr clientSock(
-    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
-  AsyncSSLSocket::UniquePtr serverSock(
-    new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
-  NpnClient client(std::move(clientSock));
-  NpnServer server(std::move(serverSock));
+class NextProtocolNPNOnlyTest : public NextProtocolTest {
+  // For mismatching protos
+};
 
-  eventBase.loop();
+class NextProtocolMismatchTest : public NextProtocolTest {
+  // For mismatching protos
+};
+
+TEST_P(NextProtocolTest, NpnTestOverlap) {
+  clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
+  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().second);
+
+  connect();
 
-  EXPECT_TRUE(client.nextProtoLength != 0);
-  EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
-  EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
-                           server.nextProtoLength), 0);
-  string selected((const char*)client.nextProto, client.nextProtoLength);
-  EXPECT_EQ(selected.compare("baz"), 0);
+  expectProtocol("baz");
+  expectProtocolType();
 }
 
-TEST(AsyncSSLSocketTest, NpnTestUnset) {
+TEST_P(NextProtocolTest, NpnTestUnset) {
   // Identical to above test, except that we want unset NPN before
   // looping.
-  EventBase eventBase;
-  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
-  std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
-  int fds[2];
-  getfds(fds);
-  getctx(clientCtx, serverCtx);
+  clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
+  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().second);
 
-  clientCtx->setAdvertisedNextProtocols({"blub","baz"});
-  serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+  connect(true /* unset */);
 
-  AsyncSSLSocket::UniquePtr clientSock(
-    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
-  AsyncSSLSocket::UniquePtr serverSock(
-    new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+  // if alpn negotiation fails, type will appear as npn
+  expectNoProtocol();
+  EXPECT_EQ(client->protocolType, server->protocolType);
+}
 
-  // unsetting NPN for any of [client, server] is enought to make NPN not
-  // work
-  clientCtx->unsetNextProtocols();
+TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
+  clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
+  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().second);
 
-  NpnClient client(std::move(clientSock));
-  NpnServer server(std::move(serverSock));
+  connect();
 
-  eventBase.loop();
+  expectNoProtocol();
+  expectProtocolType(
+      {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
+}
 
-  EXPECT_TRUE(client.nextProtoLength == 0);
-  EXPECT_TRUE(server.nextProtoLength == 0);
-  EXPECT_TRUE(client.nextProto == nullptr);
-  EXPECT_TRUE(server.nextProto == nullptr);
+// Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
+// will fail on 1.0.2 before that.
+TEST_P(NextProtocolTest, NpnTestNoOverlap) {
+  clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
+  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().second);
+
+  connect();
+
+  if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
+      GetParam().second == SSLContext::NextProtocolType::ALPN) {
+    // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
+    // mismatch should result in a fatal alert, but this is OpenSSL's current
+    // behavior and we want to know if it changes.
+    expectNoProtocol();
+  } else {
+    expectProtocol("blub");
+    expectProtocolType(
+        {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
+  }
 }
 
-TEST(AsyncSSLSocketTest, NpnTestNoOverlap) {
-  EventBase eventBase;
-  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
-  std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
-  int fds[2];
-  getfds(fds);
-  getctx(clientCtx, serverCtx);
+TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
+  clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
+  clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
+  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().second);
 
-  clientCtx->setAdvertisedNextProtocols({"blub"});
-  serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+  connect();
 
-  AsyncSSLSocket::UniquePtr clientSock(
-    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
-  AsyncSSLSocket::UniquePtr serverSock(
-    new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
-  NpnClient client(std::move(clientSock));
-  NpnServer server(std::move(serverSock));
+  expectProtocol("ponies");
+  expectProtocolType();
+}
 
-  eventBase.loop();
+TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
+  clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
+  clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
+  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().second);
 
-  EXPECT_TRUE(client.nextProtoLength != 0);
-  EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
-  EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
-                           server.nextProtoLength), 0);
-  string selected((const char*)client.nextProto, client.nextProtoLength);
-  EXPECT_EQ(selected.compare("blub"), 0);
+  connect();
+
+  expectProtocol("blub");
+  expectProtocolType();
 }
 
-TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
+TEST_P(NextProtocolTest, RandomizedNpnTest) {
   // Probability that this test will fail is 2^-64, which could be considered
   // as negligible.
   const int kTries = 64;
 
+  clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().first);
+  serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
+                                                  GetParam().second);
+
   std::set<string> selectedProtocols;
   for (int i = 0; i < kTries; ++i) {
-    EventBase eventBase;
-    std::shared_ptr<SSLContext> clientCtx = std::make_shared<SSLContext>();
-    std::shared_ptr<SSLContext> serverCtx = std::make_shared<SSLContext>();
-    int fds[2];
-    getfds(fds);
-    getctx(clientCtx, serverCtx);
-
-    clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
-    serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}},
-        {1, {"bar"}}});
-
-
-    AsyncSSLSocket::UniquePtr clientSock(
-      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
-    AsyncSSLSocket::UniquePtr serverSock(
-      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
-    NpnClient client(std::move(clientSock));
-    NpnServer server(std::move(serverSock));
-
-    eventBase.loop();
-
-    EXPECT_TRUE(client.nextProtoLength != 0);
-    EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
-    EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
-                             server.nextProtoLength), 0);
-    string selected((const char*)client.nextProto, client.nextProtoLength);
+    connect();
+
+    EXPECT_NE(client->nextProtoLength, 0);
+    EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
+    EXPECT_EQ(
+        memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
+        0);
+    string selected((const char*)client->nextProto, client->nextProtoLength);
     selectedProtocols.insert(selected);
+    expectProtocolType();
   }
   EXPECT_EQ(selectedProtocols.size(), 2);
 }
 
+INSTANTIATE_TEST_CASE_P(
+    AsyncSSLSocketTest,
+    NextProtocolTest,
+    ::testing::Values(
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::NPN,
+            SSLContext::NextProtocolType::NPN),
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::NPN,
+            SSLContext::NextProtocolType::ANY),
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::ANY,
+            SSLContext::NextProtocolType::ANY)));
+
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+INSTANTIATE_TEST_CASE_P(
+    AsyncSSLSocketTest,
+    NextProtocolTLSExtTest,
+    ::testing::Values(
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::ALPN,
+            SSLContext::NextProtocolType::ALPN),
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::ALPN,
+            SSLContext::NextProtocolType::ANY),
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::ANY,
+            SSLContext::NextProtocolType::ALPN)));
+#endif
+
+INSTANTIATE_TEST_CASE_P(
+    AsyncSSLSocketTest,
+    NextProtocolNPNOnlyTest,
+    ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
+                                           SSLContext::NextProtocolType::NPN)));
+
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+INSTANTIATE_TEST_CASE_P(
+    AsyncSSLSocketTest,
+    NextProtocolMismatchTest,
+    ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
+                                           SSLContext::NextProtocolType::ALPN),
+                      NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
+                                           SSLContext::NextProtocolType::NPN)));
+#endif
 
 #ifndef OPENSSL_NO_TLSEXT
 /**
@@ -496,6 +725,41 @@ TEST(AsyncSSLSocketTest, SNITestNotMatch) {
   EXPECT_TRUE(!client.serverNameMatch);
   EXPECT_TRUE(!server.serverNameMatch);
 }
+/**
+ * 1. Client sends TLSEXT_HOSTNAME in client hello.
+ * 2. We then change the serverName.
+ * 3. We expect that we get 'false' as the result for serNameMatch.
+ */
+
+TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
+   EventBase eventBase;
+  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
+  std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
+  // Use the same SSLContext to continue the handshake after
+  // tlsext_hostname match.
+  std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
+  const std::string serverName("xyz.newdev.facebook.com");
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, dfServerCtx);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
+  //Change the server name
+  std::string newName("new.com");
+  clientSock->setServerName(newName);
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+  SNIClient client(std::move(clientSock));
+  SNIServer server(std::move(serverSock),
+                   dfServerCtx,
+                   hskServerCtx,
+                   serverName);
+
+  eventBase.loop();
+
+  EXPECT_TRUE(!client.serverNameMatch);
+}
 
 /**
  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
@@ -544,8 +808,7 @@ TEST(AsyncSSLSocketTest, SSLClientTest) {
 
   // Set up SSL client
   EventBase eventBase;
-  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
-                                             1));
+  auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
 
   client->connect();
   EventBaseAborter eba(&eventBase, 3000);
@@ -571,8 +834,8 @@ TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
 
   // Set up SSL client
   EventBase eventBase;
-  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
-                                             10));
+  auto client =
+      std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
 
   client->connect();
   EventBaseAborter eba(&eventBase, 3000);
@@ -597,8 +860,8 @@ TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
 
   // Set up SSL client
   EventBase eventBase;
-  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
-                                             1, 10));
+  auto client =
+      std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
   client->connect(true /* write before connect completes */);
   EventBaseAborter eba(&eventBase, 3000);
   eventBase.loop();
@@ -628,8 +891,8 @@ TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
 
   // Set up SSL client
   EventBase eventBase;
-  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
-                                             10, 500));
+  auto client =
+      std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
 
   client->connect();
   EventBaseAborter eba(&eventBase, 3000);
@@ -685,8 +948,7 @@ TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
 
   // Set up SSL client
   EventBase eventBase;
-  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
-                                             2));
+  auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
 
   client->connect();
   EventBaseAborter eba(&eventBase, 3000);
@@ -715,8 +977,8 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
 
   // Set up SSL client
   EventBase eventBase;
-  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
-                                             2, 100));
+  auto client =
+      std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
 
   client->connect();
   EventBaseAborter eba(&eventBase, 3000);
@@ -725,7 +987,7 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
       handshakeCallback.closeSocket();});
   // give time for the cache lookup to come back and find it closed
-  usleep(500000);
+  handshakeCallback.waitForHandshake();
 
   EXPECT_EQ(server.getAsyncCallbacks(), 1);
   EXPECT_EQ(server.getAsyncLookups(), 1);
@@ -803,6 +1065,7 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
   cursor.write<uint32_t>(0);
 
   SSL* ssl = ctx->createSSL();
+  SCOPE_EXIT { SSL_free(ssl); };
   AsyncSSLSocket::UniquePtr sock(
       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
   sock->enableClientHelloParsing();
@@ -842,6 +1105,7 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
   cursor.write<uint32_t>(0);
 
   SSL* ssl = ctx->createSSL();
+  SCOPE_EXIT { SSL_free(ssl); };
   AsyncSSLSocket::UniquePtr sock(
       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
   sock->enableClientHelloParsing();
@@ -888,6 +1152,7 @@ TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
   cursor.write<uint32_t>(0);
 
   SSL* ssl = ctx->createSSL();
+  SCOPE_EXIT { SSL_free(ssl); };
   AsyncSSLSocket::UniquePtr sock(
       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
   sock->enableClientHelloParsing();
@@ -938,9 +1203,11 @@ TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
   EXPECT_TRUE(client.handshakeVerify_);
   EXPECT_TRUE(client.handshakeSuccess_);
   EXPECT_TRUE(!client.handshakeError_);
+  EXPECT_LE(0, client.handshakeTime.count());
   EXPECT_TRUE(!server.handshakeVerify_);
   EXPECT_TRUE(server.handshakeSuccess_);
   EXPECT_TRUE(!server.handshakeError_);
+  EXPECT_LE(0, server.handshakeTime.count());
 }
 
 /**
@@ -974,9 +1241,11 @@ TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
   EXPECT_TRUE(client.handshakeVerify_);
   EXPECT_TRUE(!client.handshakeSuccess_);
   EXPECT_TRUE(client.handshakeError_);
+  EXPECT_LE(0, client.handshakeTime.count());
   EXPECT_TRUE(!server.handshakeVerify_);
   EXPECT_TRUE(!server.handshakeSuccess_);
   EXPECT_TRUE(server.handshakeError_);
+  EXPECT_LE(0, server.handshakeTime.count());
 }
 
 /**
@@ -1012,9 +1281,11 @@ TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
   EXPECT_TRUE(!client.handshakeVerify_);
   EXPECT_TRUE(client.handshakeSuccess_);
   EXPECT_TRUE(!client.handshakeError_);
+  EXPECT_LE(0, client.handshakeTime.count());
   EXPECT_TRUE(!server.handshakeVerify_);
   EXPECT_TRUE(server.handshakeSuccess_);
   EXPECT_TRUE(!server.handshakeError_);
+  EXPECT_LE(0, server.handshakeTime.count());
 }
 
 /**
@@ -1055,9 +1326,11 @@ TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
   EXPECT_TRUE(client.handshakeVerify_);
   EXPECT_TRUE(client.handshakeSuccess_);
   EXPECT_FALSE(client.handshakeError_);
+  EXPECT_LE(0, client.handshakeTime.count());
   EXPECT_TRUE(server.handshakeVerify_);
   EXPECT_TRUE(server.handshakeSuccess_);
   EXPECT_FALSE(server.handshakeError_);
+  EXPECT_LE(0, server.handshakeTime.count());
 }
 
 /**
@@ -1089,9 +1362,11 @@ TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
   EXPECT_TRUE(client.handshakeVerify_);
   EXPECT_TRUE(client.handshakeSuccess_);
   EXPECT_TRUE(!client.handshakeError_);
+  EXPECT_LE(0, client.handshakeTime.count());
   EXPECT_TRUE(!server.handshakeVerify_);
   EXPECT_TRUE(server.handshakeSuccess_);
   EXPECT_TRUE(!server.handshakeError_);
+  EXPECT_LE(0, server.handshakeTime.count());
 }
 
 /**
@@ -1124,9 +1399,11 @@ TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
   EXPECT_TRUE(!client.handshakeVerify_);
   EXPECT_TRUE(client.handshakeSuccess_);
   EXPECT_TRUE(!client.handshakeError_);
+  EXPECT_LE(0, client.handshakeTime.count());
   EXPECT_TRUE(!server.handshakeVerify_);
   EXPECT_TRUE(server.handshakeSuccess_);
   EXPECT_TRUE(!server.handshakeError_);
+  EXPECT_LE(0, server.handshakeTime.count());
 }
 
 /**
@@ -1166,9 +1443,11 @@ TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
   EXPECT_TRUE(client.handshakeVerify_);
   EXPECT_TRUE(client.handshakeSuccess_);
   EXPECT_FALSE(client.handshakeError_);
+  EXPECT_LE(0, client.handshakeTime.count());
   EXPECT_TRUE(server.handshakeVerify_);
   EXPECT_TRUE(server.handshakeSuccess_);
   EXPECT_FALSE(server.handshakeError_);
+  EXPECT_LE(0, server.handshakeTime.count());
 }
 
 
@@ -1205,6 +1484,49 @@ TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
   EXPECT_FALSE(server.handshakeVerify_);
   EXPECT_FALSE(server.handshakeSuccess_);
   EXPECT_TRUE(server.handshakeError_);
+  EXPECT_LE(0, client.handshakeTime.count());
+  EXPECT_LE(0, server.handshakeTime.count());
+}
+
+TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
+  auto cert = getFileAsBuf(testCert);
+  auto key = getFileAsBuf(testKey);
+
+  ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
+  BIO_write(certBio.get(), cert.data(), cert.size());
+  ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
+  BIO_write(keyBio.get(), key.data(), key.size());
+
+  // Create SSL structs from buffers to get properties
+  ssl::X509UniquePtr certStruct(
+      PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
+  ssl::EvpPkeyUniquePtr keyStruct(
+      PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
+  certBio = nullptr;
+  keyBio = nullptr;
+
+  auto origCommonName = getCommonName(certStruct.get());
+  auto origKeySize = EVP_PKEY_bits(keyStruct.get());
+  certStruct = nullptr;
+  keyStruct = nullptr;
+
+  auto ctx = std::make_shared<SSLContext>();
+  ctx->loadPrivateKeyFromBufferPEM(key);
+  ctx->loadCertificateFromBufferPEM(cert);
+  ctx->loadTrustedCertificates(testCA);
+
+  ssl::SSLUniquePtr ssl(ctx->createSSL());
+
+  auto newCert = SSL_get_certificate(ssl.get());
+  auto newKey = SSL_get_privatekey(ssl.get());
+
+  // Get properties from SSL struct
+  auto newCommonName = getCommonName(newCert);
+  auto newKeySize = EVP_PKEY_bits(newKey);
+
+  // Check that the key and cert have the expected properties
+  EXPECT_EQ(origCommonName, newCommonName);
+  EXPECT_EQ(origKeySize, newKeySize);
 }
 
 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
@@ -1224,8 +1546,349 @@ TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
   socket->setMinWriteSize(50000);
   EXPECT_EQ(50000, socket->getMinWriteSize());
 }
+
+class ReadCallbackTerminator : public ReadCallback {
+ public:
+  ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
+      : ReadCallback(wcb)
+      , base_(base) {}
+
+  // Do not write data back, terminate the loop.
+  void readDataAvailable(size_t len) noexcept override {
+    std::cerr << "readDataAvailable, len " << len << std::endl;
+
+    currentBuffer.length = len;
+
+    buffers.push_back(currentBuffer);
+    currentBuffer.reset();
+    state = STATE_SUCCEEDED;
+
+    socket_->setReadCB(nullptr);
+    base_->terminateLoopSoon();
+  }
+ private:
+  EventBase* base_;
+};
+
+
+/**
+ * Test a full unencrypted codepath
+ */
+TEST(AsyncSSLSocketTest, UnencryptedTest) {
+  EventBase base;
+
+  auto clientCtx = std::make_shared<folly::SSLContext>();
+  auto serverCtx = std::make_shared<folly::SSLContext>();
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, serverCtx);
+  auto client = AsyncSSLSocket::newSocket(
+                  clientCtx, &base, fds[0], false, true);
+  auto server = AsyncSSLSocket::newSocket(
+                  serverCtx, &base, fds[1], true, true);
+
+  ReadCallbackTerminator readCallback(&base, nullptr);
+  server->setReadCB(&readCallback);
+  readCallback.setSocket(server);
+
+  uint8_t buf[128];
+  memset(buf, 'a', sizeof(buf));
+  client->write(nullptr, buf, sizeof(buf));
+
+  // Check that bytes are unencrypted
+  char c;
+  EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
+  EXPECT_EQ('a', c);
+
+  EventBaseAborter eba(&base, 3000);
+  base.loop();
+
+  EXPECT_EQ(1, readCallback.buffers.size());
+  EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
+
+  server->setReadCB(&readCallback);
+
+  // Unencrypted
+  server->sslAccept(nullptr);
+  client->sslConn(nullptr);
+
+  // Do NOT wait for handshake, writing should be queued and happen after
+
+  client->write(nullptr, buf, sizeof(buf));
+
+  // Check that bytes are *not* unencrypted
+  char c2;
+  EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
+  EXPECT_NE('a', c2);
+
+
+  base.loop();
+
+  EXPECT_EQ(2, readCallback.buffers.size());
+  EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
+}
+
+TEST(AsyncSSLSocketTest, ConnResetErrorString) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  WriteErrorCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback,
+                                      HandshakeCallback::EXPECT_ERROR);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+  socket->open();
+  uint8_t buf[3] = {0x16, 0x03, 0x01};
+  socket->write(buf, sizeof(buf));
+  socket->closeWithReset();
+
+  handshakeCallback.waitForHandshake();
+  EXPECT_NE(
+      handshakeCallback.errorString_.find("Network error"), std::string::npos);
+  EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
+}
+
+TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  WriteErrorCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback,
+                                      HandshakeCallback::EXPECT_ERROR);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+  socket->open();
+  uint8_t buf[3] = {0x16, 0x03, 0x01};
+  socket->write(buf, sizeof(buf));
+  socket->close();
+
+  handshakeCallback.waitForHandshake();
+  EXPECT_NE(
+      handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
+  EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
+}
+
+TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  WriteErrorCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback,
+                                      HandshakeCallback::EXPECT_ERROR);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+  socket->open();
+  uint8_t buf[256] = {0x16, 0x03};
+  memset(buf + 2, 'a', sizeof(buf) - 2);
+  socket->write(buf, sizeof(buf));
+  socket->close();
+
+  handshakeCallback.waitForHandshake();
+  EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
+            std::string::npos);
+  EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
+            std::string::npos);
 }
 
+#if FOLLY_ALLOW_TFO
+
+class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
+ public:
+  using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
+
+  explicit MockAsyncTFOSSLSocket(
+      std::shared_ptr<folly::SSLContext> sslCtx,
+      EventBase* evb)
+      : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
+
+  MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
+};
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, true);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket =
+      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+  socket->enableTFO();
+  socket->open();
+
+  // write()
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  socket->write(buf.data(), buf.size());
+
+  // read()
+  std::array<uint8_t, 128> readbuf;
+  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+  // close()
+  socket->close();
+}
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, false);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket =
+      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+  socket->enableTFO();
+  socket->open();
+
+  // write()
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  socket->write(buf.data(), buf.size());
+
+  // read()
+  std::array<uint8_t, 128> readbuf;
+  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+  // close()
+  socket->close();
+}
+
+class ConnCallback : public AsyncSocket::ConnectCallback {
+ public:
+  virtual void connectSuccess() noexcept override {
+    state = State::SUCCESS;
+  }
+
+  virtual void connectErr(const AsyncSocketException&) noexcept override {
+    state = State::ERROR;
+  }
+
+  enum class State { WAITING, SUCCESS, ERROR };
+
+  State state{State::WAITING};
+};
+
+template <class Cardinality>
+MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
+    EventBase* evb,
+    const SocketAddress& address,
+    Cardinality cardinality) {
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket = MockAsyncTFOSSLSocket::UniquePtr(
+      new MockAsyncTFOSSLSocket(sslContext, evb));
+  socket->enableTFO();
+
+  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+      .Times(cardinality)
+      .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
+        sockaddr_storage addr;
+        auto len = address.getAddress(&addr);
+        return connect(fd, (const struct sockaddr*)&addr, len);
+      }));
+  return socket;
+}
+
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, true);
+
+  EventBase evb;
+
+  auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
+
+  evb.runInEventBaseThread([&] { socket->detachEventBase(); });
+  evb.loop();
+
+  BlockingSocket sock(std::move(socket));
+  // write()
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  sock.write(buf.data(), buf.size());
+
+  // read()
+  std::array<uint8_t, 128> readbuf;
+  uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+  // close()
+  sock.close();
+}
+
+TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
+  // Start listening on a local port
+  ConnectTimeoutCallback acceptCallback;
+  TestSSLServer server(&acceptCallback, true);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket =
+      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+  socket->enableTFO();
+  EXPECT_THROW(
+      socket->open(std::chrono::milliseconds(1)), AsyncSocketException);
+}
+
+TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
+  // Start listening on a local port
+  ConnectTimeoutCallback acceptCallback;
+  TestSSLServer server(&acceptCallback, true);
+
+  EventBase evb;
+
+  auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
+  ConnCallback ccb;
+  // Set a short timeout
+  socket->connect(&ccb, server.getAddress(), 1);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+}
+
+#endif
+
+} // namespace
+
 ///////////////////////////////////////////////////////////////////////////
 // init_unit_test_suite
 ///////////////////////////////////////////////////////////////////////////