Enable EOR flag configuration for folly::AsyncSocket.
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
index fefd9d018a7ed8d1cc285072ff51c2c95b9597a2..c19a93263eaf6f5085cab95086f5689d66154425 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -22,8 +22,8 @@
 #include <boost/noncopyable.hpp>
 #include <errno.h>
 #include <fcntl.h>
-#include <openssl/err.h>
 #include <openssl/asn1.h>
+#include <openssl/err.h>
 #include <openssl/ssl.h>
 #include <sys/types.h>
 #include <chrono>
@@ -31,8 +31,9 @@
 #include <folly/Bits.h>
 #include <folly/SocketAddress.h>
 #include <folly/SpinLock.h>
-#include <folly/io/IOBuf.h>
 #include <folly/io/Cursor.h>
+#include <folly/io/IOBuf.h>
+#include <folly/portability/OpenSSL.h>
 #include <folly/portability/Unistd.h>
 
 using folly::SocketAddress;
@@ -54,6 +55,8 @@ using folly::AsyncSocketException;
 using folly::AsyncSSLSocket;
 using folly::Optional;
 using folly::SSLContext;
+// For OpenSSL portability API
+using namespace folly::ssl;
 using folly::ssl::OpenSSLUtils;
 
 // We have one single dummy SSL context so that we can implement attach
@@ -110,7 +113,7 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
         return;
       }
     }
-    sslSocket_->sslConn(this, timeoutLeft);
+    sslSocket_->sslConn(this, std::chrono::milliseconds(timeoutLeft));
   }
 
   void connectErr(const AsyncSocketException& ex) noexcept override {
@@ -173,15 +176,16 @@ void setup_SSL_CTX(SSL_CTX *ctx) {
 
 }
 
-BIO_METHOD sslWriteBioMethod;
+BIO_METHOD sslBioMethod;
 
-void* initsslWriteBioMethod(void) {
-  memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod));
+void* initsslBioMethod(void) {
+  memcpy(&sslBioMethod, BIO_s_socket(), sizeof(sslBioMethod));
   // override the bwrite method for MSG_EOR support
   OpenSSLUtils::setCustomBioWriteMethod(
-      &sslWriteBioMethod, AsyncSSLSocket::bioWrite);
+      &sslBioMethod, AsyncSSLSocket::bioWrite);
+  OpenSSLUtils::setCustomBioReadMethod(&sslBioMethod, AsyncSSLSocket::bioRead);
 
-  // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not
+  // Note that the sslBioMethod.type and sslBioMethod.name are not
   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
   // then have specific handlings. The sslWriteBioWrite should be compatible
   // with the one in openssl.
@@ -221,6 +225,7 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
     ctx_(ctx),
     handshakeTimeout_(this, evb),
     connectionTimeout_(this, evb) {
+  noTransparentTls_ = true;
   init();
   if (server) {
     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
@@ -231,7 +236,7 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
   }
 }
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
 /**
  * Create a client AsyncSSLSocket and allow tlsext_hostname
  * to be sent in Client Hello.
@@ -255,7 +260,7 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
     AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
   tlsextHostname_ = serverName;
 }
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
 
 AsyncSSLSocket::~AsyncSSLSocket() {
   VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
@@ -267,8 +272,8 @@ AsyncSSLSocket::~AsyncSSLSocket() {
 void AsyncSSLSocket::init() {
   // Do this here to ensure we initialize this once before any use of
   // AsyncSSLSocket instances and not as part of library load.
-  static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod();
-  (void)sslWriteBioMethodInitializer;
+  static const auto sslBioMethodInitializer = initsslBioMethod();
+  (void)sslBioMethodInitializer;
 
   setup_SSL_CTX(ctx_->getSSLCtx());
 }
@@ -353,13 +358,9 @@ std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
   return "";
 }
 
-bool AsyncSSLSocket::isEorTrackingEnabled() const {
-  return trackEor_;
-}
-
 void AsyncSSLSocket::setEorTracking(bool track) {
-  if (trackEor_ != track) {
-    trackEor_ = track;
+  if (isEorTrackingEnabled() != track) {
+    AsyncSocket::setEorTracking(track);
     appEorByteNo_ = 0;
     minEorRawByteNo_ = 0;
   }
@@ -411,14 +412,13 @@ void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
     callback->handshakeErr(this, ex);
   }
 
-  // Check the socket state not the ssl state here.
-  if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
-    failHandshake(__func__, ex);
-  }
+  failHandshake(__func__, ex);
 }
 
-void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
-      const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
+void AsyncSSLSocket::sslAccept(
+    HandshakeCB* callback,
+    std::chrono::milliseconds timeout,
+    const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
   DestructorGuard dg(this);
   assert(eventBase_->isInEventBaseThread());
   verifyPeer_ = verifyPeer;
@@ -443,12 +443,16 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
   sslState_ = STATE_ACCEPTING;
   handshakeCallback_ = callback;
 
-  if (timeout > 0) {
+  if (timeout > std::chrono::milliseconds::zero()) {
     handshakeTimeout_.scheduleTimeout(timeout);
   }
 
   /* register for a read operation (waiting for CLIENT HELLO) */
   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
+
+  if (preReceivedData_) {
+    handleRead();
+  }
 }
 
 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
@@ -473,8 +477,8 @@ void AsyncSSLSocket::attachSSLContext(
   // previously called.
   // We need to update the initial_ctx if necessary
   auto sslCtx = ctx->getSSLCtx();
+  SSL_CTX_up_ref(sslCtx);
 #ifndef OPENSSL_NO_TLSEXT
-  CRYPTO_add(&sslCtx->references, 1, CRYPTO_LOCK_SSL_CTX);
   // note that detachSSLContext has already freed ssl_->initial_ctx
   ssl_->initial_ctx = sslCtx;
 #endif
@@ -514,7 +518,7 @@ void AsyncSSLSocket::detachSSLContext() {
 }
 #endif
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
 void AsyncSSLSocket::switchServerSSLContext(
   const std::shared_ptr<SSLContext>& handshakeCtx) {
   CHECK(server_);
@@ -555,7 +559,7 @@ void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
   tlsextHostname_ = std::move(serverName);
 }
 
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
 
 void AsyncSSLSocket::timeoutExpired() noexcept {
   if (state_ == StateEnum::ESTABLISHED &&
@@ -646,6 +650,7 @@ void AsyncSSLSocket::connect(ConnectCallback* callback,
   assert(!server_);
   assert(state_ == StateEnum::UNINIT);
   assert(sslState_ == STATE_UNINIT);
+  noTransparentTls_ = true;
   AsyncSSLSocketConnector *connector =
     new AsyncSSLSocketConnector(this, callback, timeout);
   AsyncSocket::connect(connector, address, timeout, options, bindAddr);
@@ -668,20 +673,22 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
 }
 
 bool AsyncSSLSocket::setupSSLBio() {
-  auto wb = BIO_new(&sslWriteBioMethod);
+  auto sslBio = BIO_new(&sslBioMethod);
 
-  if (!wb) {
+  if (!sslBio) {
     return false;
   }
 
-  OpenSSLUtils::setBioAppData(wb, this);
-  OpenSSLUtils::setBioFd(wb, fd_, BIO_NOCLOSE);
-  SSL_set_bio(ssl_, wb, wb);
+  OpenSSLUtils::setBioAppData(sslBio, this);
+  OpenSSLUtils::setBioFd(sslBio, fd_, BIO_NOCLOSE);
+  SSL_set_bio(ssl_, sslBio, sslBio);
   return true;
 }
 
-void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
-        const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
+void AsyncSSLSocket::sslConn(
+    HandshakeCB* callback,
+    std::chrono::milliseconds timeout,
+    const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
   DestructorGuard dg(this);
   assert(eventBase_->isInEventBaseThread());
 
@@ -729,7 +736,7 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
     SSL_SESSION_free(sslSession_);
     sslSession_ = nullptr;
   }
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
   if (tlsextHostname_.size()) {
     SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
   }
@@ -747,9 +754,8 @@ void AsyncSSLSocket::startSSLConnect() {
   handshakeStartTime_ = std::chrono::steady_clock::now();
   // Make end time at least >= start time.
   handshakeEndTime_ = handshakeStartTime_;
-  if (handshakeConnectTimeout_ > 0) {
-    handshakeTimeout_.scheduleTimeout(
-        std::chrono::milliseconds(handshakeConnectTimeout_));
+  if (handshakeConnectTimeout_ > std::chrono::milliseconds::zero()) {
+    handshakeTimeout_.scheduleTimeout(handshakeConnectTimeout_);
   }
   handleConnect();
 }
@@ -770,7 +776,8 @@ void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
   sslSession_ = session;
   if (!takeOwnership && session != nullptr) {
     // Increment the reference count
-    CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
+    // This API exists in BoringSSL and OpenSSL 1.1.0
+    SSL_SESSION_up_ref(session);
   }
 }
 
@@ -790,7 +797,7 @@ bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
     SSLContext::NextProtocolType* protoType) const {
   *protoName = nullptr;
   *protoLen = 0;
-#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_ALPN
   SSL_get0_alpn_selected(ssl_, protoName, protoLen);
   if (*protoLen > 0) {
     if (protoType) {
@@ -1021,6 +1028,7 @@ AsyncSSLSocket::handleAccept() noexcept {
     SSL_set_msg_callback_arg(ssl_, this);
   }
 
+  clearOpenSSLErrors();
   int ret = SSL_accept(ssl_);
   if (ret <= 0) {
     int sslError;
@@ -1070,6 +1078,18 @@ AsyncSSLSocket::handleAccept() noexcept {
   AsyncSocket::handleInitialReadWrite();
 }
 
+void AsyncSSLSocket::clearOpenSSLErrors() {
+  // Normally clearing out the error before calling into an openssl method
+  // is a bad idea. However there might be other code that we don't control
+  // calling into openssl in the same thread, which doesn't use openssl
+  // correctly. We want to safe-guard ourselves from that code.
+  // However touching the ERR stack each and every time has a cost of taking
+  // a lock, so we only do this when we've opted in.
+  if (clearOpenSSLErrors_) {
+    ERR_clear_error();
+  }
+}
+
 void
 AsyncSSLSocket::handleConnect() noexcept {
   VLOG(3) <<  "AsyncSSLSocket::handleConnect() this=" << this
@@ -1085,6 +1105,7 @@ AsyncSSLSocket::handleConnect() noexcept {
       sslState_ == STATE_CONNECTING);
   assert(ssl_);
 
+  clearOpenSSLErrors();
   auto originalState = state_;
   int ret = SSL_connect(ssl_);
   if (ret <= 0) {
@@ -1252,6 +1273,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
     return AsyncSocket::performRead(buf, buflen, offset);
   }
 
+  clearOpenSSLErrors();
   int bytes = 0;
   if (!isBufferMovable_) {
     bytes = SSL_read(ssl_, *buf, int(*buflen));
@@ -1512,7 +1534,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
 
 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
                                       bool eor) {
-  if (eor && trackEor_) {
+  if (eor && isEorTrackingEnabled()) {
     if (appEorByteNo_) {
       // cannot track for more than one app byte EOR
       CHECK(appEorByteNo_ == appBytesWritten_ + n);
@@ -1575,7 +1597,7 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
   tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
   CHECK(tsslSock);
 
-  if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ &&
+  if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ &&
       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
     flags = MSG_EOR;
   }
@@ -1601,6 +1623,37 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
   return int(result.writeReturn);
 }
 
+int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) {
+  if (!out) {
+    return 0;
+  }
+  BIO_clear_retry_flags(b);
+
+  auto appData = OpenSSLUtils::getBioAppData(b);
+  CHECK(appData);
+  auto sslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
+
+  if (sslSock->preReceivedData_ && !sslSock->preReceivedData_->empty()) {
+    VLOG(5) << "AsyncSSLSocket::bioRead() this=" << sslSock
+            << ", reading pre-received data";
+
+    Cursor cursor(sslSock->preReceivedData_.get());
+    auto len = cursor.pullAtMost(out, outl);
+
+    IOBufQueue queue;
+    queue.append(std::move(sslSock->preReceivedData_));
+    queue.trimStart(len);
+    sslSock->preReceivedData_ = queue.move();
+    return len;
+  } else {
+    auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0);
+    if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
+      BIO_set_retry_read(b);
+    }
+    return result;
+  }
+}
+
 int AsyncSSLSocket::sslVerifyCallback(
     int preverifyOk,
     X509_STORE_CTX* x509Ctx) {
@@ -1615,6 +1668,12 @@ int AsyncSSLSocket::sslVerifyCallback(
     preverifyOk;
 }
 
+void AsyncSSLSocket::setPreReceivedData(std::unique_ptr<IOBuf> data) {
+  CHECK(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED);
+  CHECK(!preReceivedData_);
+  preReceivedData_ = std::move(data);
+}
+
 void AsyncSSLSocket::enableClientHelloParsing()  {
     parseClientHello_ = true;
     clientHelloInfo_.reset(new ssl::ClientHelloInfo());