fix flaky ConnectTFOTimeout and ConnectTFOFallbackTimeout tests
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
index 5c0fe43704543d46a7c9da598bb13b2df4f8ebbc..42bb03ac135a7f7dfe71fda0be60a7884e02f6c0 100644 (file)
 #include <folly/io/async/AsyncTransport.h>
 #include <folly/io/async/EventBase.h>
 #include <folly/io/async/ssl/SSLErrors.h>
+#include <folly/portability/Sockets.h>
 #include <folly/portability/Unistd.h>
 
 #include <gtest/gtest.h>
 #include <iostream>
 #include <list>
 #include <fcntl.h>
-#include <poll.h>
 #include <sys/types.h>
-#include <sys/socket.h>
-#include <netinet/tcp.h>
 
 namespace folly {
 
@@ -99,7 +97,7 @@ public AsyncTransportWrapper::ReadCallback {
       : wcb_(wcb), state(STATE_WAITING) {}
 
   ~ReadCallbackBase() {
-    EXPECT_EQ(state, STATE_SUCCEEDED);
+    EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
   void setSocket(
@@ -354,7 +352,7 @@ public:
   }
 
   ~HandshakeCallback() {
-    EXPECT_EQ(state, STATE_SUCCEEDED);
+    EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
   void closeSocket() {
@@ -382,7 +380,7 @@ public:
   state(STATE_WAITING), hcb_(hcb) {}
 
   ~SSLServerAcceptCallbackBase() {
-    EXPECT_EQ(state, STATE_SUCCEEDED);
+    EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
   void acceptError(const std::exception& ex) noexcept override {
@@ -589,6 +587,25 @@ public:
   }
 };
 
+class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
+ public:
+  ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
+    // We don't care if we get invoked or not.
+    // The client may time out and give up before connAccepted() is even
+    // called.
+    state = STATE_SUCCEEDED;
+  }
+
+  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
+  void connAccepted(
+      const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
+    std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
+
+    // Just wait a while before closing the socket, so the client
+    // will time out waiting for the handshake to complete.
+    s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
+  }
+};
 
 class TestSSLServer {
  protected:
@@ -609,7 +626,9 @@ class TestSSLServer {
  public:
   // Create a TestSSLServer.
   // This immediately starts listening on the given port.
-  explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
+  explicit TestSSLServer(
+      SSLServerAcceptCallbackBase* acb,
+      bool enableTFO = false);
 
   // Kill the thread.
   ~TestSSLServer() {