fix flaky ConnectTFOTimeout and ConnectTFOFallbackTimeout tests
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
index 27805c7b3cfb8d8c717916a4fa11d11bad81fcdb..cdacacada0a75d22c407c03b005446edcf351c63 100644 (file)
@@ -1794,9 +1794,11 @@ class ConnCallback : public AsyncSocket::ConnectCallback {
   State state{State::WAITING};
 };
 
+template <class Cardinality>
 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
     EventBase* evb,
-    const SocketAddress& address) {
+    const SocketAddress& address,
+    Cardinality cardinality) {
   // Set up SSL context.
   auto sslContext = std::make_shared<SSLContext>();
 
@@ -1806,6 +1808,7 @@ MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
   socket->enableTFO();
 
   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+      .Times(cardinality)
       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
         sockaddr_storage addr;
         auto len = address.getAddress(&addr);
@@ -1824,7 +1827,7 @@ TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
 
   EventBase evb;
 
-  auto socket = setupSocketWithFallback(&evb, server.getAddress());
+  auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
   ConnCallback ccb;
   socket->connect(&ccb, server.getAddress(), 30);
 
@@ -1852,10 +1855,7 @@ TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
 
 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
   // Start listening on a local port
-  WriteCallbackBase writeCallback;
-  ReadErrorCallback readCallback(&writeCallback);
-  HandshakeCallback handshakeCallback(&readCallback);
-  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  ConnectTimeoutCallback acceptCallback;
   TestSSLServer server(&acceptCallback, true);
 
   // Set up SSL context.
@@ -1871,15 +1871,12 @@ TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
 
 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
   // Start listening on a local port
-  WriteCallbackBase writeCallback;
-  ReadErrorCallback readCallback(&writeCallback);
-  HandshakeCallback handshakeCallback(&readCallback);
-  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  ConnectTimeoutCallback acceptCallback;
   TestSSLServer server(&acceptCallback, true);
 
   EventBase evb;
 
-  auto socket = setupSocketWithFallback(&evb, server.getAddress());
+  auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
   ConnCallback ccb;
   // Set a short timeout
   socket->connect(&ccb, server.getAddress(), 1);