fix flaky ConnectTFOTimeout and ConnectTFOFallbackTimeout tests
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
index 5bb0a82c1a804d63df9c10f97ac67d6ea47820e4..cdacacada0a75d22c407c03b005446edcf351c63 100644 (file)
  */
 #include <folly/io/async/test/AsyncSSLSocketTest.h>
 
-#include <signal.h>
 #include <pthread.h>
+#include <signal.h>
 
+#include <folly/SocketAddress.h>
 #include <folly/io/async/AsyncSSLSocket.h>
 #include <folly/io/async/EventBase.h>
-#include <folly/SocketAddress.h>
+#include <folly/portability/Sockets.h>
+#include <folly/portability/Unistd.h>
 
 #include <folly/io/async/test/BlockingSocket.h>
 
-#include <fstream>
+#include <fcntl.h>
+#include <folly/io/Cursor.h>
 #include <gtest/gtest.h>
+#include <openssl/bio.h>
+#include <sys/types.h>
+#include <fstream>
 #include <iostream>
 #include <list>
 #include <set>
-#include <unistd.h>
-#include <fcntl.h>
-#include <openssl/bio.h>
-#include <poll.h>
-#include <sys/types.h>
-#include <sys/socket.h>
-#include <netinet/tcp.h>
-#include <folly/io/Cursor.h>
+
+#include <gmock/gmock.h>
 
 using std::string;
 using std::vector;
@@ -45,6 +45,8 @@ using std::cerr;
 using std::endl;
 using std::list;
 
+using namespace testing;
+
 namespace folly {
 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
@@ -57,10 +59,7 @@ const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
 constexpr size_t SSLClient::kMaxReadBufferSz;
 constexpr size_t SSLClient::kMaxReadsPerEvent;
 
-inline void BIO_free_fb(BIO* bio) { CHECK_EQ(1, BIO_free(bio)); }
-using BIO_deleter = folly::static_function_deleter<BIO, &BIO_free_fb>;
-
-TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
+TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
     : ctx_(new folly::SSLContext),
       acb_(acb),
       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
@@ -72,7 +71,13 @@ TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
   acb_->ctx_ = ctx_;
   acb_->base_ = &evb_;
 
-  //set up the listening socket
+  // Enable TFO
+  if (enableTFO) {
+    LOG(INFO) << "server TFO enabled";
+    socket_->setTFOEnabled(true, 1000);
+  }
+
+  // set up the listening socket
   socket_->bind(0);
   socket_->getAddress(&address_);
   socket_->listen(100);
@@ -204,13 +209,89 @@ TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
   cerr << "ConnectWriteReadClose test completed" << endl;
 }
 
+/**
+ * Test reading after server close.
+ */
+TEST(AsyncSSLSocketTest, ReadAfterClose) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadEOFCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+  auto socket =
+      std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
+  socket->open();
+
+  // This should trigger an EOF on the client.
+  auto evb = handshakeCallback.getSocket()->getEventBase();
+  evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
+  std::array<uint8_t, 128> readbuf;
+  auto bytesRead = socket->read(readbuf.data(), readbuf.size());
+  EXPECT_EQ(0, bytesRead);
+}
+
+/**
+ * Test bad renegotiation
+ */
+TEST(AsyncSSLSocketTest, Renegotiate) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto dfServerCtx = std::make_shared<SSLContext>();
+  std::array<int, 2> fds;
+  getfds(fds.data());
+  getctx(clientCtx, dfServerCtx);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+  SSLHandshakeClient client(std::move(clientSock), true, true);
+  RenegotiatingServer server(std::move(serverSock));
+
+  while (!client.handshakeSuccess_ && !client.handshakeError_) {
+    eventBase.loopOnce();
+  }
+
+  ASSERT_TRUE(client.handshakeSuccess_);
+
+  auto sslSock = std::move(client).moveSocket();
+  sslSock->detachEventBase();
+  // This is nasty, however we don't want to add support for
+  // renegotiation in AsyncSSLSocket.
+  SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
+
+  auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
+
+  std::thread t([&]() { eventBase.loopForever(); });
+
+  // Trigger the renegotiation.
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  try {
+    socket->write(buf.data(), buf.size());
+  } catch (AsyncSocketException& e) {
+    LOG(INFO) << "client got error " << e.what();
+  }
+  eventBase.terminateLoopSoon();
+  t.join();
+
+  eventBase.loop();
+  ASSERT_TRUE(server.renegotiationError_);
+}
+
 /**
  * Negative test for handshakeError().
  */
 TEST(AsyncSSLSocketTest, HandshakeError) {
   // Start listening on a local port
   WriteCallbackBase writeCallback;
-  ReadCallback readCallback(&writeCallback);
+  WriteErrorCallback readCallback(&writeCallback);
   HandshakeCallback handshakeCallback(&readCallback);
   HandshakeErrorCallback acceptCallback(&handshakeCallback);
   TestSSLServer server(&acceptCallback);
@@ -404,6 +485,10 @@ class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
   std::unique_ptr<NpnServer> server;
 };
 
+class NextProtocolTLSExtTest : public NextProtocolTest {
+  // For extended TLS protos
+};
+
 class NextProtocolNPNOnlyTest : public NextProtocolTest {
   // For mismatching protos
 };
@@ -449,15 +534,26 @@ TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
 }
 
-TEST_P(NextProtocolNPNOnlyTest, NpnTestNoOverlap) {
+// Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
+// will fail on 1.0.2 before that.
+TEST_P(NextProtocolTest, NpnTestNoOverlap) {
   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
                                         GetParam().second);
 
   connect();
 
-  expectProtocol("blub");
-  expectProtocolType();
+  if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
+      GetParam().second == SSLContext::NextProtocolType::ALPN) {
+    // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
+    // mismatch should result in a fatal alert, but this is OpenSSL's current
+    // behavior and we want to know if it changes.
+    expectNoProtocol();
+  } else {
+    expectProtocol("blub");
+    expectProtocolType(
+        {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
+  }
 }
 
 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
@@ -513,20 +609,32 @@ TEST_P(NextProtocolTest, RandomizedNpnTest) {
 INSTANTIATE_TEST_CASE_P(
     AsyncSSLSocketTest,
     NextProtocolTest,
-    ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
-                                           SSLContext::NextProtocolType::NPN),
-#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
-                      NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
-                                           SSLContext::NextProtocolType::ALPN),
-#endif
-                      NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
-                                           SSLContext::NextProtocolType::ANY),
+    ::testing::Values(
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::NPN,
+            SSLContext::NextProtocolType::NPN),
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::NPN,
+            SSLContext::NextProtocolType::ANY),
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::ANY,
+            SSLContext::NextProtocolType::ANY)));
+
 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
-                      NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
-                                           SSLContext::NextProtocolType::ANY),
+INSTANTIATE_TEST_CASE_P(
+    AsyncSSLSocketTest,
+    NextProtocolTLSExtTest,
+    ::testing::Values(
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::ALPN,
+            SSLContext::NextProtocolType::ALPN),
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::ALPN,
+            SSLContext::NextProtocolType::ANY),
+        NextProtocolTypePair(
+            SSLContext::NextProtocolType::ANY,
+            SSLContext::NextProtocolType::ALPN)));
 #endif
-                      NextProtocolTypePair(SSLContext::NextProtocolType::ANY,
-                                           SSLContext::NextProtocolType::ANY)));
 
 INSTANTIATE_TEST_CASE_P(
     AsyncSSLSocketTest,
@@ -879,7 +987,7 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
       handshakeCallback.closeSocket();});
   // give time for the cache lookup to come back and find it closed
-  usleep(500000);
+  handshakeCallback.waitForHandshake();
 
   EXPECT_EQ(server.getAsyncCallbacks(), 1);
   EXPECT_EQ(server.getAsyncLookups(), 1);
@@ -1384,15 +1492,15 @@ TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
   auto cert = getFileAsBuf(testCert);
   auto key = getFileAsBuf(testKey);
 
-  std::unique_ptr<BIO, BIO_deleter> certBio(BIO_new(BIO_s_mem()));
+  ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
   BIO_write(certBio.get(), cert.data(), cert.size());
-  std::unique_ptr<BIO, BIO_deleter> keyBio(BIO_new(BIO_s_mem()));
+  ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
   BIO_write(keyBio.get(), key.data(), key.size());
 
   // Create SSL structs from buffers to get properties
-  X509_UniquePtr certStruct(
+  ssl::X509UniquePtr certStruct(
       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
-  EVP_PKEY_UniquePtr keyStruct(
+  ssl::EvpPkeyUniquePtr keyStruct(
       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
   certBio = nullptr;
   keyBio = nullptr;
@@ -1407,7 +1515,7 @@ TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
   ctx->loadCertificateFromBufferPEM(cert);
   ctx->loadTrustedCertificates(testCA);
 
-  SSL_UniquePtr ssl(ctx->createSSL());
+  ssl::SSLUniquePtr ssl(ctx->createSSL());
 
   auto newCert = SSL_get_certificate(ssl.get());
   auto newKey = SSL_get_privatekey(ssl.get());
@@ -1520,6 +1628,265 @@ TEST(AsyncSSLSocketTest, UnencryptedTest) {
   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
 }
 
+TEST(AsyncSSLSocketTest, ConnResetErrorString) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  WriteErrorCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback,
+                                      HandshakeCallback::EXPECT_ERROR);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+  socket->open();
+  uint8_t buf[3] = {0x16, 0x03, 0x01};
+  socket->write(buf, sizeof(buf));
+  socket->closeWithReset();
+
+  handshakeCallback.waitForHandshake();
+  EXPECT_NE(
+      handshakeCallback.errorString_.find("Network error"), std::string::npos);
+  EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
+}
+
+TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  WriteErrorCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback,
+                                      HandshakeCallback::EXPECT_ERROR);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+  socket->open();
+  uint8_t buf[3] = {0x16, 0x03, 0x01};
+  socket->write(buf, sizeof(buf));
+  socket->close();
+
+  handshakeCallback.waitForHandshake();
+  EXPECT_NE(
+      handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
+  EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
+}
+
+TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  WriteErrorCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback,
+                                      HandshakeCallback::EXPECT_ERROR);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+  socket->open();
+  uint8_t buf[256] = {0x16, 0x03};
+  memset(buf + 2, 'a', sizeof(buf) - 2);
+  socket->write(buf, sizeof(buf));
+  socket->close();
+
+  handshakeCallback.waitForHandshake();
+  EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
+            std::string::npos);
+  EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
+            std::string::npos);
+}
+
+#if FOLLY_ALLOW_TFO
+
+class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
+ public:
+  using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
+
+  explicit MockAsyncTFOSSLSocket(
+      std::shared_ptr<folly::SSLContext> sslCtx,
+      EventBase* evb)
+      : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
+
+  MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
+};
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, true);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket =
+      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+  socket->enableTFO();
+  socket->open();
+
+  // write()
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  socket->write(buf.data(), buf.size());
+
+  // read()
+  std::array<uint8_t, 128> readbuf;
+  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+  // close()
+  socket->close();
+}
+
+/**
+ * Test connecting to, writing to, reading from, and closing the
+ * connection to the SSL server with TFO.
+ */
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, false);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket =
+      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+  socket->enableTFO();
+  socket->open();
+
+  // write()
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  socket->write(buf.data(), buf.size());
+
+  // read()
+  std::array<uint8_t, 128> readbuf;
+  uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+  // close()
+  socket->close();
+}
+
+class ConnCallback : public AsyncSocket::ConnectCallback {
+ public:
+  virtual void connectSuccess() noexcept override {
+    state = State::SUCCESS;
+  }
+
+  virtual void connectErr(const AsyncSocketException&) noexcept override {
+    state = State::ERROR;
+  }
+
+  enum class State { WAITING, SUCCESS, ERROR };
+
+  State state{State::WAITING};
+};
+
+template <class Cardinality>
+MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
+    EventBase* evb,
+    const SocketAddress& address,
+    Cardinality cardinality) {
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket = MockAsyncTFOSSLSocket::UniquePtr(
+      new MockAsyncTFOSSLSocket(sslContext, evb));
+  socket->enableTFO();
+
+  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+      .Times(cardinality)
+      .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
+        sockaddr_storage addr;
+        auto len = address.getAddress(&addr);
+        return connect(fd, (const struct sockaddr*)&addr, len);
+      }));
+  return socket;
+}
+
+TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback, true);
+
+  EventBase evb;
+
+  auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
+
+  evb.runInEventBaseThread([&] { socket->detachEventBase(); });
+  evb.loop();
+
+  BlockingSocket sock(std::move(socket));
+  // write()
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  sock.write(buf.data(), buf.size());
+
+  // read()
+  std::array<uint8_t, 128> readbuf;
+  uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
+  EXPECT_EQ(bytesRead, 128);
+  EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
+
+  // close()
+  sock.close();
+}
+
+TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
+  // Start listening on a local port
+  ConnectTimeoutCallback acceptCallback;
+  TestSSLServer server(&acceptCallback, true);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+
+  // connect
+  auto socket =
+      std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
+  socket->enableTFO();
+  EXPECT_THROW(
+      socket->open(std::chrono::milliseconds(1)), AsyncSocketException);
+}
+
+TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
+  // Start listening on a local port
+  ConnectTimeoutCallback acceptCallback;
+  TestSSLServer server(&acceptCallback, true);
+
+  EventBase evb;
+
+  auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
+  ConnCallback ccb;
+  // Set a short timeout
+  socket->connect(&ccb, server.getAddress(), 1);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+}
+
+#endif
+
 } // namespace
 
 ///////////////////////////////////////////////////////////////////////////