Enable EOR flag configuration for folly::AsyncSocket.
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
index a61e761770d493d6df17fb29a2f2301f68b2e64b..c19a93263eaf6f5085cab95086f5689d66154425 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -22,8 +22,8 @@
 #include <boost/noncopyable.hpp>
 #include <errno.h>
 #include <fcntl.h>
-#include <openssl/err.h>
 #include <openssl/asn1.h>
+#include <openssl/err.h>
 #include <openssl/ssl.h>
 #include <sys/types.h>
 #include <chrono>
@@ -176,15 +176,16 @@ void setup_SSL_CTX(SSL_CTX *ctx) {
 
 }
 
-BIO_METHOD sslWriteBioMethod;
+BIO_METHOD sslBioMethod;
 
-void* initsslWriteBioMethod(void) {
-  memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod));
+void* initsslBioMethod(void) {
+  memcpy(&sslBioMethod, BIO_s_socket(), sizeof(sslBioMethod));
   // override the bwrite method for MSG_EOR support
   OpenSSLUtils::setCustomBioWriteMethod(
-      &sslWriteBioMethod, AsyncSSLSocket::bioWrite);
+      &sslBioMethod, AsyncSSLSocket::bioWrite);
+  OpenSSLUtils::setCustomBioReadMethod(&sslBioMethod, AsyncSSLSocket::bioRead);
 
-  // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not
+  // Note that the sslBioMethod.type and sslBioMethod.name are not
   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
   // then have specific handlings. The sslWriteBioWrite should be compatible
   // with the one in openssl.
@@ -224,6 +225,7 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
     ctx_(ctx),
     handshakeTimeout_(this, evb),
     connectionTimeout_(this, evb) {
+  noTransparentTls_ = true;
   init();
   if (server) {
     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
@@ -234,7 +236,7 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
   }
 }
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
 /**
  * Create a client AsyncSSLSocket and allow tlsext_hostname
  * to be sent in Client Hello.
@@ -258,7 +260,7 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
     AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
   tlsextHostname_ = serverName;
 }
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
 
 AsyncSSLSocket::~AsyncSSLSocket() {
   VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
@@ -270,8 +272,8 @@ AsyncSSLSocket::~AsyncSSLSocket() {
 void AsyncSSLSocket::init() {
   // Do this here to ensure we initialize this once before any use of
   // AsyncSSLSocket instances and not as part of library load.
-  static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod();
-  (void)sslWriteBioMethodInitializer;
+  static const auto sslBioMethodInitializer = initsslBioMethod();
+  (void)sslBioMethodInitializer;
 
   setup_SSL_CTX(ctx_->getSSLCtx());
 }
@@ -356,13 +358,9 @@ std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
   return "";
 }
 
-bool AsyncSSLSocket::isEorTrackingEnabled() const {
-  return trackEor_;
-}
-
 void AsyncSSLSocket::setEorTracking(bool track) {
-  if (trackEor_ != track) {
-    trackEor_ = track;
+  if (isEorTrackingEnabled() != track) {
+    AsyncSocket::setEorTracking(track);
     appEorByteNo_ = 0;
     minEorRawByteNo_ = 0;
   }
@@ -414,10 +412,7 @@ void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
     callback->handshakeErr(this, ex);
   }
 
-  // Check the socket state not the ssl state here.
-  if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
-    failHandshake(__func__, ex);
-  }
+  failHandshake(__func__, ex);
 }
 
 void AsyncSSLSocket::sslAccept(
@@ -454,6 +449,10 @@ void AsyncSSLSocket::sslAccept(
 
   /* register for a read operation (waiting for CLIENT HELLO) */
   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
+
+  if (preReceivedData_) {
+    handleRead();
+  }
 }
 
 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
@@ -519,7 +518,7 @@ void AsyncSSLSocket::detachSSLContext() {
 }
 #endif
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
 void AsyncSSLSocket::switchServerSSLContext(
   const std::shared_ptr<SSLContext>& handshakeCtx) {
   CHECK(server_);
@@ -560,7 +559,7 @@ void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
   tlsextHostname_ = std::move(serverName);
 }
 
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
 
 void AsyncSSLSocket::timeoutExpired() noexcept {
   if (state_ == StateEnum::ESTABLISHED &&
@@ -651,6 +650,7 @@ void AsyncSSLSocket::connect(ConnectCallback* callback,
   assert(!server_);
   assert(state_ == StateEnum::UNINIT);
   assert(sslState_ == STATE_UNINIT);
+  noTransparentTls_ = true;
   AsyncSSLSocketConnector *connector =
     new AsyncSSLSocketConnector(this, callback, timeout);
   AsyncSocket::connect(connector, address, timeout, options, bindAddr);
@@ -673,15 +673,15 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
 }
 
 bool AsyncSSLSocket::setupSSLBio() {
-  auto wb = BIO_new(&sslWriteBioMethod);
+  auto sslBio = BIO_new(&sslBioMethod);
 
-  if (!wb) {
+  if (!sslBio) {
     return false;
   }
 
-  OpenSSLUtils::setBioAppData(wb, this);
-  OpenSSLUtils::setBioFd(wb, fd_, BIO_NOCLOSE);
-  SSL_set_bio(ssl_, wb, wb);
+  OpenSSLUtils::setBioAppData(sslBio, this);
+  OpenSSLUtils::setBioFd(sslBio, fd_, BIO_NOCLOSE);
+  SSL_set_bio(ssl_, sslBio, sslBio);
   return true;
 }
 
@@ -736,7 +736,7 @@ void AsyncSSLSocket::sslConn(
     SSL_SESSION_free(sslSession_);
     sslSession_ = nullptr;
   }
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
   if (tlsextHostname_.size()) {
     SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
   }
@@ -797,7 +797,7 @@ bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
     SSLContext::NextProtocolType* protoType) const {
   *protoName = nullptr;
   *protoLen = 0;
-#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_ALPN
   SSL_get0_alpn_selected(ssl_, protoName, protoLen);
   if (*protoLen > 0) {
     if (protoType) {
@@ -1028,6 +1028,7 @@ AsyncSSLSocket::handleAccept() noexcept {
     SSL_set_msg_callback_arg(ssl_, this);
   }
 
+  clearOpenSSLErrors();
   int ret = SSL_accept(ssl_);
   if (ret <= 0) {
     int sslError;
@@ -1077,6 +1078,18 @@ AsyncSSLSocket::handleAccept() noexcept {
   AsyncSocket::handleInitialReadWrite();
 }
 
+void AsyncSSLSocket::clearOpenSSLErrors() {
+  // Normally clearing out the error before calling into an openssl method
+  // is a bad idea. However there might be other code that we don't control
+  // calling into openssl in the same thread, which doesn't use openssl
+  // correctly. We want to safe-guard ourselves from that code.
+  // However touching the ERR stack each and every time has a cost of taking
+  // a lock, so we only do this when we've opted in.
+  if (clearOpenSSLErrors_) {
+    ERR_clear_error();
+  }
+}
+
 void
 AsyncSSLSocket::handleConnect() noexcept {
   VLOG(3) <<  "AsyncSSLSocket::handleConnect() this=" << this
@@ -1092,6 +1105,7 @@ AsyncSSLSocket::handleConnect() noexcept {
       sslState_ == STATE_CONNECTING);
   assert(ssl_);
 
+  clearOpenSSLErrors();
   auto originalState = state_;
   int ret = SSL_connect(ssl_);
   if (ret <= 0) {
@@ -1259,6 +1273,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
     return AsyncSocket::performRead(buf, buflen, offset);
   }
 
+  clearOpenSSLErrors();
   int bytes = 0;
   if (!isBufferMovable_) {
     bytes = SSL_read(ssl_, *buf, int(*buflen));
@@ -1519,7 +1534,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
 
 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
                                       bool eor) {
-  if (eor && trackEor_) {
+  if (eor && isEorTrackingEnabled()) {
     if (appEorByteNo_) {
       // cannot track for more than one app byte EOR
       CHECK(appEorByteNo_ == appBytesWritten_ + n);
@@ -1582,7 +1597,7 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
   tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
   CHECK(tsslSock);
 
-  if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ &&
+  if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ &&
       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
     flags = MSG_EOR;
   }
@@ -1608,6 +1623,37 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
   return int(result.writeReturn);
 }
 
+int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) {
+  if (!out) {
+    return 0;
+  }
+  BIO_clear_retry_flags(b);
+
+  auto appData = OpenSSLUtils::getBioAppData(b);
+  CHECK(appData);
+  auto sslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
+
+  if (sslSock->preReceivedData_ && !sslSock->preReceivedData_->empty()) {
+    VLOG(5) << "AsyncSSLSocket::bioRead() this=" << sslSock
+            << ", reading pre-received data";
+
+    Cursor cursor(sslSock->preReceivedData_.get());
+    auto len = cursor.pullAtMost(out, outl);
+
+    IOBufQueue queue;
+    queue.append(std::move(sslSock->preReceivedData_));
+    queue.trimStart(len);
+    sslSock->preReceivedData_ = queue.move();
+    return len;
+  } else {
+    auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0);
+    if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
+      BIO_set_retry_read(b);
+    }
+    return result;
+  }
+}
+
 int AsyncSSLSocket::sslVerifyCallback(
     int preverifyOk,
     X509_STORE_CTX* x509Ctx) {
@@ -1622,6 +1668,12 @@ int AsyncSSLSocket::sslVerifyCallback(
     preverifyOk;
 }
 
+void AsyncSSLSocket::setPreReceivedData(std::unique_ptr<IOBuf> data) {
+  CHECK(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED);
+  CHECK(!preReceivedData_);
+  preReceivedData_ = std::move(data);
+}
+
 void AsyncSSLSocket::enableClientHelloParsing()  {
     parseClientHello_ = true;
     clientHelloInfo_.reset(new ssl::ClientHelloInfo());