fix flaky ConnectTFOTimeout and ConnectTFOFallbackTimeout tests
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
index 78623f77c92433814e0eb61645043b658798f874..42bb03ac135a7f7dfe71fda0be60a7884e02f6c0 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2016 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 #include <signal.h>
 #include <pthread.h>
 
-#include <folly/io/async/AsyncServerSocket.h>
+#include <folly/ExceptionWrapper.h>
+#include <folly/SocketAddress.h>
 #include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncServerSocket.h>
 #include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/AsyncTimeout.h>
 #include <folly/io/async/AsyncTransport.h>
 #include <folly/io/async/EventBase.h>
-#include <folly/io/async/AsyncTimeout.h>
-#include <folly/SocketAddress.h>
+#include <folly/io/async/ssl/SSLErrors.h>
+#include <folly/portability/Sockets.h>
+#include <folly/portability/Unistd.h>
 
 #include <gtest/gtest.h>
 #include <iostream>
 #include <list>
-#include <unistd.h>
 #include <fcntl.h>
-#include <poll.h>
 #include <sys/types.h>
-#include <sys/socket.h>
-#include <netinet/tcp.h>
 
 namespace folly {
 
@@ -58,7 +58,7 @@ public:
       , exception(AsyncSocketException::UNKNOWN, "none") {}
 
   ~WriteCallbackBase() {
-    EXPECT_EQ(state, STATE_SUCCEEDED);
+    EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
   void setSocket(
@@ -92,13 +92,12 @@ public:
 
 class ReadCallbackBase :
 public AsyncTransportWrapper::ReadCallback {
-public:
-  explicit ReadCallbackBase(WriteCallbackBase *wcb)
-      : wcb_(wcb)
-      , state(STATE_WAITING) {}
+ public:
+  explicit ReadCallbackBase(WriteCallbackBase* wcb)
+      : wcb_(wcb), state(STATE_WAITING) {}
 
   ~ReadCallbackBase() {
-    EXPECT_EQ(state, STATE_SUCCEEDED);
+    EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
   void setSocket(
@@ -209,7 +208,7 @@ public:
     *lenReturn = 0;
   }
 
-  void readDataAvailable(size_t len) noexcept override {
+  void readDataAvailable(size_t /* len */) noexcept override {
     // This should never to called.
     FAIL();
   }
@@ -222,6 +221,27 @@ public:
   }
 };
 
+class ReadEOFCallback : public ReadCallbackBase {
+ public:
+  explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
+
+  // Return nullptr buffer to trigger readError()
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *bufReturn = nullptr;
+    *lenReturn = 0;
+  }
+
+  void readDataAvailable(size_t /* len */) noexcept override {
+    // This should never to called.
+    FAIL();
+  }
+
+  void readEOF() noexcept override {
+    ReadCallbackBase::readEOF();
+    setState(STATE_SUCCEEDED);
+  }
+};
+
 class WriteErrorCallback : public ReadCallback {
 public:
   explicit WriteErrorCallback(WriteCallbackBase *wcb)
@@ -305,25 +325,34 @@ public:
 
   // Functions inherited from AsyncSSLSocketHandshakeCallback
   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
+    std::lock_guard<std::mutex> g(mutex_);
+    cv_.notify_all();
     EXPECT_EQ(sock, socket_.get());
     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
     rcb_->setSocket(socket_);
     sock->setReadCB(rcb_);
     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
   }
-  void handshakeErr(
-    AsyncSSLSocket *sock,
-    const AsyncSocketException& ex) noexcept override {
+  void handshakeErr(AsyncSSLSocket* /* sock */,
+                    const AsyncSocketException& ex) noexcept override {
+    std::lock_guard<std::mutex> g(mutex_);
+    cv_.notify_all();
     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
     if (expect_ == EXPECT_ERROR) {
       // rcb will never be invoked
       rcb_->setState(STATE_SUCCEEDED);
     }
+    errorString_ = ex.what();
+  }
+
+  void waitForHandshake() {
+    std::unique_lock<std::mutex> lock(mutex_);
+    cv_.wait(lock, [this] { return state != STATE_WAITING; });
   }
 
   ~HandshakeCallback() {
-    EXPECT_EQ(state, STATE_SUCCEEDED);
+    EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
   void closeSocket() {
@@ -331,10 +360,17 @@ public:
     state = STATE_SUCCEEDED;
   }
 
+  std::shared_ptr<AsyncSSLSocket> getSocket() {
+    return socket_;
+  }
+
   StateEnum state;
   std::shared_ptr<AsyncSSLSocket> socket_;
   ReadCallbackBase *rcb_;
   ExpectType expect_;
+  std::mutex mutex_;
+  std::condition_variable cv_;
+  std::string errorString_;
 };
 
 class SSLServerAcceptCallbackBase:
@@ -344,7 +380,7 @@ public:
   state(STATE_WAITING), hcb_(hcb) {}
 
   ~SSLServerAcceptCallbackBase() {
-    EXPECT_EQ(state, STATE_SUCCEEDED);
+    EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
   void acceptError(const std::exception& ex) noexcept override {
@@ -353,8 +389,8 @@ public:
     state = STATE_FAILED;
   }
 
-  void connectionAccepted(int fd, const folly::SocketAddress& clientAddr)
-    noexcept override{
+  void connectionAccepted(
+      int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
     printf("Connection accepted\n");
     std::shared_ptr<AsyncSSLSocket> sslSock;
     try {
@@ -551,13 +587,32 @@ public:
   }
 };
 
+class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
+ public:
+  ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
+    // We don't care if we get invoked or not.
+    // The client may time out and give up before connAccepted() is even
+    // called.
+    state = STATE_SUCCEEDED;
+  }
+
+  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
+  void connAccepted(
+      const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
+    std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
+
+    // Just wait a while before closing the socket, so the client
+    // will time out waiting for the handshake to complete.
+    s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
+  }
+};
 
 class TestSSLServer {
  protected:
   EventBase evb_;
   std::shared_ptr<folly::SSLContext> ctx_;
   SSLServerAcceptCallbackBase *acb_;
-  folly::AsyncServerSocket *socket_;
+  std::shared_ptr<folly::AsyncServerSocket> socket_;
   folly::SocketAddress address_;
   pthread_t thread_;
 
@@ -571,7 +626,9 @@ class TestSSLServer {
  public:
   // Create a TestSSLServer.
   // This immediately starts listening on the given port.
-  explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
+  explicit TestSSLServer(
+      SSLServerAcceptCallbackBase* acb,
+      bool enableTFO = false);
 
   // Kill the thread.
   ~TestSSLServer() {
@@ -612,10 +669,10 @@ class TestSSLAsyncCacheServer : public TestSSLServer {
   static uint32_t asyncLookups_;
   static uint32_t lookupDelay_;
 
-  static SSL_SESSION *getSessionCallback(SSL *ssl,
-                                         unsigned char *sess_id,
-                                         int id_len,
-                                         int *copyflag) {
+  static SSL_SESSION* getSessionCallback(SSL* ssl,
+                                         unsigned char* /* sess_id */,
+                                         int /* id_len */,
+                                         intcopyflag) {
     *copyflag = 0;
     asyncCallbacks_++;
 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
@@ -804,10 +861,12 @@ class NpnClient :
 
   const unsigned char* nextProto;
   unsigned nextProtoLength;
+  SSLContext::NextProtocolType protocolType;
+
  private:
   void handshakeSuc(AsyncSSLSocket*) noexcept override {
-    socket_->getSelectedNextProtocol(&nextProto,
-                                     &nextProtoLength);
+    socket_->getSelectedNextProtocol(
+        &nextProto, &nextProtoLength, &protocolType);
   }
   void handshakeErr(
     AsyncSSLSocket*,
@@ -838,21 +897,22 @@ class NpnServer :
 
   const unsigned char* nextProto;
   unsigned nextProtoLength;
+  SSLContext::NextProtocolType protocolType;
+
  private:
   void handshakeSuc(AsyncSSLSocket*) noexcept override {
-    socket_->getSelectedNextProtocol(&nextProto,
-                                     &nextProtoLength);
+    socket_->getSelectedNextProtocol(
+        &nextProto, &nextProtoLength, &protocolType);
   }
   void handshakeErr(
     AsyncSSLSocket*,
     const AsyncSocketException& ex) noexcept override {
     ADD_FAILURE() << "server handshake error: " << ex.what();
   }
-  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+  void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
     *lenReturn = 0;
   }
-  void readDataAvailable(size_t len) noexcept override {
-  }
+  void readDataAvailable(size_t /* len */) noexcept override {}
   void readEOF() noexcept override {
     socket_->close();
   }
@@ -864,6 +924,48 @@ class NpnServer :
   AsyncSSLSocket::UniquePtr socket_;
 };
 
+class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
+                            public AsyncTransportWrapper::ReadCallback {
+ public:
+  explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
+      : socket_(std::move(socket)) {
+    socket_->sslAccept(this);
+  }
+
+  ~RenegotiatingServer() {
+    socket_->setReadCB(nullptr);
+  }
+
+  void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
+    LOG(INFO) << "Renegotiating server handshake success";
+    socket_->setReadCB(this);
+  }
+  void handshakeErr(
+      AsyncSSLSocket*,
+      const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
+  }
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *lenReturn = sizeof(buf);
+    *bufReturn = buf;
+  }
+  void readDataAvailable(size_t /* len */) noexcept override {}
+  void readEOF() noexcept override {}
+  void readErr(const AsyncSocketException& ex) noexcept override {
+    LOG(INFO) << "server got read error " << ex.what();
+    auto exPtr = dynamic_cast<const SSLException*>(&ex);
+    ASSERT_NE(nullptr, exPtr);
+    std::string exStr(ex.what());
+    SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
+    ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
+    renegotiationError_ = true;
+  }
+
+  AsyncSSLSocket::UniquePtr socket_;
+  unsigned char buf[128];
+  bool renegotiationError_{false};
+};
+
 #ifndef OPENSSL_NO_TLSEXT
 class SNIClient :
   private AsyncSSLSocket::HandshakeCB,
@@ -918,17 +1020,16 @@ class SNIServer :
   bool serverNameMatch;
 
  private:
-  void handshakeSuc(AsyncSSLSocket* ssl) noexcept override {}
+  void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
   void handshakeErr(
     AsyncSSLSocket*,
     const AsyncSocketException& ex) noexcept override {
     ADD_FAILURE() << "server handshake error: " << ex.what();
   }
-  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+  void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
     *lenReturn = 0;
   }
-  void readDataAvailable(size_t len) noexcept override {
-  }
+  void readDataAvailable(size_t /* len */) noexcept override {}
   void readEOF() noexcept override {
     socket_->close();
   }
@@ -980,10 +1081,18 @@ class SSLClient : public AsyncSocket::ConnectCallback,
   uint32_t errors_;
   uint32_t writeAfterConnectErrors_;
 
+  // These settings test that we eventually drain the
+  // socket, even if the maxReadsPerEvent_ is hit during
+  // a event loop iteration.
+  static constexpr size_t kMaxReadsPerEvent = 2;
+  static constexpr size_t kMaxReadBufferSz =
+    sizeof(readbuf_) / kMaxReadsPerEvent / 2;  // 2 event loop iterations
+
  public:
   SSLClient(EventBase *eventBase,
             const folly::SocketAddress& address,
-            uint32_t requests, uint32_t timeout = 0)
+            uint32_t requests,
+            uint32_t timeout = 0)
       : eventBase_(eventBase),
         session_(nullptr),
         requests_(requests),
@@ -1046,6 +1155,7 @@ class SSLClient : public AsyncSocket::ConnectCallback,
     }
 
     // write()
+    sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
     sslSocket_->write(this, buf_, sizeof(buf_));
     sslSocket_->setReadCB(this);
     memset(readbuf_, 'b', sizeof(readbuf_));
@@ -1063,10 +1173,8 @@ class SSLClient : public AsyncSocket::ConnectCallback,
     std::cerr << "client write success" << std::endl;
   }
 
-  void writeErr(
-    size_t bytesWritten,
-    const AsyncSocketException& ex)
-    noexcept override {
+  void writeErr(size_t /* bytesWritten */,
+                const AsyncSocketException& ex) noexcept override {
     std::cerr << "client writeError: " << ex.what() << std::endl;
     if (!sslSocket_) {
       writeAfterConnectErrors_++;
@@ -1075,7 +1183,7 @@ class SSLClient : public AsyncSocket::ConnectCallback,
 
   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
     *bufReturn = readbuf_ + bytesRead_;
-    *lenReturn = sizeof(readbuf_) - bytesRead_;
+    *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
   }
 
   void readEOF() noexcept override {
@@ -1090,7 +1198,7 @@ class SSLClient : public AsyncSocket::ConnectCallback,
   void readDataAvailable(size_t len) noexcept override {
     std::cerr << "client read data: " << len << std::endl;
     bytesRead_ += len;
-    if (len == sizeof(buf_)) {
+    if (bytesRead_ == sizeof(buf_)) {
       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
       sslSocket_->closeNow();
       sslSocket_.reset();
@@ -1118,9 +1226,14 @@ class SSLHandshakeBase :
     verifyResult_(verifyResult) {
   }
 
+  AsyncSSLSocket::UniquePtr moveSocket() && {
+    return std::move(socket_);
+  }
+
   bool handshakeVerify_;
   bool handshakeSuccess_;
   bool handshakeError_;
+  std::chrono::nanoseconds handshakeTime;
 
  protected:
   AsyncSSLSocket::UniquePtr socket_;
@@ -1128,10 +1241,9 @@ class SSLHandshakeBase :
   bool verifyResult_;
 
   // HandshakeCallback
-  bool handshakeVer(
-   AsyncSSLSocket* sock,
-   bool preverifyOk,
-   X509_STORE_CTX* ctx) noexcept override {
+  bool handshakeVer(AsyncSSLSocket* /* sock */,
+                    bool preverifyOk,
+                    X509_STORE_CTX* /* ctx */) noexcept override {
     handshakeVerify_ = true;
 
     EXPECT_EQ(preverifyResult_, preverifyOk);
@@ -1139,13 +1251,17 @@ class SSLHandshakeBase :
   }
 
   void handshakeSuc(AsyncSSLSocket*) noexcept override {
+    LOG(INFO) << "Handshake success";
     handshakeSuccess_ = true;
+    handshakeTime = socket_->getHandshakeTime();
   }
 
   void handshakeErr(
-   AsyncSSLSocket*,
-   const AsyncSocketException& ex) noexcept override {
+      AsyncSSLSocket*,
+      const AsyncSocketException& ex) noexcept override {
+    LOG(INFO) << "Handshake error " << ex.what();
     handshakeError_ = true;
+    handshakeTime = socket_->getHandshakeTime();
   }
 
   // WriteCallback