Always override write bio method
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
index a6e3307cb832ddea24297f0292e09ba3540f381a..4fa6b6fef26b57f6cbb62cb1c083a07f40fa1e3b 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2016 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 #include <folly/io/async/AsyncSSLSocket.h>
 
 #include <folly/io/async/EventBase.h>
+#include <folly/portability/Sockets.h>
 
 #include <boost/noncopyable.hpp>
 #include <errno.h>
 #include <fcntl.h>
-#include <netinet/in.h>
-#include <netinet/tcp.h>
 #include <openssl/err.h>
 #include <openssl/asn1.h>
 #include <openssl/ssl.h>
 #include <sys/types.h>
-#include <sys/socket.h>
-#include <unistd.h>
 #include <chrono>
 
 #include <folly/Bits.h>
@@ -36,6 +33,7 @@
 #include <folly/SpinLock.h>
 #include <folly/io/IOBuf.h>
 #include <folly/io/Cursor.h>
+#include <folly/portability/Unistd.h>
 
 using folly::SocketAddress;
 using folly::SSLContext;
@@ -55,16 +53,13 @@ using folly::AsyncSocket;
 using folly::AsyncSocketException;
 using folly::AsyncSSLSocket;
 using folly::Optional;
+using folly::SSLContext;
 
 // We have one single dummy SSL context so that we can implement attach
 // and detach methods in a thread safe fashion without modifying opnessl.
 static SSLContext *dummyCtx = nullptr;
 static SpinLock dummyCtxLock;
 
-// Numbers chosen as to not collide with functions in ssl.h
-const uint8_t TASYNCSSLSOCKET_F_PERFORM_READ = 90;
-const uint8_t TASYNCSSLSOCKET_F_PERFORM_WRITE = 91;
-
 // If given min write size is less than this, buffer will be allocated on
 // stack, otherwise it is allocated on heap
 const size_t MAX_STACK_BUF_SIZE = 2048;
@@ -84,8 +79,7 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
   int64_t startTime_;
 
  protected:
-  virtual ~AsyncSSLSocketConnector() {
-  }
+  ~AsyncSSLSocketConnector() override {}
 
  public:
   AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
@@ -98,7 +92,7 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
                    std::chrono::steady_clock::now().time_since_epoch()).count()) {
   }
 
-  virtual void connectSuccess() noexcept {
+  void connectSuccess() noexcept override {
     VLOG(7) << "client socket connected";
 
     int64_t timeoutLeft = 0;
@@ -118,13 +112,13 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
     sslSocket_->sslConn(this, timeoutLeft);
   }
 
-  virtual void connectErr(const AsyncSocketException& ex) noexcept {
+  void connectErr(const AsyncSocketException& ex) noexcept override {
     LOG(ERROR) << "TCP connect failed: " <<  ex.what();
     fail(ex);
     delete this;
   }
 
-  virtual void handshakeSuc(AsyncSSLSocket *sock) noexcept {
+  void handshakeSuc(AsyncSSLSocket* /* sock */) noexcept override {
     VLOG(7) << "client handshake success";
     if (callback_) {
       callback_->connectSuccess();
@@ -132,8 +126,8 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
     delete this;
   }
 
-  virtual void handshakeErr(AsyncSSLSocket *socket,
-                              const AsyncSocketException& ex) noexcept {
+  void handshakeErr(AsyncSSLSocket* /* socket */,
+                    const AsyncSocketException& ex) noexcept override {
     LOG(ERROR) << "client handshakeErr: " << ex.what();
     fail(ex);
     delete this;
@@ -220,57 +214,68 @@ void setup_SSL_CTX(SSL_CTX *ctx) {
                    SSL_MODE_ENABLE_PARTIAL_WRITE
                    );
 #endif
+// SSL_CTX_set_mode is a Macro
+#ifdef SSL_MODE_WRITE_IOVEC
+  SSL_CTX_set_mode(ctx,
+                   SSL_CTX_get_mode(ctx)
+                   | SSL_MODE_WRITE_IOVEC);
+#endif
+
 }
 
-BIO_METHOD eorAwareBioMethod;
+BIO_METHOD sslWriteBioMethod;
 
-__attribute__((__constructor__))
-void initEorBioMethod(void) {
-  memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod));
+void* initsslWriteBioMethod(void) {
+  memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod));
   // override the bwrite method for MSG_EOR support
-  eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite;
+  sslWriteBioMethod.bwrite = AsyncSSLSocket::bioWrite;
 
-  // Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not
+  // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not
   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
-  // then have specific handlings. The eorAwareBioWrite should be compatible
+  // then have specific handlings. The sslWriteBioWrite should be compatible
   // with the one in openssl.
+
+  // Return something here to enable AsyncSSLSocket to call this method using
+  // a function-scoped static.
+  return nullptr;
 }
 
 } // anonymous namespace
 
 namespace folly {
 
-SSLException::SSLException(int sslError, int errno_copy):
-    AsyncSocketException(
-      AsyncSocketException::SSL_ERROR,
-      ERR_error_string(sslError, msg_),
-      sslError == SSL_ERROR_SYSCALL ? errno_copy : 0), error_(sslError) {}
-
 /**
  * Create a client AsyncSSLSocket
  */
 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
-                                 EventBase* evb) :
+                               EventBase* evb, bool deferSecurityNegotiation) :
     AsyncSocket(evb),
     ctx_(ctx),
     handshakeTimeout_(this, evb) {
-  setup_SSL_CTX(ctx_->getSSLCtx());
+  init();
+  if (deferSecurityNegotiation) {
+    sslState_ = STATE_UNENCRYPTED;
+  }
 }
 
 /**
  * Create a server/client AsyncSSLSocket
  */
 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
-                                 EventBase* evb, int fd, bool server) :
+                               EventBase* evb, int fd, bool server,
+                               bool deferSecurityNegotiation) :
     AsyncSocket(evb, fd),
     server_(server),
     ctx_(ctx),
     handshakeTimeout_(this, evb) {
-  setup_SSL_CTX(ctx_->getSSLCtx());
+  init();
   if (server) {
     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
                               AsyncSSLSocket::sslInfoCallback);
   }
+  if (deferSecurityNegotiation) {
+    sslState_ = STATE_UNENCRYPTED;
+  }
 }
 
 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
@@ -280,12 +285,10 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
  */
 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
                                  EventBase* evb,
-                                 const std::string& serverName) :
-    AsyncSocket(evb),
-    ctx_(ctx),
-    handshakeTimeout_(this, evb),
-    tlsextHostname_(serverName) {
-  setup_SSL_CTX(ctx_->getSSLCtx());
+                               const std::string& serverName,
+                               bool deferSecurityNegotiation) :
+    AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
+  tlsextHostname_ = serverName;
 }
 
 /**
@@ -294,12 +297,10 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
  */
 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
                                  EventBase* evb, int fd,
-                                 const std::string& serverName) :
-    AsyncSocket(evb, fd),
-    ctx_(ctx),
-    handshakeTimeout_(this, evb),
-    tlsextHostname_(serverName) {
-  setup_SSL_CTX(ctx_->getSSLCtx());
+                               const std::string& serverName,
+                               bool deferSecurityNegotiation) :
+    AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
+  tlsextHostname_ = serverName;
 }
 #endif
 
@@ -310,6 +311,15 @@ AsyncSSLSocket::~AsyncSSLSocket() {
           << sslState_ << ", events=" << eventFlags_ << ")";
 }
 
+void AsyncSSLSocket::init() {
+  // Do this here to ensure we initialize this once before any use of
+  // AsyncSSLSocket instances and not as part of library load.
+  static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod();
+  (void)sslWriteBioMethodInitializer;
+
+  setup_SSL_CTX(ctx_->getSSLCtx());
+}
+
 void AsyncSSLSocket::closeNow() {
   // Close the SSL connection.
   if (ssl_ != nullptr && fd_ != -1) {
@@ -335,13 +345,10 @@ void AsyncSSLSocket::closeNow() {
 
   DestructorGuard dg(this);
 
-  if (handshakeCallback_) {
-    AsyncSocketException ex(AsyncSocketException::END_OF_FILE,
-                           "SSL connection closed locally");
-    HandshakeCB* callback = handshakeCallback_;
-    handshakeCallback_ = nullptr;
-    callback->handshakeErr(this, ex);
-  }
+  invokeHandshakeErr(
+      AsyncSocketException(
+        AsyncSocketException::END_OF_FILE,
+        "SSL connection closed locally"));
 
   if (ssl_ != nullptr) {
     SSL_free(ssl_);
@@ -370,7 +377,7 @@ void AsyncSSLSocket::shutdownWriteNow() {
 bool AsyncSSLSocket::good() const {
   return (AsyncSocket::good() &&
           (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
-           sslState_ == STATE_ESTABLISHED));
+           sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
 }
 
 // The TAsyncTransport definition of 'good' states that the transport is
@@ -384,34 +391,24 @@ bool AsyncSSLSocket::connecting() const {
                                      sslState_ == STATE_CONNECTING))));
 }
 
+std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
+  const unsigned char* protoName = nullptr;
+  unsigned protoLength;
+  if (getSelectedNextProtocolNoThrow(&protoName, &protoLength)) {
+    return std::string(reinterpret_cast<const char*>(protoName), protoLength);
+  }
+  return "";
+}
+
 bool AsyncSSLSocket::isEorTrackingEnabled() const {
-  const BIO *wb = SSL_get_wbio(ssl_);
-  return wb && wb->method == &eorAwareBioMethod;
+  return trackEor_;
 }
 
 void AsyncSSLSocket::setEorTracking(bool track) {
-  BIO *wb = SSL_get_wbio(ssl_);
-  if (!wb) {
-    throw AsyncSocketException(AsyncSocketException::INVALID_STATE,
-                              "setting EOR tracking without an initialized "
-                              "BIO");
-  }
-
-  if (track) {
-    if (wb->method != &eorAwareBioMethod) {
-      // only do this if we didn't
-      wb->method = &eorAwareBioMethod;
-      BIO_set_app_data(wb, this);
-      appEorByteNo_ = 0;
-      minEorRawByteNo_ = 0;
-    }
-  } else if (wb->method == &eorAwareBioMethod) {
-    wb->method = BIO_s_socket();
-    BIO_set_app_data(wb, nullptr);
+  if (trackEor_ != track) {
+    trackEor_ = track;
     appEorByteNo_ = 0;
     minEorRawByteNo_ = 0;
-  } else {
-    CHECK(wb->method == BIO_s_socket());
   }
 }
 
@@ -447,6 +444,7 @@ void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
   AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
                          "sslAccept() called with socket in invalid state");
 
+  handshakeEndTime_ = std::chrono::steady_clock::now();
   if (callback) {
     callback->handshakeErr(this, ex);
   }
@@ -464,10 +462,22 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
   verifyPeer_ = verifyPeer;
 
   // Make sure we're in the uninitialized state
-  if (!server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
+  if (!server_ || (sslState_ != STATE_UNINIT &&
+                   sslState_ != STATE_UNENCRYPTED) ||
+      handshakeCallback_ != nullptr) {
     return invalidState(callback);
   }
 
+  // Cache local and remote socket addresses to keep them available
+  // after socket file descriptor is closed.
+  if (cacheAddrOnFailure_ && -1 != getFd()) {
+    cacheLocalPeerAddr();
+  }
+
+  handshakeStartTime_ = std::chrono::steady_clock::now();
+  // Make end time at least >= start time.
+  handshakeEndTime_ = handshakeStartTime_;
+
   sslState_ = STATE_ACCEPTING;
   handshakeCallback_ = callback;
 
@@ -555,7 +565,10 @@ bool AsyncSSLSocket::isServerNameMatch() const {
     return false;
   }
 
-  return (ss->tlsext_hostname ? true : false);
+  if(!ss->tlsext_hostname) {
+    return false;
+  }
+  return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
 }
 
 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
@@ -567,7 +580,7 @@ void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
 void AsyncSSLSocket::timeoutExpired() noexcept {
   if (state_ == StateEnum::ESTABLISHED &&
       (sslState_ == STATE_CACHE_LOOKUP ||
-       sslState_ == STATE_RSA_ASYNC_PENDING)) {
+       sslState_ == STATE_ASYNC_PENDING)) {
     sslState_ = STATE_ERROR;
     // We are expecting a callback in restartSSLAccept.  The cache lookup
     // and rsa-call necessarily have pointers to this ssl socket, so delay
@@ -583,18 +596,10 @@ void AsyncSSLSocket::timeoutExpired() noexcept {
   }
 }
 
-int AsyncSSLSocket::sslExDataIndex_ = -1;
-std::mutex AsyncSSLSocket::mutex_;
-
 int AsyncSSLSocket::getSSLExDataIndex() {
-  if (sslExDataIndex_ < 0) {
-    std::lock_guard<std::mutex> g(mutex_);
-    if (sslExDataIndex_ < 0) {
-      sslExDataIndex_ = SSL_get_ex_new_index(0,
-          (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
-    }
-  }
-  return sslExDataIndex_;
+  static auto index = SSL_get_ex_new_index(
+      0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
+  return index;
 }
 
 AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
@@ -602,23 +607,27 @@ AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
       getSSLExDataIndex()));
 }
 
-void AsyncSSLSocket::failHandshake(const char* fn,
-                                    const AsyncSocketException& ex) {
+void AsyncSSLSocket::failHandshake(const char* /* fn */,
+                                   const AsyncSocketException& ex) {
   startFail();
-
   if (handshakeTimeout_.isScheduled()) {
     handshakeTimeout_.cancelTimeout();
   }
+  invokeHandshakeErr(ex);
+  finishFail();
+}
+
+void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
+  handshakeEndTime_ = std::chrono::steady_clock::now();
   if (handshakeCallback_ != nullptr) {
     HandshakeCB* callback = handshakeCallback_;
     handshakeCallback_ = nullptr;
     callback->handshakeErr(this, ex);
   }
-
-  finishFail();
 }
 
 void AsyncSSLSocket::invokeHandshakeCB() {
+  handshakeEndTime_ = std::chrono::steady_clock::now();
   if (handshakeTimeout_.isScheduled()) {
     handshakeTimeout_.cancelTimeout();
   }
@@ -629,6 +638,19 @@ void AsyncSSLSocket::invokeHandshakeCB() {
   }
 }
 
+void AsyncSSLSocket::cacheLocalPeerAddr() {
+  SocketAddress address;
+  try {
+    getLocalAddress(&address);
+    getPeerAddress(&address);
+  } catch (const std::system_error& e) {
+    // The handle can be still valid while the connection is already closed.
+    if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
+      throw;
+    }
+  }
+}
+
 void AsyncSSLSocket::connect(ConnectCallback* callback,
                               const folly::SocketAddress& address,
                               int timeout,
@@ -659,18 +681,43 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
   }
 }
 
+bool AsyncSSLSocket::setupSSLBio() {
+  auto wb = BIO_new(&sslWriteBioMethod);
+
+  if (!wb) {
+    return false;
+  }
+
+  BIO_set_app_data(wb, this);
+  BIO_set_fd(wb, fd_, BIO_NOCLOSE);
+  SSL_set_bio(ssl_, wb, wb);
+  return true;
+}
+
 void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
         const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
   DestructorGuard dg(this);
   assert(eventBase_->isInEventBaseThread());
 
+  // Cache local and remote socket addresses to keep them available
+  // after socket file descriptor is closed.
+  if (cacheAddrOnFailure_ && -1 != getFd()) {
+    cacheLocalPeerAddr();
+  }
+
   verifyPeer_ = verifyPeer;
 
   // Make sure we're in the uninitialized state
-  if (server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
+  if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
+                  STATE_UNENCRYPTED) ||
+      handshakeCallback_ != nullptr) {
     return invalidState(callback);
   }
 
+  handshakeStartTime_ = std::chrono::steady_clock::now();
+  // Make end time at least >= start time.
+  handshakeEndTime_ = handshakeStartTime_;
+
   sslState_ = STATE_CONNECTING;
   handshakeCallback_ = callback;
 
@@ -685,9 +732,15 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
     return failHandshake(__func__, ex);
   }
 
+  if (!setupSSLBio()) {
+    sslState_ = STATE_ERROR;
+    AsyncSocketException ex(
+        AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio");
+    return failHandshake(__func__, ex);
+  }
+
   applyVerificationOptions(ssl_);
 
-  SSL_set_fd(ssl_, fd_);
   if (sslSession_ != nullptr) {
     SSL_set_session(ssl_, sslSession_);
     SSL_SESSION_free(sslSession_);
@@ -716,6 +769,10 @@ SSL_SESSION *AsyncSSLSocket::getSSLSession() {
   return sslSession_;
 }
 
+const SSL* AsyncSSLSocket::getSSL() const {
+  return ssl_;
+}
+
 void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
   sslSession_ = session;
   if (!takeOwnership && session != nullptr) {
@@ -724,23 +781,39 @@ void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
   }
 }
 
-void AsyncSSLSocket::getSelectedNextProtocol(const unsigned char** protoName,
-    unsigned* protoLen) const {
-  if (!getSelectedNextProtocolNoThrow(protoName, protoLen)) {
+void AsyncSSLSocket::getSelectedNextProtocol(
+    const unsigned char** protoName,
+    unsigned* protoLen,
+    SSLContext::NextProtocolType* protoType) const {
+  if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) {
     throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
                               "NPN not supported");
   }
 }
 
 bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
-  const unsigned char** protoName,
-  unsigned* protoLen) const {
+    const unsigned char** protoName,
+    unsigned* protoLen,
+    SSLContext::NextProtocolType* protoType) const {
   *protoName = nullptr;
   *protoLen = 0;
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+  SSL_get0_alpn_selected(ssl_, protoName, protoLen);
+  if (*protoLen > 0) {
+    if (protoType) {
+      *protoType = SSLContext::NextProtocolType::ALPN;
+    }
+    return true;
+  }
+#endif
 #ifdef OPENSSL_NPN_NEGOTIATED
   SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
+  if (protoType) {
+    *protoType = SSLContext::NextProtocolType::NPN;
+  }
   return true;
 #else
+  (void)protoType;
   return false;
 #endif
 }
@@ -756,28 +829,44 @@ const char *AsyncSSLSocket::getNegotiatedCipherName() const {
   return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
 }
 
+/* static */
+const char* AsyncSSLSocket::getSSLServerNameFromSSL(SSL* ssl) {
+  if (ssl == nullptr) {
+    return nullptr;
+  }
+#ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
+  return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
+#else
+  return nullptr;
+#endif
+}
+
 const char *AsyncSSLSocket::getSSLServerName() const {
 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
-  return (ssl_ != nullptr) ? SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name)
-        : nullptr;
+  return getSSLServerNameFromSSL(ssl_);
 #else
   throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
-                            "SNI not supported");
+                             "SNI not supported");
 #endif
 }
 
 const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
-  try {
-    return getSSLServerName();
-  } catch (AsyncSocketException& ex) {
-    return nullptr;
-  }
+  return getSSLServerNameFromSSL(ssl_);
 }
 
 int AsyncSSLSocket::getSSLVersion() const {
   return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
 }
 
+const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
+  X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
+  if (cert) {
+    int nid = OBJ_obj2nid(cert->sig_alg->algorithm);
+    return OBJ_nid2ln(nid);
+  }
+  return nullptr;
+}
+
 int AsyncSSLSocket::getSSLCertSize() const {
   int certSize = 0;
   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
@@ -789,8 +878,11 @@ int AsyncSSLSocket::getSSLCertSize() const {
   return certSize;
 }
 
-bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
-  int error = *errorOut = SSL_get_error(ssl_, ret);
+bool AsyncSSLSocket::willBlock(int ret,
+                               int* sslErrorOut,
+                               unsigned long* errErrorOut) noexcept {
+  *errErrorOut = 0;
+  int error = *sslErrorOut = SSL_get_error(ssl_, ret);
   if (error == SSL_ERROR_WANT_READ) {
     // Register for read event if not already.
     updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
@@ -819,12 +911,18 @@ bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
     // The timeout (if set) keeps running here
     return true;
 #endif
+  } else if (0
 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
-  } else if (error == SSL_ERROR_WANT_RSA_ASYNC_PENDING) {
+      || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
+#endif
+#ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
+      || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
+#endif
+      ) {
     // Our custom openssl function has kicked off an async request to do
-    // modular exponentiation.  When that call returns, a callback will
+    // rsa/ecdsa private key operation.  When that call returns, a callback will
     // be invoked that will re-call handleAccept.
-    sslState_ = STATE_RSA_ASYNC_PENDING;
+    sslState_ = STATE_ASYNC_PENDING;
 
     // Unregister for all events while blocked here
     updateEventRegistration(
@@ -834,11 +932,8 @@ bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
 
     // The timeout (if set) keeps running here
     return true;
-#endif
   } else {
-    // SSL_ERROR_ZERO_RETURN is processed here so we can get some detail
-    // in the log
-    long lastError = ERR_get_error();
+    unsigned long lastError = *errErrorOut = ERR_get_error();
     VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
             << "state=" << state_ << ", "
             << "sslState=" << sslState_ << ", "
@@ -850,17 +945,6 @@ bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
             << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
             << "func: " << ERR_func_error_string(lastError) << ", "
             << "reason: " << ERR_reason_error_string(lastError);
-    if (error != SSL_ERROR_SYSCALL) {
-      if (error == SSL_ERROR_SSL) {
-        *errorOut = lastError;
-      }
-      if ((unsigned long)lastError < 0x8000) {
-        errno = ENOSYS;
-      } else {
-        errno = lastError;
-      }
-    }
-    ERR_clear_error();
     return false;
   }
 }
@@ -883,7 +967,7 @@ AsyncSSLSocket::restartSSLAccept()
   DestructorGuard dg(this);
   assert(
     sslState_ == STATE_CACHE_LOOKUP ||
-    sslState_ == STATE_RSA_ASYNC_PENDING ||
+    sslState_ == STATE_ASYNC_PENDING ||
     sslState_ == STATE_ERROR ||
     sslState_ == STATE_CLOSED
   );
@@ -923,26 +1007,34 @@ AsyncSSLSocket::handleAccept() noexcept {
                  << ", fd=" << fd_ << "): " << e.what();
       return failHandshake(__func__, ex);
     }
-    SSL_set_fd(ssl_, fd_);
+
+    if (!setupSSLBio()) {
+      sslState_ = STATE_ERROR;
+      AsyncSocketException ex(
+          AsyncSocketException::INTERNAL_ERROR, "error creating write bio");
+      return failHandshake(__func__, ex);
+    }
+
     SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
 
     applyVerificationOptions(ssl_);
   }
 
   if (server_ && parseClientHello_) {
-    SSL_set_msg_callback_arg(ssl_, this);
     SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
+    SSL_set_msg_callback_arg(ssl_, this);
   }
 
-  errno = 0;
   int ret = SSL_accept(ssl_);
   if (ret <= 0) {
-    int error;
-    if (willBlock(ret, &error)) {
+    int sslError;
+    unsigned long errError;
+    int errnoCopy = errno;
+    if (willBlock(ret, &sslError, &errError)) {
       return;
     } else {
       sslState_ = STATE_ERROR;
-      SSLException ex(error, errno);
+      SSLException ex(sslError, errError, ret, errnoCopy);
       return failHandshake(__func__, ex);
     }
   }
@@ -996,15 +1088,16 @@ AsyncSSLSocket::handleConnect() noexcept {
          sslState_ == STATE_CONNECTING);
   assert(ssl_);
 
-  errno = 0;
   int ret = SSL_connect(ssl_);
   if (ret <= 0) {
-    int error;
-    if (willBlock(ret, &error)) {
+    int sslError;
+    unsigned long errError;
+    int errnoCopy = errno;
+    if (willBlock(ret, &sslError, &errError)) {
       return;
     } else {
       sslState_ = STATE_ERROR;
-      SSLException ex(error, errno);
+      SSLException ex(sslError, errError, ret, errnoCopy);
       return failHandshake(__func__, ex);
     }
   }
@@ -1016,7 +1109,8 @@ AsyncSSLSocket::handleConnect() noexcept {
   // STATE_CONNECTING.
   sslState_ = STATE_ESTABLISHED;
 
-  VLOG(3) << "AsyncSSLSocket %p: fd %d successfully connected; "
+  VLOG(3) << "AsyncSSLSocket " << this << ": "
+          << "fd " << fd_ << " successfully connected; "
           << "state=" << int(state_) << ", sslState=" << sslState_
           << ", events=" << eventFlags_;
 
@@ -1044,6 +1138,34 @@ AsyncSSLSocket::handleConnect() noexcept {
   AsyncSocket::handleInitialReadWrite();
 }
 
+void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
+#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
+  // turn on the buffer movable in openssl
+  if (bufferMovableEnabled_ && ssl_ != nullptr && !isBufferMovable_ &&
+      callback != nullptr && callback->isBufferMovable()) {
+    SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
+    isBufferMovable_ = true;
+  }
+#endif
+
+  AsyncSocket::setReadCB(callback);
+}
+
+void AsyncSSLSocket::setBufferMovableEnabled(bool enabled) {
+  bufferMovableEnabled_ = enabled;
+}
+
+void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
+  CHECK(readCallback_);
+  if (isBufferMovable_) {
+    *buf = nullptr;
+    *buflen = 0;
+  } else {
+    // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP
+    readCallback_->getReadBuffer(buf, buflen);
+  }
+}
+
 void
 AsyncSSLSocket::handleRead() noexcept {
   VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
@@ -1069,42 +1191,57 @@ AsyncSSLSocket::handleRead() noexcept {
   AsyncSocket::handleRead();
 }
 
-ssize_t
-AsyncSSLSocket::performRead(void* buf, size_t buflen) {
-  errno = 0;
-  ssize_t bytes = SSL_read(ssl_, buf, buflen);
+AsyncSocket::ReadResult
+AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
+  VLOG(4) << "AsyncSSLSocket::performRead() this=" << this << ", buf=" << *buf
+          << ", buflen=" << *buflen;
+
+  if (sslState_ == STATE_UNENCRYPTED) {
+    return AsyncSocket::performRead(buf, buflen, offset);
+  }
+
+  ssize_t bytes = 0;
+  if (!isBufferMovable_) {
+    bytes = SSL_read(ssl_, *buf, *buflen);
+  }
+#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
+  else {
+    bytes = SSL_read_buf(ssl_, buf, (int *) offset, (int *) buflen);
+  }
+#endif
+
   if (server_ && renegotiateAttempted_) {
     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
-               << ", sslstate=" << sslState_ << ", events=" << eventFlags_ << "): "
-               << "client intitiated SSL renegotiation not permitted";
-    // We pack our own SSLerr here with a dummy function
-    errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
-                     SSL_CLIENT_RENEGOTIATION_ATTEMPT);
-    ERR_clear_error();
-    return READ_ERROR;
+               << ", sslstate=" << sslState_ << ", events=" << eventFlags_
+               << "): client intitiated SSL renegotiation not permitted";
+    return ReadResult(
+        READ_ERROR,
+        folly::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
   }
   if (bytes <= 0) {
     int error = SSL_get_error(ssl_, bytes);
     if (error == SSL_ERROR_WANT_READ) {
       // The caller will register for read event if not already.
-      return READ_BLOCKING;
+      if (errno == EWOULDBLOCK || errno == EAGAIN) {
+        return ReadResult(READ_BLOCKING);
+      } else {
+        return ReadResult(READ_ERROR);
+      }
     } else if (error == SSL_ERROR_WANT_WRITE) {
       // TODO: Even though we are attempting to read data, SSL_read() may
       // need to write data if renegotiation is being performed.  We currently
       // don't support this and just fail the read.
       LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
-                 << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
-                 << "unsupported SSL renegotiation during read",
-      errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
-                       SSL_INVALID_RENEGOTIATION);
-      ERR_clear_error();
-      return READ_ERROR;
+                 << ", sslState=" << sslState_ << ", events=" << eventFlags_
+                 << "): unsupported SSL renegotiation during read";
+      return ReadResult(
+          READ_ERROR,
+          folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
     } else {
-      // TODO: Fix this code so that it can return a proper error message
-      // to the callback, rather than relying on AsyncSocket code which
-      // can't handle SSL errors.
-      long lastError = ERR_get_error();
-
+      if (zero_return(error, bytes)) {
+        return ReadResult(bytes);
+      }
+      long errError = ERR_get_error();
       VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
               << "state=" << state_ << ", "
               << "sslState=" << sslState_ << ", "
@@ -1112,24 +1249,15 @@ AsyncSSLSocket::performRead(void* buf, size_t buflen) {
               << "bytes: " << bytes << ", "
               << "error: " << error << ", "
               << "errno: " << errno << ", "
-              << "func: " << ERR_func_error_string(lastError) << ", "
-              << "reason: " << ERR_reason_error_string(lastError);
-      ERR_clear_error();
-      if (zero_return(error, bytes)) {
-        return bytes;
-      }
-      if (error != SSL_ERROR_SYSCALL) {
-        if ((unsigned long)lastError < 0x8000) {
-          errno = ENOSYS;
-        } else {
-          errno = lastError;
-        }
-      }
-      return READ_ERROR;
+              << "func: " << ERR_func_error_string(errError) << ", "
+              << "reason: " << ERR_reason_error_string(errError);
+      return ReadResult(
+          READ_ERROR,
+          folly::make_unique<SSLException>(error, errError, bytes, errno));
     }
   } else {
     appBytesReceived_ += bytes;
-    return bytes;
+    return ReadResult(bytes);
   }
 }
 
@@ -1157,19 +1285,52 @@ void AsyncSSLSocket::handleWrite() noexcept {
   AsyncSocket::handleWrite();
 }
 
-ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
-                                      uint32_t count,
-                                      WriteFlags flags,
-                                      uint32_t* countWritten,
-                                      uint32_t* partialWritten) {
+AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
+  if (error == SSL_ERROR_WANT_READ) {
+    // Even though we are attempting to write data, SSL_write() may
+    // need to read data if renegotiation is being performed.  We currently
+    // don't support this and just fail the write.
+    LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
+               << ", sslState=" << sslState_ << ", events=" << eventFlags_
+               << "): "
+               << "unsupported SSL renegotiation during write";
+    return WriteResult(
+        WRITE_ERROR,
+        folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
+  } else {
+    if (zero_return(error, rc)) {
+      return WriteResult(0);
+    }
+    auto errError = ERR_get_error();
+    VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
+            << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
+            << "SSL error: " << error << ", errno: " << errno
+            << ", func: " << ERR_func_error_string(errError)
+            << ", reason: " << ERR_reason_error_string(errError);
+    return WriteResult(
+        WRITE_ERROR,
+        folly::make_unique<SSLException>(error, errError, rc, errno));
+  }
+}
+
+AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
+    const iovec* vec,
+    uint32_t count,
+    WriteFlags flags,
+    uint32_t* countWritten,
+    uint32_t* partialWritten) {
+  if (sslState_ == STATE_UNENCRYPTED) {
+    return AsyncSocket::performWrite(
+      vec, count, flags, countWritten, partialWritten);
+  }
   if (sslState_ != STATE_ESTABLISHED) {
     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
-               << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
+               << ", sslState=" << sslState_
+               << ", events=" << eventFlags_ << "): "
                << "TODO: AsyncSSLSocket currently does not support calling "
                << "write() before the handshake has fully completed";
-      errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
-                       SSL_EARLY_WRITE);
-      return -1;
+    return WriteResult(
+        WRITE_ERROR, folly::make_unique<SSLException>(SSLError::EARLY_WRITE));
   }
 
   bool cork = isSet(flags, WriteFlags::CORK);
@@ -1203,7 +1364,6 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
     buf = ((const char*)v->iov_base) + offset;
 
     ssize_t bytes;
-    errno = 0;
     uint32_t buffersStolen = 0;
     if ((len < minWriteSize_) && ((i + 1) < count)) {
       // Combine this buffer with part or all of the next buffers in
@@ -1257,41 +1417,12 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
       if (error == SSL_ERROR_WANT_WRITE) {
         // The caller will register for write event if not already.
         *partialWritten = offset;
-        return totalWritten;
-      } else if (error == SSL_ERROR_WANT_READ) {
-        // TODO: Even though we are attempting to write data, SSL_write() may
-        // need to read data if renegotiation is being performed.  We currently
-        // don't support this and just fail the write.
-        LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
-                   << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
-                   << "unsupported SSL renegotiation during write",
-        errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
-                         SSL_INVALID_RENEGOTIATION);
-        ERR_clear_error();
-        return -1;
-      } else {
-        // TODO: Fix this code so that it can return a proper error message
-        // to the callback, rather than relying on AsyncSocket code which
-        // can't handle SSL errors.
-        long lastError = ERR_get_error();
-        VLOG(3) <<
-          "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
-                << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
-                << "SSL error: " << error << ", errno: " << errno
-                << ", func: " << ERR_func_error_string(lastError)
-                << ", reason: " << ERR_reason_error_string(lastError);
-        if (error != SSL_ERROR_SYSCALL) {
-          if ((unsigned long)lastError < 0x8000) {
-            errno = ENOSYS;
-          } else {
-            errno = lastError;
-          }
-        }
-        ERR_clear_error();
-        if (!zero_return(error, bytes)) {
-          return -1;
-        } // else fall through to below to correctly record totalWritten
+        return WriteResult(totalWritten);
       }
+      auto writeResult = interpretSSLError(bytes, error);
+      if (writeResult.writeReturn < 0) {
+        return writeResult;
+      } // else fall through to below to correctly record totalWritten
     }
 
     totalWritten += bytes;
@@ -1312,16 +1443,16 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
         v = &(vec[++i]);
       }
       *partialWritten = bytes;
-      return totalWritten;
+      return WriteResult(totalWritten);
     }
   }
 
-  return totalWritten;
+  return WriteResult(totalWritten);
 }
 
 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
                                       bool eor) {
-  if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) {
+  if (eor && trackEor_) {
     if (appEorByteNo_) {
       // cannot track for more than one app byte EOR
       CHECK(appEorByteNo_ == appBytesWritten_ + n);
@@ -1351,43 +1482,52 @@ int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
   return n;
 }
 
-void
-AsyncSSLSocket::sslInfoCallback(const SSL *ssl, int where, int ret) {
+void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
   AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
   if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
     sslSocket->renegotiateAttempted_ = true;
   }
+  if (where & SSL_CB_READ_ALERT) {
+    const char* type = SSL_alert_type_string(ret);
+    if (type) {
+      const char* desc = SSL_alert_desc_string(ret);
+      sslSocket->alertsReceived_.emplace_back(
+          *type, StringPiece(desc, std::strlen(desc)));
+    }
+  }
 }
 
-int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) {
+int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
   int ret;
   struct msghdr msg;
   struct iovec iov;
   int flags = 0;
-  AsyncSSLSocket *tsslSock;
+  AsyncSSLSockettsslSock;
 
-  iov.iov_base = const_cast<char *>(in);
+  iov.iov_base = const_cast<char*>(in);
   iov.iov_len = inl;
   memset(&msg, 0, sizeof(msg));
   msg.msg_iov = &iov;
   msg.msg_iovlen = 1;
 
-  tsslSock =
-    reinterpret_cast<AsyncSSLSocket*>(BIO_get_app_data(b));
-  if (tsslSock &&
-      tsslSock->minEorRawByteNo_ &&
+  auto appData = BIO_get_app_data(b);
+  CHECK(appData);
+
+  tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
+  CHECK(tsslSock);
+
+  if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ &&
       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
     flags = MSG_EOR;
   }
 
-  errno = 0;
-  ret = sendmsg(b->num, &msg, flags);
+  ret = sendmsg(BIO_get_fd(b, nullptr), &msg, flags);
   BIO_clear_retry_flags(b);
   if (ret <= 0) {
     if (BIO_sock_should_retry(ret))
       BIO_set_retry_write(b);
   }
-  return(ret);
+  return ret;
 }
 
 int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
@@ -1405,7 +1545,7 @@ int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
 
 void AsyncSSLSocket::enableClientHelloParsing()  {
     parseClientHello_ = true;
-    clientHelloInfo_.reset(new ClientHelloInfo());
+    clientHelloInfo_.reset(new ssl::ClientHelloInfo());
 }
 
 void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {
@@ -1414,17 +1554,19 @@ void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {
   clientHelloInfo_->clientHelloBuf_.clear();
 }
 
-void
-AsyncSSLSocket::clientHelloParsingCallback(int written, int version,
-    int contentType, const void *buf, size_t len, SSL *ssl, void *arg)
-{
+void AsyncSSLSocket::clientHelloParsingCallback(int written,
+                                                int /* version */,
+                                                int contentType,
+                                                const void* buf,
+                                                size_t len,
+                                                SSL* ssl,
+                                                void* arg) {
   AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
   if (written != 0) {
     sock->resetClientHelloParsing(ssl);
     return;
   }
   if (contentType != SSL3_RT_HANDSHAKE) {
-    sock->resetClientHelloParsing(ssl);
     return;
   }
   if (len == 0) {
@@ -1480,13 +1622,30 @@ AsyncSSLSocket::clientHelloParsingCallback(int written, int version,
     if (cursor.totalLength() > 0) {
       uint16_t extensionsLength = cursor.readBE<uint16_t>();
       while (extensionsLength) {
+        ssl::TLSExtension extensionType =
+            static_cast<ssl::TLSExtension>(cursor.readBE<uint16_t>());
         sock->clientHelloInfo_->
-          clientHelloExtensions_.push_back(cursor.readBE<uint16_t>());
+          clientHelloExtensions_.push_back(extensionType);
         extensionsLength -= 2;
         uint16_t extensionDataLength = cursor.readBE<uint16_t>();
         extensionsLength -= 2;
-        cursor.skip(extensionDataLength);
-        extensionsLength -= extensionDataLength;
+
+        if (extensionType == ssl::TLSExtension::SIGNATURE_ALGORITHMS) {
+          cursor.skip(2);
+          extensionDataLength -= 2;
+          while (extensionDataLength) {
+            ssl::HashAlgorithm hashAlg =
+                static_cast<ssl::HashAlgorithm>(cursor.readBE<uint8_t>());
+            ssl::SignatureAlgorithm sigAlg =
+                static_cast<ssl::SignatureAlgorithm>(cursor.readBE<uint8_t>());
+            extensionDataLength -= 2;
+            sock->clientHelloInfo_->
+              clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg);
+          }
+        } else {
+          cursor.skip(extensionDataLength);
+          extensionsLength -= extensionDataLength;
+        }
       }
     }
   } catch (std::out_of_range& e) {