apply clang-tidy modernize-use-override
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
index 26090b7ba770dfa7bd762f46f19aa440a1769a00..1379bf90bc91fe37c0285cda33e7522882b85871 100644 (file)
@@ -22,9 +22,6 @@
 #include <boost/noncopyable.hpp>
 #include <errno.h>
 #include <fcntl.h>
-#include <openssl/err.h>
-#include <openssl/asn1.h>
-#include <openssl/ssl.h>
 #include <sys/types.h>
 #include <chrono>
 
@@ -177,14 +174,20 @@ void setup_SSL_CTX(SSL_CTX *ctx) {
 
 }
 
-BIO_METHOD sslBioMethod;
+// Note: This is a Leaky Meyer's Singleton. The reason we can't use a non-leaky
+// thing is because we will be setting this BIO_METHOD* inside BIOs owned by
+// various SSL objects which may get callbacks even during teardown. We may
+// eventually try to fix this
+static BIO_METHOD* getSSLBioMethod() {
+  static auto const instance = OpenSSLUtils::newSocketBioMethod().release();
+  return instance;
+}
 
 void* initsslBioMethod(void) {
-  memcpy(&sslBioMethod, BIO_s_socket(), sizeof(sslBioMethod));
+  auto sslBioMethod = getSSLBioMethod();
   // override the bwrite method for MSG_EOR support
-  OpenSSLUtils::setCustomBioWriteMethod(
-      &sslBioMethod, AsyncSSLSocket::bioWrite);
-  OpenSSLUtils::setCustomBioReadMethod(&sslBioMethod, AsyncSSLSocket::bioRead);
+  OpenSSLUtils::setCustomBioWriteMethod(sslBioMethod, AsyncSSLSocket::bioWrite);
+  OpenSSLUtils::setCustomBioReadMethod(sslBioMethod, AsyncSSLSocket::bioRead);
 
   // Note that the sslBioMethod.type and sslBioMethod.name are not
   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
@@ -248,8 +251,8 @@ AsyncSSLSocket::AsyncSSLSocket(
     : AsyncSocket(std::move(oldAsyncSocket)),
       server_(server),
       ctx_(ctx),
-      handshakeTimeout_(this, oldAsyncSocket->getEventBase()),
-      connectionTimeout_(this, oldAsyncSocket->getEventBase()) {
+      handshakeTimeout_(this, AsyncSocket::getEventBase()),
+      connectionTimeout_(this, AsyncSocket::getEventBase()) {
   noTransparentTls_ = true;
   init();
   if (server) {
@@ -403,7 +406,7 @@ size_t AsyncSSLSocket::getRawBytesWritten() const {
     return 0;
   }
   BIO* next = BIO_next(b);
-  while (next != NULL) {
+  while (next != nullptr) {
     b = next;
     next = BIO_next(b);
   }
@@ -480,7 +483,6 @@ void AsyncSSLSocket::sslAccept(
   checkForImmediateRead();
 }
 
-#if OPENSSL_VERSION_NUMBER >= 0x009080bfL
 void AsyncSSLSocket::attachSSLContext(
   const std::shared_ptr<SSLContext>& ctx) {
 
@@ -503,10 +505,16 @@ void AsyncSSLSocket::attachSSLContext(
   // We need to update the initial_ctx if necessary
   auto sslCtx = ctx->getSSLCtx();
   SSL_CTX_up_ref(sslCtx);
-#ifndef OPENSSL_NO_TLSEXT
-  // note that detachSSLContext has already freed ssl_->initial_ctx
-  ssl_->initial_ctx = sslCtx;
-#endif
+
+  // The 'initial_ctx' inside an SSL* points to the context that it was created
+  // with, which is also where session callbacks and servername callbacks
+  // happen.
+  // When we switch to a different SSL_CTX, we want to update the initial_ctx as
+  // well so that any callbacks don't go to a different object
+  // NOTE: this will only work if we have access to ssl_ internals, so it may
+  // not work on
+  // OpenSSL version >= 1.1.0
+  OpenSSLUtils::setSSLInitialCtx(ssl_, sslCtx);
   // Detach sets the socket's context to the dummy context. Thus we must acquire
   // this lock.
   SpinLockGuard guard(dummyCtxLock);
@@ -521,14 +529,20 @@ void AsyncSSLSocket::detachSSLContext() {
   if (!ssl_) {
     return;
   }
-// Detach the initial_ctx as well.  Internally w/ OPENSSL_NO_TLSEXT
-// it is used for session info.  It will be reattached in attachSSLContext
-#ifndef OPENSSL_NO_TLSEXT
-  if (ssl_->initial_ctx) {
-    SSL_CTX_free(ssl_->initial_ctx);
-    ssl_->initial_ctx = nullptr;
+  // The 'initial_ctx' inside an SSL* points to the context that it was created
+  // with, which is also where session callbacks and servername callbacks
+  // happen.
+  // Detach the initial_ctx as well.  It will be reattached in attachSSLContext
+  // it is used for session info.
+  // NOTE: this will only work if we have access to ssl_ internals, so it may
+  // not work on
+  // OpenSSL version >= 1.1.0
+  SSL_CTX* initialCtx = OpenSSLUtils::getSSLInitialCtx(ssl_);
+  if (initialCtx) {
+    SSL_CTX_free(initialCtx);
+    OpenSSLUtils::setSSLInitialCtx(ssl_, nullptr);
   }
-#endif
+
   SpinLockGuard guard(dummyCtxLock);
   if (nullptr == dummyCtx) {
     // We need to lazily initialize the dummy context so we don't
@@ -541,7 +555,6 @@ void AsyncSSLSocket::detachSSLContext() {
   // would not be thread safe.
   SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
 }
-#endif
 
 #if FOLLY_OPENSSL_HAS_SNI
 void AsyncSSLSocket::switchServerSSLContext(
@@ -574,10 +587,8 @@ bool AsyncSSLSocket::isServerNameMatch() const {
     return false;
   }
 
-  if(!ss->tlsext_hostname) {
-    return false;
-  }
-  return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
+  auto tlsextHostname = SSL_SESSION_get0_hostname(ss);
+  return (tlsextHostname && !tlsextHostname_.compare(tlsextHostname));
 }
 
 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
@@ -691,10 +702,10 @@ void AsyncSSLSocket::connect(
   assert(sslState_ == STATE_UNINIT);
   noTransparentTls_ = true;
   totalConnectTimeout_ = totalConnectTimeout;
-  AsyncSSLSocketConnector* connector =
-      new AsyncSSLSocketConnector(this, callback, totalConnectTimeout.count());
+  AsyncSSLSocketConnector* connector = new AsyncSSLSocketConnector(
+      this, callback, int(totalConnectTimeout.count()));
   AsyncSocket::connect(
-      connector, address, connectTimeout.count(), options, bindAddr);
+      connector, address, int(connectTimeout.count()), options, bindAddr);
 }
 
 bool AsyncSSLSocket::needsPeerVerification() const {
@@ -723,7 +734,7 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
 }
 
 bool AsyncSSLSocket::setupSSLBio() {
-  auto sslBio = BIO_new(&sslBioMethod);
+  auto sslBio = BIO_new(getSSLBioMethod());
 
   if (!sslBio) {
     return false;
@@ -911,7 +922,7 @@ int AsyncSSLSocket::getSSLVersion() const {
 const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
   if (cert) {
-    int nid = OBJ_obj2nid(cert->sig_alg->algorithm);
+    int nid = X509_get_signature_nid(cert);
     return OBJ_nid2ln(nid);
   }
   return nullptr;
@@ -965,14 +976,14 @@ bool AsyncSSLSocket::willBlock(int ret,
     // The timeout (if set) keeps running here
     return true;
 #endif
-  } else if (0
+  } else if ((0
 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
       || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
 #endif
 #ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
       || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
 #endif
-      ) {
+      )) {
     // Our custom openssl function has kicked off an async request to do
     // rsa/ecdsa private key operation.  When that call returns, a callback will
     // be invoked that will re-call handleAccept.
@@ -1327,7 +1338,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
                << "): client intitiated SSL renegotiation not permitted";
     return ReadResult(
         READ_ERROR,
-        folly::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
+        std::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
   }
   if (bytes <= 0) {
     int error = SSL_get_error(ssl_, bytes);
@@ -1347,7 +1358,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
                  << "): unsupported SSL renegotiation during read";
       return ReadResult(
           READ_ERROR,
-          folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
+          std::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
     } else {
       if (zero_return(error, bytes)) {
         return ReadResult(bytes);
@@ -1364,7 +1375,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
               << "reason: " << ERR_reason_error_string(errError);
       return ReadResult(
           READ_ERROR,
-          folly::make_unique<SSLException>(error, errError, bytes, errno));
+          std::make_unique<SSLException>(error, errError, bytes, errno));
     }
   } else {
     appBytesReceived_ += bytes;
@@ -1407,7 +1418,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
                << "unsupported SSL renegotiation during write";
     return WriteResult(
         WRITE_ERROR,
-        folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
+        std::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
   } else {
     if (zero_return(error, rc)) {
       return WriteResult(0);
@@ -1420,7 +1431,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
             << ", reason: " << ERR_reason_error_string(errError);
     return WriteResult(
         WRITE_ERROR,
-        folly::make_unique<SSLException>(error, errError, rc, errno));
+        std::make_unique<SSLException>(error, errError, rc, errno));
   }
 }
 
@@ -1441,7 +1452,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
                << "TODO: AsyncSSLSocket currently does not support calling "
                << "write() before the handshake has fully completed";
     return WriteResult(
-        WRITE_ERROR, folly::make_unique<SSLException>(SSLError::EARLY_WRITE));
+        WRITE_ERROR, std::make_unique<SSLException>(SSLError::EARLY_WRITE));
   }
 
   // Declare a buffer used to hold small write requests.  It could point to a
@@ -1690,9 +1701,9 @@ int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) {
     queue.append(std::move(sslSock->preReceivedData_));
     queue.trimStart(len);
     sslSock->preReceivedData_ = queue.move();
-    return len;
+    return static_cast<int>(len);
   } else {
-    auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0);
+    auto result = int(recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0));
     if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
       BIO_set_retry_read(b);
     }