Modernize use of std::make_unique
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
index 8ce1644580d2ce2e3c0dfdc109c7219fff754609..4bb8a9b43525a37d0f80d95d20b4f357f365003b 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 #include <boost/noncopyable.hpp>
 #include <errno.h>
 #include <fcntl.h>
-#include <openssl/err.h>
-#include <openssl/asn1.h>
-#include <openssl/ssl.h>
 #include <sys/types.h>
 #include <chrono>
+#include <memory>
 
 #include <folly/Bits.h>
+#include <folly/Format.h>
 #include <folly/SocketAddress.h>
 #include <folly/SpinLock.h>
-#include <folly/io/IOBuf.h>
 #include <folly/io/Cursor.h>
-#include <folly/portability/Unistd.h>
+#include <folly/io/IOBuf.h>
+#include <folly/portability/OpenSSL.h>
 
 using folly::SocketAddress;
 using folly::SSLContext;
@@ -54,8 +53,11 @@ using folly::AsyncSocketException;
 using folly::AsyncSSLSocket;
 using folly::Optional;
 using folly::SSLContext;
+// For OpenSSL portability API
+using namespace folly::ssl;
 using folly::ssl::OpenSSLUtils;
 
+
 // We have one single dummy SSL context so that we can implement attach
 // and detach methods in a thread safe fashion without modifying opnessl.
 static SSLContext *dummyCtx = nullptr;
@@ -103,14 +105,15 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
 
       timeoutLeft = timeout_ - (curTime - startTime_);
       if (timeoutLeft <= 0) {
-        AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
-                                "SSL connect timed out");
+        AsyncSocketException ex(
+            AsyncSocketException::TIMED_OUT,
+            folly::sformat("SSL connect timed out after {}ms", timeout_));
         fail(ex);
         delete this;
         return;
       }
     }
-    sslSocket_->sslConn(this, timeoutLeft);
+    sslSocket_->sslConn(this, std::chrono::milliseconds(timeoutLeft));
   }
 
   void connectErr(const AsyncSocketException& ex) noexcept override {
@@ -173,15 +176,22 @@ void setup_SSL_CTX(SSL_CTX *ctx) {
 
 }
 
-BIO_METHOD sslWriteBioMethod;
+// Note: This is a Leaky Meyer's Singleton. The reason we can't use a non-leaky
+// thing is because we will be setting this BIO_METHOD* inside BIOs owned by
+// various SSL objects which may get callbacks even during teardown. We may
+// eventually try to fix this
+static BIO_METHOD* getSSLBioMethod() {
+  static auto const instance = OpenSSLUtils::newSocketBioMethod().release();
+  return instance;
+}
 
-void* initsslWriteBioMethod(void) {
-  memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod));
+void* initsslBioMethod() {
+  auto sslBioMethod = getSSLBioMethod();
   // override the bwrite method for MSG_EOR support
-  OpenSSLUtils::setCustomBioWriteMethod(
-      &sslWriteBioMethod, AsyncSSLSocket::bioWrite);
+  OpenSSLUtils::setCustomBioWriteMethod(sslBioMethod, AsyncSSLSocket::bioWrite);
+  OpenSSLUtils::setCustomBioReadMethod(sslBioMethod, AsyncSSLSocket::bioRead);
 
-  // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not
+  // Note that the sslBioMethod.type and sslBioMethod.name are not
   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
   // then have specific handlings. The sslWriteBioWrite should be compatible
   // with the one in openssl.
@@ -191,7 +201,7 @@ void* initsslWriteBioMethod(void) {
   return nullptr;
 }
 
-} // anonymous namespace
+} // namespace
 
 namespace folly {
 
@@ -213,14 +223,39 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
 /**
  * Create a server/client AsyncSSLSocket
  */
-AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
-                               EventBase* evb, int fd, bool server,
-                               bool deferSecurityNegotiation) :
-    AsyncSocket(evb, fd),
-    server_(server),
-    ctx_(ctx),
-    handshakeTimeout_(this, evb),
-    connectionTimeout_(this, evb) {
+AsyncSSLSocket::AsyncSSLSocket(
+    const shared_ptr<SSLContext>& ctx,
+    EventBase* evb,
+    int fd,
+    bool server,
+    bool deferSecurityNegotiation)
+    : AsyncSocket(evb, fd),
+      server_(server),
+      ctx_(ctx),
+      handshakeTimeout_(this, evb),
+      connectionTimeout_(this, evb) {
+  noTransparentTls_ = true;
+  init();
+  if (server) {
+    SSL_CTX_set_info_callback(
+        ctx_->getSSLCtx(), AsyncSSLSocket::sslInfoCallback);
+  }
+  if (deferSecurityNegotiation) {
+    sslState_ = STATE_UNENCRYPTED;
+  }
+}
+
+AsyncSSLSocket::AsyncSSLSocket(
+    const shared_ptr<SSLContext>& ctx,
+    AsyncSocket::UniquePtr oldAsyncSocket,
+    bool server,
+    bool deferSecurityNegotiation)
+    : AsyncSocket(std::move(oldAsyncSocket)),
+      server_(server),
+      ctx_(ctx),
+      handshakeTimeout_(this, AsyncSocket::getEventBase()),
+      connectionTimeout_(this, AsyncSocket::getEventBase()) {
+  noTransparentTls_ = true;
   init();
   if (server) {
     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
@@ -231,7 +266,7 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
   }
 }
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
 /**
  * Create a client AsyncSSLSocket and allow tlsext_hostname
  * to be sent in Client Hello.
@@ -248,14 +283,16 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
  * Create a client AsyncSSLSocket from an already connected fd
  * and allow tlsext_hostname to be sent in Client Hello.
  */
-AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
-                                 EventBase* evb, int fd,
-                               const std::string& serverName,
-                               bool deferSecurityNegotiation) :
-    AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
+AsyncSSLSocket::AsyncSSLSocket(
+    const shared_ptr<SSLContext>& ctx,
+    EventBase* evb,
+    int fd,
+    const std::string& serverName,
+    bool deferSecurityNegotiation)
+    : AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
   tlsextHostname_ = serverName;
 }
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
 
 AsyncSSLSocket::~AsyncSSLSocket() {
   VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
@@ -267,8 +304,8 @@ AsyncSSLSocket::~AsyncSSLSocket() {
 void AsyncSSLSocket::init() {
   // Do this here to ensure we initialize this once before any use of
   // AsyncSSLSocket instances and not as part of library load.
-  static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod();
-  (void)sslWriteBioMethodInitializer;
+  static const auto sslBioMethodInitializer = initsslBioMethod();
+  (void)sslBioMethodInitializer;
 
   setup_SSL_CTX(ctx_->getSSLCtx());
 }
@@ -353,13 +390,9 @@ std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
   return "";
 }
 
-bool AsyncSSLSocket::isEorTrackingEnabled() const {
-  return trackEor_;
-}
-
 void AsyncSSLSocket::setEorTracking(bool track) {
-  if (trackEor_ != track) {
-    trackEor_ = track;
+  if (isEorTrackingEnabled() != track) {
+    AsyncSocket::setEorTracking(track);
     appEorByteNo_ = 0;
     minEorRawByteNo_ = 0;
   }
@@ -375,7 +408,7 @@ size_t AsyncSSLSocket::getRawBytesWritten() const {
     return 0;
   }
   BIO* next = BIO_next(b);
-  while (next != NULL) {
+  while (next != nullptr) {
     b = next;
     next = BIO_next(b);
   }
@@ -411,29 +444,28 @@ void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
     callback->handshakeErr(this, ex);
   }
 
-  // Check the socket state not the ssl state here.
-  if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
-    failHandshake(__func__, ex);
-  }
+  failHandshake(__func__, ex);
 }
 
-void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
-      const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
+void AsyncSSLSocket::sslAccept(
+    HandshakeCB* callback,
+    std::chrono::milliseconds timeout,
+    const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
   verifyPeer_ = verifyPeer;
 
   // Make sure we're in the uninitialized state
-  if (!server_ || (sslState_ != STATE_UNINIT &&
-                   sslState_ != STATE_UNENCRYPTED) ||
+  if (!server_ ||
+      (sslState_ != STATE_UNINIT && sslState_ != STATE_UNENCRYPTED) ||
       handshakeCallback_ != nullptr) {
     return invalidState(callback);
   }
 
   // Cache local and remote socket addresses to keep them available
   // after socket file descriptor is closed.
-  if (cacheAddrOnFailure_ && -1 != getFd()) {
-    cacheLocalPeerAddr();
+  if (cacheAddrOnFailure_) {
+    cacheAddresses();
   }
 
   handshakeStartTime_ = std::chrono::steady_clock::now();
@@ -443,15 +475,16 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
   sslState_ = STATE_ACCEPTING;
   handshakeCallback_ = callback;
 
-  if (timeout > 0) {
+  if (timeout > std::chrono::milliseconds::zero()) {
     handshakeTimeout_.scheduleTimeout(timeout);
   }
 
   /* register for a read operation (waiting for CLIENT HELLO) */
   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
+
+  checkForImmediateRead();
 }
 
-#if OPENSSL_VERSION_NUMBER >= 0x009080bfL
 void AsyncSSLSocket::attachSSLContext(
   const std::shared_ptr<SSLContext>& ctx) {
 
@@ -464,24 +497,52 @@ void AsyncSSLSocket::attachSSLContext(
   DCHECK(ctx->getSSLCtx());
   ctx_ = ctx;
 
+  // It's possible this could be attached before ssl_ is set up
+  if (!ssl_) {
+    return;
+  }
+
   // In order to call attachSSLContext, detachSSLContext must have been
-  // previously called which sets the socket's context to the dummy
-  // context. Thus we must acquire this lock.
+  // previously called.
+  // We need to update the initial_ctx if necessary
+  // The 'initial_ctx' inside an SSL* points to the context that it was created
+  // with, which is also where session callbacks and servername callbacks
+  // happen.
+  // When we switch to a different SSL_CTX, we want to update the initial_ctx as
+  // well so that any callbacks don't go to a different object
+  // NOTE: this will only work if we have access to ssl_ internals, so it may
+  // not work on
+  // OpenSSL version >= 1.1.0
+  auto sslCtx = ctx->getSSLCtx();
+  OpenSSLUtils::setSSLInitialCtx(ssl_, sslCtx);
+  // Detach sets the socket's context to the dummy context. Thus we must acquire
+  // this lock.
   SpinLockGuard guard(dummyCtxLock);
-  SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
+  SSL_set_SSL_CTX(ssl_, sslCtx);
 }
 
 void AsyncSSLSocket::detachSSLContext() {
   DCHECK(ctx_);
   ctx_.reset();
-  // We aren't using the initial_ctx for now, and it can introduce race
-  // conditions in the destructor of the SSL object.
-#ifndef OPENSSL_NO_TLSEXT
-  if (ssl_->initial_ctx) {
-    SSL_CTX_free(ssl_->initial_ctx);
-    ssl_->initial_ctx = nullptr;
+  // It's possible for this to be called before ssl_ has been
+  // set up
+  if (!ssl_) {
+    return;
   }
-#endif
+  // The 'initial_ctx' inside an SSL* points to the context that it was created
+  // with, which is also where session callbacks and servername callbacks
+  // happen.
+  // Detach the initial_ctx as well.  It will be reattached in attachSSLContext
+  // it is used for session info.
+  // NOTE: this will only work if we have access to ssl_ internals, so it may
+  // not work on
+  // OpenSSL version >= 1.1.0
+  SSL_CTX* initialCtx = OpenSSLUtils::getSSLInitialCtx(ssl_);
+  if (initialCtx) {
+    SSL_CTX_free(initialCtx);
+    OpenSSLUtils::setSSLInitialCtx(ssl_, nullptr);
+  }
+
   SpinLockGuard guard(dummyCtxLock);
   if (nullptr == dummyCtx) {
     // We need to lazily initialize the dummy context so we don't
@@ -494,9 +555,8 @@ void AsyncSSLSocket::detachSSLContext() {
   // would not be thread safe.
   SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
 }
-#endif
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
 void AsyncSSLSocket::switchServerSSLContext(
   const std::shared_ptr<SSLContext>& handshakeCtx) {
   CHECK(server_);
@@ -527,22 +587,20 @@ bool AsyncSSLSocket::isServerNameMatch() const {
     return false;
   }
 
-  if(!ss->tlsext_hostname) {
-    return false;
-  }
-  return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
+  auto tlsextHostname = SSL_SESSION_get0_hostname(ss);
+  return (tlsextHostname && !tlsextHostname_.compare(tlsextHostname));
 }
 
 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
   tlsextHostname_ = std::move(serverName);
 }
 
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
 
-void AsyncSSLSocket::timeoutExpired() noexcept {
+void AsyncSSLSocket::timeoutExpired(
+    std::chrono::milliseconds timeout) noexcept {
   if (state_ == StateEnum::ESTABLISHED &&
-      (sslState_ == STATE_CACHE_LOOKUP ||
-       sslState_ == STATE_ASYNC_PENDING)) {
+      (sslState_ == STATE_CACHE_LOOKUP || sslState_ == STATE_ASYNC_PENDING)) {
     sslState_ = STATE_ERROR;
     // We are expecting a callback in restartSSLAccept.  The cache lookup
     // and rsa-call necessarily have pointers to this ssl socket, so delay
@@ -557,9 +615,12 @@ void AsyncSSLSocket::timeoutExpired() noexcept {
     assert(state_ == StateEnum::ESTABLISHED &&
            (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
     DestructorGuard dg(this);
-    AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
-                           (sslState_ == STATE_CONNECTING) ?
-                           "SSL connect timed out" : "SSL accept timed out");
+    AsyncSocketException ex(
+        AsyncSocketException::TIMED_OUT,
+        folly::sformat(
+            "SSL {} timed out after {}ms",
+            (sslState_ == STATE_CONNECTING) ? "connect" : "accept",
+            timeout.count()));
     failHandshake(__func__, ex);
   }
 }
@@ -606,31 +667,43 @@ void AsyncSSLSocket::invokeHandshakeCB() {
   }
 }
 
-void AsyncSSLSocket::cacheLocalPeerAddr() {
-  SocketAddress address;
-  try {
-    getLocalAddress(&address);
-    getPeerAddress(&address);
-  } catch (const std::system_error& e) {
-    // The handle can be still valid while the connection is already closed.
-    if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
-      throw;
-    }
-  }
+void AsyncSSLSocket::connect(
+    ConnectCallback* callback,
+    const folly::SocketAddress& address,
+    int timeout,
+    const OptionMap& options,
+    const folly::SocketAddress& bindAddr) noexcept {
+  auto timeoutChrono = std::chrono::milliseconds(timeout);
+  connect(callback, address, timeoutChrono, timeoutChrono, options, bindAddr);
 }
 
-void AsyncSSLSocket::connect(ConnectCallback* callback,
-                              const folly::SocketAddress& address,
-                              int timeout,
-                              const OptionMap &options,
-                              const folly::SocketAddress& bindAddr)
-                              noexcept {
+void AsyncSSLSocket::connect(
+    ConnectCallback* callback,
+    const folly::SocketAddress& address,
+    std::chrono::milliseconds connectTimeout,
+    std::chrono::milliseconds totalConnectTimeout,
+    const OptionMap& options,
+    const folly::SocketAddress& bindAddr) noexcept {
   assert(!server_);
   assert(state_ == StateEnum::UNINIT);
-  assert(sslState_ == STATE_UNINIT);
-  AsyncSSLSocketConnector *connector =
-    new AsyncSSLSocketConnector(this, callback, timeout);
-  AsyncSocket::connect(connector, address, timeout, options, bindAddr);
+  assert(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED);
+  noTransparentTls_ = true;
+  totalConnectTimeout_ = totalConnectTimeout;
+  if (sslState_ != STATE_UNENCRYPTED) {
+    callback = new AsyncSSLSocketConnector(
+        this, callback, int(totalConnectTimeout.count()));
+  }
+  AsyncSocket::connect(
+      callback, address, int(connectTimeout.count()), options, bindAddr);
+}
+
+bool AsyncSSLSocket::needsPeerVerification() const {
+  if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
+    return ctx_->needsPeerVerification();
+  }
+  return (
+      verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
+      verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
 }
 
 void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
@@ -650,27 +723,29 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
 }
 
 bool AsyncSSLSocket::setupSSLBio() {
-  auto wb = BIO_new(&sslWriteBioMethod);
+  auto sslBio = BIO_new(getSSLBioMethod());
 
-  if (!wb) {
+  if (!sslBio) {
     return false;
   }
 
-  OpenSSLUtils::setBioAppData(wb, this);
-  OpenSSLUtils::setBioFd(wb, fd_, BIO_NOCLOSE);
-  SSL_set_bio(ssl_, wb, wb);
+  OpenSSLUtils::setBioAppData(sslBio, this);
+  OpenSSLUtils::setBioFd(sslBio, fd_, BIO_NOCLOSE);
+  SSL_set_bio(ssl_, sslBio, sslBio);
   return true;
 }
 
-void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
-        const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
+void AsyncSSLSocket::sslConn(
+    HandshakeCB* callback,
+    std::chrono::milliseconds timeout,
+    const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   // Cache local and remote socket addresses to keep them available
   // after socket file descriptor is closed.
-  if (cacheAddrOnFailure_ && -1 != getFd()) {
-    cacheLocalPeerAddr();
+  if (cacheAddrOnFailure_) {
+    cacheAddresses();
   }
 
   verifyPeer_ = verifyPeer;
@@ -711,7 +786,7 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
     SSL_SESSION_free(sslSession_);
     sslSession_ = nullptr;
   }
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
   if (tlsextHostname_.size()) {
     SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
   }
@@ -729,9 +804,8 @@ void AsyncSSLSocket::startSSLConnect() {
   handshakeStartTime_ = std::chrono::steady_clock::now();
   // Make end time at least >= start time.
   handshakeEndTime_ = handshakeStartTime_;
-  if (handshakeConnectTimeout_ > 0) {
-    handshakeTimeout_.scheduleTimeout(
-        std::chrono::milliseconds(handshakeConnectTimeout_));
+  if (handshakeConnectTimeout_ > std::chrono::milliseconds::zero()) {
+    handshakeTimeout_.scheduleTimeout(handshakeConnectTimeout_);
   }
   handleConnect();
 }
@@ -752,7 +826,8 @@ void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
   sslSession_ = session;
   if (!takeOwnership && session != nullptr) {
     // Increment the reference count
-    CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
+    // This API exists in BoringSSL and OpenSSL 1.1.0
+    SSL_SESSION_up_ref(session);
   }
 }
 
@@ -772,7 +847,7 @@ bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
     SSLContext::NextProtocolType* protoType) const {
   *protoName = nullptr;
   *protoLen = 0;
-#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_ALPN
   SSL_get0_alpn_selected(ssl_, protoName, protoLen);
   if (*protoLen > 0) {
     if (protoType) {
@@ -836,7 +911,7 @@ int AsyncSSLSocket::getSSLVersion() const {
 const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
   if (cert) {
-    int nid = OBJ_obj2nid(cert->sig_alg->algorithm);
+    int nid = X509_get_signature_nid(cert);
     return OBJ_nid2ln(nid);
   }
   return nullptr;
@@ -890,14 +965,14 @@ bool AsyncSSLSocket::willBlock(int ret,
     // The timeout (if set) keeps running here
     return true;
 #endif
-  } else if (0
+  } else if ((0
 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
       || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
 #endif
 #ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
       || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
 #endif
-      ) {
+      )) {
     // Our custom openssl function has kicked off an async request to do
     // rsa/ecdsa private key operation.  When that call returns, a callback will
     // be invoked that will re-call handleAccept.
@@ -934,6 +1009,8 @@ void AsyncSSLSocket::checkForImmediateRead() noexcept {
   // the socket to become readable again.
   if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
     AsyncSocket::handleRead();
+  } else {
+    AsyncSocket::checkForImmediateRead();
   }
 }
 
@@ -956,8 +1033,8 @@ AsyncSSLSocket::restartSSLAccept()
   }
   if (sslState_ == STATE_ERROR) {
     // go straight to fail if timeout expired during lookup
-    AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
-                           "SSL accept timed out");
+    AsyncSocketException ex(
+        AsyncSocketException::TIMED_OUT, "SSL accept timed out");
     failHandshake(__func__, ex);
     return;
   }
@@ -1250,7 +1327,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
                << "): client intitiated SSL renegotiation not permitted";
     return ReadResult(
         READ_ERROR,
-        folly::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
+        std::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
   }
   if (bytes <= 0) {
     int error = SSL_get_error(ssl_, bytes);
@@ -1270,12 +1347,12 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
                  << "): unsupported SSL renegotiation during read";
       return ReadResult(
           READ_ERROR,
-          folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
+          std::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
     } else {
       if (zero_return(error, bytes)) {
         return ReadResult(bytes);
       }
-      long errError = ERR_get_error();
+      auto errError = ERR_get_error();
       VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
               << "state=" << state_ << ", "
               << "sslState=" << sslState_ << ", "
@@ -1287,7 +1364,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
               << "reason: " << ERR_reason_error_string(errError);
       return ReadResult(
           READ_ERROR,
-          folly::make_unique<SSLException>(error, errError, bytes, errno));
+          std::make_unique<SSLException>(error, errError, bytes, errno));
     }
   } else {
     appBytesReceived_ += bytes;
@@ -1330,7 +1407,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
                << "unsupported SSL renegotiation during write";
     return WriteResult(
         WRITE_ERROR,
-        folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
+        std::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
   } else {
     if (zero_return(error, rc)) {
       return WriteResult(0);
@@ -1343,7 +1420,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
             << ", reason: " << ERR_reason_error_string(errError);
     return WriteResult(
         WRITE_ERROR,
-        folly::make_unique<SSLException>(error, errError, rc, errno));
+        std::make_unique<SSLException>(error, errError, rc, errno));
   }
 }
 
@@ -1364,7 +1441,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
                << "TODO: AsyncSSLSocket currently does not support calling "
                << "write() before the handshake has fully completed";
     return WriteResult(
-        WRITE_ERROR, folly::make_unique<SSLException>(SSLError::EARLY_WRITE));
+        WRITE_ERROR, std::make_unique<SSLException>(SSLError::EARLY_WRITE));
   }
 
   // Declare a buffer used to hold small write requests.  It could point to a
@@ -1425,8 +1502,13 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
         uint32_t nextIndex = i + buffersStolen + 1;
         bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len,
                                              minWriteSize_ - len);
-        memcpy(combinedBuf + len, vec[nextIndex].iov_base,
-               bytesStolenFromNextBuffer);
+        if (bytesStolenFromNextBuffer > 0) {
+          assert(vec[nextIndex].iov_base != nullptr);
+          ::memcpy(
+              combinedBuf + len,
+              vec[nextIndex].iov_base,
+              bytesStolenFromNextBuffer);
+        }
         len += bytesStolenFromNextBuffer;
         if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
           // couldn't steal the whole buffer
@@ -1494,7 +1576,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
 
 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
                                       bool eor) {
-  if (eor && trackEor_) {
+  if (eor && isEorTrackingEnabled()) {
     if (appEorByteNo_) {
       // cannot track for more than one app byte EOR
       CHECK(appEorByteNo_ == appBytesWritten_ + n);
@@ -1542,11 +1624,10 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
 int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
   struct msghdr msg;
   struct iovec iov;
-  int flags = 0;
   AsyncSSLSocket* tsslSock;
 
   iov.iov_base = const_cast<char*>(in);
-  iov.iov_len = inl;
+  iov.iov_len = size_t(inl);
   memset(&msg, 0, sizeof(msg));
   msg.msg_iov = &iov;
   msg.msg_iovlen = 1;
@@ -1557,23 +1638,29 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
   tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
   CHECK(tsslSock);
 
-  if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ &&
+  WriteFlags flags = WriteFlags::NONE;
+  if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ &&
       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
-    flags = MSG_EOR;
+    flags |= WriteFlags::EOR;
   }
 
-#ifdef MSG_NOSIGNAL
-  flags |= MSG_NOSIGNAL;
-#endif
-
-#ifdef MSG_MORE
   if (tsslSock->corkCurrentWrite_) {
-    flags |= MSG_MORE;
+    flags |= WriteFlags::CORK;
+  }
+
+  int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(
+      flags, false /*zeroCopyEnabled*/);
+  msg.msg_controllen =
+      tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags);
+  CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
+           msg.msg_controllen);
+  if (msg.msg_controllen != 0) {
+    msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
+    tsslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control);
   }
-#endif
 
   auto result = tsslSock->sendSocketMessage(
-      OpenSSLUtils::getBioFd(b, nullptr), &msg, flags);
+      OpenSSLUtils::getBioFd(b, nullptr), &msg, msg_flags);
   BIO_clear_retry_flags(b);
   if (!result.exception && result.writeReturn <= 0) {
     if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) {
@@ -1583,6 +1670,37 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
   return int(result.writeReturn);
 }
 
+int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) {
+  if (!out) {
+    return 0;
+  }
+  BIO_clear_retry_flags(b);
+
+  auto appData = OpenSSLUtils::getBioAppData(b);
+  CHECK(appData);
+  auto sslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
+
+  if (sslSock->preReceivedData_ && !sslSock->preReceivedData_->empty()) {
+    VLOG(5) << "AsyncSSLSocket::bioRead() this=" << sslSock
+            << ", reading pre-received data";
+
+    Cursor cursor(sslSock->preReceivedData_.get());
+    auto len = cursor.pullAtMost(out, outl);
+
+    IOBufQueue queue;
+    queue.append(std::move(sslSock->preReceivedData_));
+    queue.trimStart(len);
+    sslSock->preReceivedData_ = queue.move();
+    return static_cast<int>(len);
+  } else {
+    auto result = int(recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0));
+    if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
+      BIO_set_retry_read(b);
+    }
+    return result;
+  }
+}
+
 int AsyncSSLSocket::sslVerifyCallback(
     int preverifyOk,
     X509_STORE_CTX* x509Ctx) {
@@ -1599,7 +1717,7 @@ int AsyncSSLSocket::sslVerifyCallback(
 
 void AsyncSSLSocket::enableClientHelloParsing()  {
     parseClientHello_ = true;
-    clientHelloInfo_.reset(new ssl::ClientHelloInfo());
+    clientHelloInfo_ = std::make_unique<ssl::ClientHelloInfo>();
 }
 
 void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {
@@ -1710,7 +1828,7 @@ void AsyncSSLSocket::clientHelloParsingCallback(int written,
         }
       }
     }
-  } catch (std::out_of_range& e) {
+  } catch (std::out_of_range&) {
     // we'll use what we found and cleanup below.
     VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
       << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock;
@@ -1818,6 +1936,14 @@ std::string AsyncSSLSocket::getSSLAlertsReceived() const {
   return ret;
 }
 
+void AsyncSSLSocket::setSSLCertVerificationAlert(std::string alert) {
+  sslVerificationAlert_ = std::move(alert);
+}
+
+std::string AsyncSSLSocket::getSSLCertVerificationAlert() const {
+  return sslVerificationAlert_;
+}
+
 void AsyncSSLSocket::getSSLSharedCiphers(std::string& sharedCiphers) const {
   char ciphersBuffer[1024];
   ciphersBuffer[0] = '\0';