Adds writer test case for RCU
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
index 683b2a6507acfa25a1a9c73cb8b5ad3cf013e490..7a05fbbc8ea5d995dd0751add4d059643fccdaa4 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2012-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 #pragma once
 
 #include <signal.h>
-#include <pthread.h>
 
-#include <folly/io/async/AsyncServerSocket.h>
+#include <folly/ExceptionWrapper.h>
+#include <folly/SocketAddress.h>
+#include <folly/experimental/TestUtil.h>
 #include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncServerSocket.h>
 #include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/AsyncTimeout.h>
 #include <folly/io/async/AsyncTransport.h>
 #include <folly/io/async/EventBase.h>
-#include <folly/io/async/AsyncTimeout.h>
-#include <folly/SocketAddress.h>
+#include <folly/io/async/ssl/SSLErrors.h>
+#include <folly/io/async/test/TestSSLServer.h>
+#include <folly/portability/GTest.h>
+#include <folly/portability/PThread.h>
+#include <folly/portability/Sockets.h>
+#include <folly/portability/Unistd.h>
 
-#include <gtest/gtest.h>
-#include <iostream>
-#include <list>
-#include <unistd.h>
 #include <fcntl.h>
-#include <poll.h>
 #include <sys/types.h>
-#include <sys/socket.h>
-#include <netinet/tcp.h>
+#include <condition_variable>
+#include <iostream>
+#include <list>
+#include <memory>
 
 namespace folly {
 
-enum StateEnum {
-  STATE_WAITING,
-  STATE_SUCCEEDED,
-  STATE_FAILED
-};
-
 // The destructors of all callback classes assert that the state is
 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
 // are responsible for setting the succeeded state properly before the
 // destructors are called.
 
+class SendMsgParamsCallbackBase :
+      public folly::AsyncSocket::SendMsgParamsCallback {
+ public:
+  SendMsgParamsCallbackBase() {}
+
+  void setSocket(
+    const std::shared_ptr<AsyncSSLSocket> &socket) {
+    socket_ = socket;
+    oldCallback_ = socket_->getSendMsgParamsCB();
+    socket_->setSendMsgParamCB(this);
+  }
+
+  int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+                                                                     override {
+    return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
+  }
+
+  void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+    oldCallback_->getAncillaryData(flags, data);
+  }
+
+  uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+    return oldCallback_->getAncillaryDataSize(flags);
+  }
+
+  std::shared_ptr<AsyncSSLSocket> socket_;
+  folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
+};
+
+class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
+ public:
+  SendMsgFlagsCallback() {}
+
+  void resetFlags(int flags) {
+    flags_ = flags;
+  }
+
+  int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+                                                                  override {
+    if (flags_) {
+      return flags_;
+    } else {
+      return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
+    }
+  }
+
+  int flags_{0};
+};
+
+class SendMsgDataCallback : public SendMsgFlagsCallback {
+ public:
+  SendMsgDataCallback() {}
+
+  void resetData(std::vector<char>&& data) {
+    ancillaryData_.swap(data);
+  }
+
+  void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+    if (ancillaryData_.size()) {
+      std::cerr << "getAncillaryData: copying data" << std::endl;
+      memcpy(data, ancillaryData_.data(), ancillaryData_.size());
+    } else {
+      oldCallback_->getAncillaryData(flags, data);
+    }
+  }
+
+  uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+    if (ancillaryData_.size()) {
+      std::cerr << "getAncillaryDataSize: returning size" << std::endl;
+      return ancillaryData_.size();
+    } else {
+      return oldCallback_->getAncillaryDataSize(flags);
+    }
+  }
+
+  std::vector<char> ancillaryData_;
+};
+
 class WriteCallbackBase :
 public AsyncTransportWrapper::WriteCallback {
-public:
-  WriteCallbackBase()
+ public:
+  explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
       : state(STATE_WAITING)
       , bytesWritten(0)
-      , exception(AsyncSocketException::UNKNOWN, "none") {}
+      , exception(AsyncSocketException::UNKNOWN, "none")
+      , mcb_(mcb) {}
 
-  ~WriteCallbackBase() {
-    EXPECT_EQ(state, STATE_SUCCEEDED);
+  ~WriteCallbackBase() override {
+    EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
-  void setSocket(
+  virtual void setSocket(
     const std::shared_ptr<AsyncSSLSocket> &socket) {
     socket_ = socket;
+    if (mcb_) {
+      mcb_->setSocket(socket);
+    }
   }
 
   void writeSuccess() noexcept override {
@@ -72,33 +152,140 @@ public:
   }
 
   void writeErr(
-    size_t bytesWritten,
+    size_t nBytesWritten,
     const AsyncSocketException& ex) noexcept override {
-    std::cerr << "writeError: bytesWritten " << bytesWritten
+    std::cerr << "writeError: bytesWritten " << nBytesWritten
          << ", exception " << ex.what() << std::endl;
 
     state = STATE_FAILED;
-    this->bytesWritten = bytesWritten;
+    this->bytesWritten = nBytesWritten;
     exception = ex;
     socket_->close();
-    socket_->detachEventBase();
   }
 
   std::shared_ptr<AsyncSSLSocket> socket_;
   StateEnum state;
   size_t bytesWritten;
   AsyncSocketException exception;
+  SendMsgParamsCallbackBase* mcb_;
+};
+
+class ExpectWriteErrorCallback :
+public WriteCallbackBase {
+ public:
+  explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
+      : WriteCallbackBase(mcb) {}
+
+  ~ExpectWriteErrorCallback() override {
+    EXPECT_EQ(STATE_FAILED, state);
+    EXPECT_EQ(exception.type_,
+             AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
+    EXPECT_EQ(exception.errno_, 22);
+    // Suppress the assert in  ~WriteCallbackBase()
+    state = STATE_SUCCEEDED;
+  }
+};
+
+#ifdef FOLLY_HAVE_MSG_ERRQUEUE
+/* copied from include/uapi/linux/net_tstamp.h */
+/* SO_TIMESTAMPING gets an integer bit field comprised of these values */
+enum SOF_TIMESTAMPING {
+  SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
+  SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
+  SOF_TIMESTAMPING_OPT_ID = (1 << 7),
+  SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
+  SOF_TIMESTAMPING_TX_ACK = (1 << 9),
+  SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
+};
+
+class WriteCheckTimestampCallback :
+  public WriteCallbackBase {
+ public:
+  explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
+    : WriteCallbackBase(mcb) {}
+
+  ~WriteCheckTimestampCallback() override {
+    EXPECT_EQ(STATE_SUCCEEDED, state);
+    EXPECT_TRUE(gotTimestamp_);
+    EXPECT_TRUE(gotByteSeq_);
+  }
+
+  void setSocket(
+    const std::shared_ptr<AsyncSSLSocket> &socket) override {
+    WriteCallbackBase::setSocket(socket);
+
+    EXPECT_NE(socket_->getFd(), 0);
+    int flags = SOF_TIMESTAMPING_OPT_ID
+                | SOF_TIMESTAMPING_OPT_TSONLY
+                | SOF_TIMESTAMPING_SOFTWARE;
+    AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
+    int ret = tstampingOpt.apply(socket_->getFd(), flags);
+    EXPECT_EQ(ret, 0);
+  }
+
+  void checkForTimestampNotifications() noexcept {
+    int fd = socket_->getFd();
+    std::vector<char> ctrl(1024, 0);
+    unsigned char data;
+    struct msghdr msg;
+    iovec entry;
+
+    memset(&msg, 0, sizeof(msg));
+    entry.iov_base = &data;
+    entry.iov_len = sizeof(data);
+    msg.msg_iov = &entry;
+    msg.msg_iovlen = 1;
+    msg.msg_control = ctrl.data();
+    msg.msg_controllen = ctrl.size();
+
+    int ret;
+    while (true) {
+      ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
+      if (ret < 0) {
+        if (errno != EAGAIN) {
+          auto errnoCopy = errno;
+          std::cerr << "::recvmsg exited with code " << ret
+                    << ", errno: " << errnoCopy << std::endl;
+          AsyncSocketException ex(
+            AsyncSocketException::INTERNAL_ERROR,
+            "recvmsg() failed",
+            errnoCopy);
+          exception = ex;
+        }
+        return;
+      }
+
+      for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+           cmsg != nullptr && cmsg->cmsg_len != 0;
+           cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+        if (cmsg->cmsg_level == SOL_SOCKET &&
+            cmsg->cmsg_type == SCM_TIMESTAMPING) {
+          gotTimestamp_ = true;
+          continue;
+        }
+
+        if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
+            (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
+          gotByteSeq_ = true;
+          continue;
+        }
+      }
+    }
+  }
+
+  bool gotTimestamp_{false};
+  bool gotByteSeq_{false};
 };
+#endif // FOLLY_HAVE_MSG_ERRQUEUE
 
 class ReadCallbackBase :
 public AsyncTransportWrapper::ReadCallback {
-public:
-  explicit ReadCallbackBase(WriteCallbackBase *wcb)
-      : wcb_(wcb)
-      , state(STATE_WAITING) {}
+ public:
+  explicit ReadCallbackBase(WriteCallbackBase* wcb)
+      : wcb_(wcb), state(STATE_WAITING) {}
 
-  ~ReadCallbackBase() {
-    EXPECT_EQ(state, STATE_SUCCEEDED);
+  ~ReadCallbackBase() override {
+    EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
   void setSocket(
@@ -118,14 +305,12 @@ public:
     std::cerr << "readError " << ex.what() << std::endl;
     state = STATE_FAILED;
     socket_->close();
-    socket_->detachEventBase();
   }
 
   void readEOF() noexcept override {
     std::cerr << "readEOF" << std::endl;
 
     socket_->close();
-    socket_->detachEventBase();
   }
 
   std::shared_ptr<AsyncSSLSocket> socket_;
@@ -134,12 +319,12 @@ public:
 };
 
 class ReadCallback : public ReadCallbackBase {
-public:
+ public:
   explicit ReadCallback(WriteCallbackBase *wcb)
       : ReadCallbackBase(wcb)
       , buffers() {}
 
-  ~ReadCallback() {
+  ~ReadCallback() override {
     for (std::vector<Buffer>::iterator it = buffers.begin();
          it != buffers.end();
          ++it) {
@@ -172,7 +357,7 @@ public:
   }
 
   class Buffer {
-  public:
+   public:
     Buffer() : buffer(nullptr), length(0) {}
     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
 
@@ -180,10 +365,10 @@ public:
       buffer = nullptr;
       length = 0;
     }
-    void allocate(size_t length) {
+    void allocate(size_t len) {
       assert(buffer == nullptr);
-      this->buffer = static_cast<char*>(malloc(length));
-      this->length = length;
+      this->buffer = static_cast<char*>(malloc(len));
+      this->length = len;
     }
     void free() {
       ::free(buffer);
@@ -199,7 +384,7 @@ public:
 };
 
 class ReadErrorCallback : public ReadCallbackBase {
-public:
+ public:
   explicit ReadErrorCallback(WriteCallbackBase *wcb)
       : ReadCallbackBase(wcb) {}
 
@@ -209,7 +394,7 @@ public:
     *lenReturn = 0;
   }
 
-  void readDataAvailable(size_t len) noexcept override {
+  void readDataAvailable(size_t /* len */) noexcept override {
     // This should never to called.
     FAIL();
   }
@@ -222,8 +407,29 @@ public:
   }
 };
 
+class ReadEOFCallback : public ReadCallbackBase {
+ public:
+  explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
+
+  // Return nullptr buffer to trigger readError()
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *bufReturn = nullptr;
+    *lenReturn = 0;
+  }
+
+  void readDataAvailable(size_t /* len */) noexcept override {
+    // This should never to called.
+    FAIL();
+  }
+
+  void readEOF() noexcept override {
+    ReadCallbackBase::readEOF();
+    setState(STATE_SUCCEEDED);
+  }
+};
+
 class WriteErrorCallback : public ReadCallback {
-public:
+ public:
   explicit WriteErrorCallback(WriteCallbackBase *wcb)
       : ReadCallback(wcb) {}
 
@@ -238,7 +444,9 @@ public:
     wcb_->setSocket(socket_);
 
     // Write back the same data.
-    socket_->write(wcb_, currentBuffer.buffer, len);
+    folly::test::msvcSuppressAbortOnInvalidParams([&] {
+      socket_->write(wcb_, currentBuffer.buffer, len);
+    });
 
     if (wcb_->state == STATE_FAILED) {
       setState(STATE_SUCCEEDED);
@@ -257,22 +465,23 @@ public:
 };
 
 class EmptyReadCallback : public ReadCallback {
-public:
+ public:
   explicit EmptyReadCallback()
       : ReadCallback(nullptr) {}
 
   void readErr(const AsyncSocketException& ex) noexcept override {
     std::cerr << "readError " << ex.what() << std::endl;
     state = STATE_FAILED;
-    tcpSocket_->close();
-    tcpSocket_->detachEventBase();
+    if (tcpSocket_) {
+      tcpSocket_->close();
+    }
   }
 
   void readEOF() noexcept override {
     std::cerr << "readEOF" << std::endl;
-
-    tcpSocket_->close();
-    tcpSocket_->detachEventBase();
+    if (tcpSocket_) {
+      tcpSocket_->close();
+    }
     state = STATE_SUCCEEDED;
   }
 
@@ -281,7 +490,7 @@ public:
 
 class HandshakeCallback :
 public AsyncSSLSocket::HandshakeCB {
-public:
+ public:
   enum ExpectType {
     EXPECT_SUCCESS,
     EXPECT_ERROR
@@ -305,25 +514,34 @@ public:
 
   // Functions inherited from AsyncSSLSocketHandshakeCallback
   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
+    std::lock_guard<std::mutex> g(mutex_);
+    cv_.notify_all();
     EXPECT_EQ(sock, socket_.get());
     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
     rcb_->setSocket(socket_);
     sock->setReadCB(rcb_);
     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
   }
-  void handshakeErr(
-    AsyncSSLSocket *sock,
-    const AsyncSocketException& ex) noexcept override {
+  void handshakeErr(AsyncSSLSocket* /* sock */,
+                    const AsyncSocketException& ex) noexcept override {
+    std::lock_guard<std::mutex> g(mutex_);
+    cv_.notify_all();
     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
     if (expect_ == EXPECT_ERROR) {
       // rcb will never be invoked
       rcb_->setState(STATE_SUCCEEDED);
     }
+    errorString_ = ex.what();
+  }
+
+  void waitForHandshake() {
+    std::unique_lock<std::mutex> lock(mutex_);
+    cv_.wait(lock, [this] { return state != STATE_WAITING; });
   }
 
-  ~HandshakeCallback() {
-    EXPECT_EQ(state, STATE_SUCCEEDED);
+  ~HandshakeCallback() override {
+    EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
   void closeSocket() {
@@ -331,58 +549,21 @@ public:
     state = STATE_SUCCEEDED;
   }
 
+  std::shared_ptr<AsyncSSLSocket> getSocket() {
+    return socket_;
+  }
+
   StateEnum state;
   std::shared_ptr<AsyncSSLSocket> socket_;
   ReadCallbackBase *rcb_;
   ExpectType expect_;
-};
-
-class SSLServerAcceptCallbackBase:
-public folly::AsyncServerSocket::AcceptCallback {
-public:
-  explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
-  state(STATE_WAITING), hcb_(hcb) {}
-
-  ~SSLServerAcceptCallbackBase() {
-    EXPECT_EQ(state, STATE_SUCCEEDED);
-  }
-
-  void acceptError(const std::exception& ex) noexcept override {
-    std::cerr << "SSLServerAcceptCallbackBase::acceptError "
-              << ex.what() << std::endl;
-    state = STATE_FAILED;
-  }
-
-  void connectionAccepted(int fd, const folly::SocketAddress& clientAddr)
-    noexcept override{
-    printf("Connection accepted\n");
-    std::shared_ptr<AsyncSSLSocket> sslSock;
-    try {
-      // Create a AsyncSSLSocket object with the fd. The socket should be
-      // added to the event base and in the state of accepting SSL connection.
-      sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
-    } catch (const std::exception &e) {
-      LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
-        "object with socket " << e.what() << fd;
-      ::close(fd);
-      acceptError(e);
-      return;
-    }
-
-    connAccepted(sslSock);
-  }
-
-  virtual void connAccepted(
-    const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
-
-  StateEnum state;
-  HandshakeCallback *hcb_;
-  std::shared_ptr<folly::SSLContext> ctx_;
-  folly::EventBase* base_;
+  std::mutex mutex_;
+  std::condition_variable cv_;
+  std::string errorString_;
 };
 
 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
-public:
+ public:
   uint32_t timeout_;
 
   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
@@ -390,7 +571,7 @@ public:
       SSLServerAcceptCallbackBase(hcb),
       timeout_(timeout) {}
 
-  virtual ~SSLServerAcceptCallback() {
+  ~SSLServerAcceptCallback() override {
     if (timeout_ > 0) {
       // if we set a timeout, we expect failure
       EXPECT_EQ(hcb_->state, STATE_FAILED);
@@ -398,7 +579,6 @@ public:
     }
   }
 
-  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
   void connAccepted(
     const std::shared_ptr<folly::AsyncSSLSocket> &s)
     noexcept override {
@@ -406,7 +586,7 @@ public:
     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
 
     hcb_->setSocket(sock);
-    sock->sslAccept(hcb_, timeout_);
+    sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
     EXPECT_EQ(sock->getSSLState(),
                       AsyncSSLSocket::STATE_ACCEPTING);
 
@@ -415,11 +595,10 @@ public:
 };
 
 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
-public:
+ public:
   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
       SSLServerAcceptCallback(hcb) {}
 
-  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
   void connAccepted(
     const std::shared_ptr<folly::AsyncSSLSocket> &s)
     noexcept override {
@@ -456,12 +635,11 @@ public:
 };
 
 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
-public:
+ public:
   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
                                              uint32_t timeout = 0):
     SSLServerAcceptCallback(hcb, timeout) {}
 
-  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
   void connAccepted(
     const std::shared_ptr<folly::AsyncSSLSocket> &s)
     noexcept override {
@@ -470,7 +648,7 @@ public:
     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
 
     hcb_->setSocket(sock);
-    sock->sslAccept(hcb_, timeout_);
+    sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
     ASSERT_TRUE((sock->getSSLState() ==
                  AsyncSSLSocket::STATE_ACCEPTING) ||
                 (sock->getSSLState() ==
@@ -482,11 +660,10 @@ public:
 
 
 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
-public:
+ public:
   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
   SSLServerAcceptCallbackBase(hcb)  {}
 
-  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
   void connAccepted(
     const std::shared_ptr<folly::AsyncSSLSocket> &s)
     noexcept override {
@@ -511,8 +688,6 @@ public:
     EXPECT_EQ(hcb_->state, STATE_FAILED);
     EXPECT_EQ(callback2.state, STATE_FAILED);
 
-    sock->detachEventBase();
-
     state = STATE_SUCCEEDED;
     hcb_->setState(STATE_SUCCEEDED);
     callback2.setState(STATE_SUCCEEDED);
@@ -520,11 +695,10 @@ public:
 };
 
 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
-public:
+ public:
   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
   SSLServerAcceptCallbackBase(hcb)  {}
 
-  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
   void connAccepted(
     const std::shared_ptr<folly::AsyncSSLSocket> &s)
     noexcept override {
@@ -551,41 +725,22 @@ public:
   }
 };
 
-
-class TestSSLServer {
- protected:
-  EventBase evb_;
-  std::shared_ptr<folly::SSLContext> ctx_;
-  SSLServerAcceptCallbackBase *acb_;
-  std::shared_ptr<folly::AsyncServerSocket> socket_;
-  folly::SocketAddress address_;
-  pthread_t thread_;
-
-  static void *Main(void *ctx) {
-    TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
-    self->evb_.loop();
-    std::cerr << "Server thread exited event loop" << std::endl;
-    return nullptr;
-  }
-
+class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
  public:
-  // Create a TestSSLServer.
-  // This immediately starts listening on the given port.
-  explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
-
-  // Kill the thread.
-  ~TestSSLServer() {
-    evb_.runInEventBaseThread([&](){
-      socket_->stopAccepting();
-    });
-    std::cerr << "Waiting for server thread to exit" << std::endl;
-    pthread_join(thread_, nullptr);
+  ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
+    // We don't care if we get invoked or not.
+    // The client may time out and give up before connAccepted() is even
+    // called.
+    state = STATE_SUCCEEDED;
   }
 
-  EventBase &getEventBase() { return evb_; }
+  void connAccepted(
+      const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
+    std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
 
-  const folly::SocketAddress& getAddress() const {
-    return address_;
+    // Just wait a while before closing the socket, so the client
+    // will time out waiting for the handshake to complete.
+    s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
   }
 };
 
@@ -595,8 +750,10 @@ class TestSSLAsyncCacheServer : public TestSSLServer {
         int lookupDelay = 100) :
       TestSSLServer(acb) {
     SSL_CTX *sslCtx = ctx_->getSSLCtx();
+#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
     SSL_CTX_sess_set_get_cb(sslCtx,
                             TestSSLAsyncCacheServer::getSessionCallback);
+#endif
     SSL_CTX_set_session_cache_mode(
       sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
     asyncCallbacks_ = 0;
@@ -612,12 +769,13 @@ class TestSSLAsyncCacheServer : public TestSSLServer {
   static uint32_t asyncLookups_;
   static uint32_t lookupDelay_;
 
-  static SSL_SESSION *getSessionCallback(SSL *ssl,
-                                         unsigned char *sess_id,
-                                         int id_len,
-                                         int *copyflag) {
+  static SSL_SESSION* getSessionCallback(SSL* ssl,
+                                         unsigned char* /* sess_id */,
+                                         int /* id_len */,
+                                         intcopyflag) {
     *copyflag = 0;
     asyncCallbacks_++;
+    (void)ssl;
 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
     if (!SSL_want_sess_cache_lookup(ssl)) {
       // libssl.so mismatch
@@ -667,13 +825,13 @@ class BlockingWriteClient :
       bufLen_(2500),
       iovCount_(2000) {
     // Fill buf_
-    buf_.reset(new uint8_t[bufLen_]);
+    buf_ = std::make_unique<uint8_t[]>(bufLen_);
     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
       buf_[n] = n % 0xff;
     }
 
     // Initialize iov_
-    iov_.reset(new struct iovec[iovCount_]);
+    iov_ = std::make_unique<struct iovec[]>(iovCount_);
     for (uint32_t n = 0; n < iovCount_; ++n) {
       iov_[n].iov_base = buf_.get() + n;
       if (n & 0x1) {
@@ -683,7 +841,7 @@ class BlockingWriteClient :
       }
     }
 
-    socket_->sslConn(this, 100);
+    socket_->sslConn(this, std::chrono::milliseconds(100));
   }
 
   struct iovec* getIovec() const {
@@ -728,8 +886,8 @@ class BlockingWriteServer :
     : socket_(std::move(socket)),
       bufSize_(2500 * 2000),
       bytesRead_(0) {
-    buf_.reset(new uint8_t[bufSize_]);
-    socket_->sslAccept(this, 100);
+    buf_ = std::make_unique<uint8_t[]>(bufSize_);
+    socket_->sslAccept(this, std::chrono::milliseconds(100));
   }
 
   void checkBuffer(struct iovec* iov, uint32_t count) const {
@@ -805,6 +963,7 @@ class NpnClient :
   const unsigned char* nextProto;
   unsigned nextProtoLength;
   SSLContext::NextProtocolType protocolType;
+  folly::Optional<AsyncSocketException> except;
 
  private:
   void handshakeSuc(AsyncSSLSocket*) noexcept override {
@@ -814,7 +973,7 @@ class NpnClient :
   void handshakeErr(
     AsyncSSLSocket*,
     const AsyncSocketException& ex) noexcept override {
-    ADD_FAILURE() << "client handshake error: " << ex.what();
+    except = ex;
   }
   void writeSuccess() noexcept override {
     socket_->close();
@@ -841,6 +1000,7 @@ class NpnServer :
   const unsigned char* nextProto;
   unsigned nextProtoLength;
   SSLContext::NextProtocolType protocolType;
+  folly::Optional<AsyncSocketException> except;
 
  private:
   void handshakeSuc(AsyncSSLSocket*) noexcept override {
@@ -850,13 +1010,12 @@ class NpnServer :
   void handshakeErr(
     AsyncSSLSocket*,
     const AsyncSocketException& ex) noexcept override {
-    ADD_FAILURE() << "server handshake error: " << ex.what();
+    except = ex;
   }
-  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+  void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
     *lenReturn = 0;
   }
-  void readDataAvailable(size_t len) noexcept override {
-  }
+  void readDataAvailable(size_t /* len */) noexcept override {}
   void readEOF() noexcept override {
     socket_->close();
   }
@@ -868,6 +1027,48 @@ class NpnServer :
   AsyncSSLSocket::UniquePtr socket_;
 };
 
+class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
+                            public AsyncTransportWrapper::ReadCallback {
+ public:
+  explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
+      : socket_(std::move(socket)) {
+    socket_->sslAccept(this);
+  }
+
+  ~RenegotiatingServer() override {
+    socket_->setReadCB(nullptr);
+  }
+
+  void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
+    LOG(INFO) << "Renegotiating server handshake success";
+    socket_->setReadCB(this);
+  }
+  void handshakeErr(
+      AsyncSSLSocket*,
+      const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
+  }
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *lenReturn = sizeof(buf);
+    *bufReturn = buf;
+  }
+  void readDataAvailable(size_t /* len */) noexcept override {}
+  void readEOF() noexcept override {}
+  void readErr(const AsyncSocketException& ex) noexcept override {
+    LOG(INFO) << "server got read error " << ex.what();
+    auto exPtr = dynamic_cast<const SSLException*>(&ex);
+    ASSERT_NE(nullptr, exPtr);
+    std::string exStr(ex.what());
+    SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
+    ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
+    renegotiationError_ = true;
+  }
+
+  AsyncSSLSocket::UniquePtr socket_;
+  unsigned char buf[128];
+  bool renegotiationError_{false};
+};
+
 #ifndef OPENSSL_NO_TLSEXT
 class SNIClient :
   private AsyncSSLSocket::HandshakeCB,
@@ -922,17 +1123,16 @@ class SNIServer :
   bool serverNameMatch;
 
  private:
-  void handshakeSuc(AsyncSSLSocket* ssl) noexcept override {}
+  void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
   void handshakeErr(
     AsyncSSLSocket*,
     const AsyncSocketException& ex) noexcept override {
     ADD_FAILURE() << "server handshake error: " << ex.what();
   }
-  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+  void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
     *lenReturn = 0;
   }
-  void readDataAvailable(size_t len) noexcept override {
-  }
+  void readDataAvailable(size_t /* len */) noexcept override {}
   void readEOF() noexcept override {
     socket_->close();
   }
@@ -988,8 +1188,9 @@ class SSLClient : public AsyncSocket::ConnectCallback,
   // socket, even if the maxReadsPerEvent_ is hit during
   // a event loop iteration.
   static constexpr size_t kMaxReadsPerEvent = 2;
+  // 2 event loop iterations
   static constexpr size_t kMaxReadBufferSz =
-    sizeof(readbuf_) / kMaxReadsPerEvent / 2;  // 2 event loop iterations
+    sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
 
  public:
   SSLClient(EventBase *eventBase,
@@ -1012,7 +1213,7 @@ class SSLClient : public AsyncSocket::ConnectCallback,
     memset(buf_, 'a', sizeof(buf_));
   }
 
-  ~SSLClient() {
+  ~SSLClient() override {
     if (session_) {
       SSL_SESSION_free(session_);
     }
@@ -1076,10 +1277,8 @@ class SSLClient : public AsyncSocket::ConnectCallback,
     std::cerr << "client write success" << std::endl;
   }
 
-  void writeErr(
-    size_t bytesWritten,
-    const AsyncSocketException& ex)
-    noexcept override {
+  void writeErr(size_t /* bytesWritten */,
+                const AsyncSocketException& ex) noexcept override {
     std::cerr << "client writeError: " << ex.what() << std::endl;
     if (!sslSocket_) {
       writeAfterConnectErrors_++;
@@ -1131,6 +1330,10 @@ class SSLHandshakeBase :
     verifyResult_(verifyResult) {
   }
 
+  AsyncSSLSocket::UniquePtr moveSocket() && {
+    return std::move(socket_);
+  }
+
   bool handshakeVerify_;
   bool handshakeSuccess_;
   bool handshakeError_;
@@ -1142,10 +1345,9 @@ class SSLHandshakeBase :
   bool verifyResult_;
 
   // HandshakeCallback
-  bool handshakeVer(
-   AsyncSSLSocket* sock,
-   bool preverifyOk,
-   X509_STORE_CTX* ctx) noexcept override {
+  bool handshakeVer(AsyncSSLSocket* /* sock */,
+                    bool preverifyOk,
+                    X509_STORE_CTX* /* ctx */) noexcept override {
     handshakeVerify_ = true;
 
     EXPECT_EQ(preverifyResult_, preverifyOk);
@@ -1153,13 +1355,15 @@ class SSLHandshakeBase :
   }
 
   void handshakeSuc(AsyncSSLSocket*) noexcept override {
+    LOG(INFO) << "Handshake success";
     handshakeSuccess_ = true;
     handshakeTime = socket_->getHandshakeTime();
   }
 
   void handshakeErr(
-   AsyncSSLSocket*,
-   const AsyncSocketException& ex) noexcept override {
+      AsyncSSLSocket*,
+      const AsyncSocketException& ex) noexcept override {
+    LOG(INFO) << "Handshake error " << ex.what();
     handshakeError_ = true;
     handshakeTime = socket_->getHandshakeTime();
   }
@@ -1184,7 +1388,7 @@ class SSLHandshakeClient : public SSLHandshakeBase {
    bool preverifyResult,
    bool verifyResult) :
     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
-    socket_->sslConn(this, 0);
+    socket_->sslConn(this, std::chrono::milliseconds::zero());
   }
 };
 
@@ -1195,8 +1399,10 @@ class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
    bool preverifyResult,
    bool verifyResult) :
     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
-    socket_->sslConn(this, 0,
-      folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+    socket_->sslConn(
+        this,
+        std::chrono::milliseconds::zero(),
+        folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
   }
 };
 
@@ -1207,8 +1413,10 @@ class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
    bool preverifyResult,
    bool verifyResult) :
     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
-    socket_->sslConn(this, 0,
-      folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
+    socket_->sslConn(
+        this,
+        std::chrono::milliseconds::zero(),
+        folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
   }
 };
 
@@ -1219,7 +1427,7 @@ class SSLHandshakeServer : public SSLHandshakeBase {
       bool preverifyResult,
       bool verifyResult)
     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
-    socket_->sslAccept(this, 0);
+    socket_->sslAccept(this, std::chrono::milliseconds::zero());
   }
 };
 
@@ -1231,7 +1439,7 @@ class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
       bool verifyResult)
       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
     socket_->enableClientHelloParsing();
-    socket_->sslAccept(this, 0);
+    socket_->sslAccept(this, std::chrono::milliseconds::zero());
   }
 
   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
@@ -1254,8 +1462,10 @@ class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
       bool preverifyResult,
       bool verifyResult)
     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
-    socket_->sslAccept(this, 0,
-      folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+    socket_->sslAccept(
+        this,
+        std::chrono::milliseconds::zero(),
+        folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
   }
 };
 
@@ -1266,8 +1476,10 @@ class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
       bool preverifyResult,
       bool verifyResult)
     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
-    socket_->sslAccept(this, 0,
-      folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
+    socket_->sslAccept(
+        this,
+        std::chrono::milliseconds::zero(),
+        folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
   }
 };
 
@@ -1290,4 +1502,4 @@ class EventBaseAborter : public AsyncTimeout {
   EventBase* eventBase_;
 };
 
-}
+} // namespace folly