apply clang-tidy modernize-use-override
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
index 3b99a4c4a05121cbfdddd39336877d57b3ed5c8f..1379bf90bc91fe37c0285cda33e7522882b85871 100644 (file)
@@ -22,9 +22,6 @@
 #include <boost/noncopyable.hpp>
 #include <errno.h>
 #include <fcntl.h>
-#include <openssl/err.h>
-#include <openssl/asn1.h>
-#include <openssl/ssl.h>
 #include <sys/types.h>
 #include <chrono>
 
@@ -34,7 +31,6 @@
 #include <folly/io/Cursor.h>
 #include <folly/io/IOBuf.h>
 #include <folly/portability/OpenSSL.h>
-#include <folly/portability/Unistd.h>
 
 using folly::SocketAddress;
 using folly::SSLContext;
@@ -59,6 +55,7 @@ using folly::SSLContext;
 using namespace folly::ssl;
 using folly::ssl::OpenSSLUtils;
 
+
 // We have one single dummy SSL context so that we can implement attach
 // and detach methods in a thread safe fashion without modifying opnessl.
 static SSLContext *dummyCtx = nullptr;
@@ -177,14 +174,20 @@ void setup_SSL_CTX(SSL_CTX *ctx) {
 
 }
 
-BIO_METHOD sslBioMethod;
+// Note: This is a Leaky Meyer's Singleton. The reason we can't use a non-leaky
+// thing is because we will be setting this BIO_METHOD* inside BIOs owned by
+// various SSL objects which may get callbacks even during teardown. We may
+// eventually try to fix this
+static BIO_METHOD* getSSLBioMethod() {
+  static auto const instance = OpenSSLUtils::newSocketBioMethod().release();
+  return instance;
+}
 
 void* initsslBioMethod(void) {
-  memcpy(&sslBioMethod, BIO_s_socket(), sizeof(sslBioMethod));
+  auto sslBioMethod = getSSLBioMethod();
   // override the bwrite method for MSG_EOR support
-  OpenSSLUtils::setCustomBioWriteMethod(
-      &sslBioMethod, AsyncSSLSocket::bioWrite);
-  OpenSSLUtils::setCustomBioReadMethod(&sslBioMethod, AsyncSSLSocket::bioRead);
+  OpenSSLUtils::setCustomBioWriteMethod(sslBioMethod, AsyncSSLSocket::bioWrite);
+  OpenSSLUtils::setCustomBioReadMethod(sslBioMethod, AsyncSSLSocket::bioRead);
 
   // Note that the sslBioMethod.type and sslBioMethod.name are not
   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
@@ -248,8 +251,8 @@ AsyncSSLSocket::AsyncSSLSocket(
     : AsyncSocket(std::move(oldAsyncSocket)),
       server_(server),
       ctx_(ctx),
-      handshakeTimeout_(this, oldAsyncSocket->getEventBase()),
-      connectionTimeout_(this, oldAsyncSocket->getEventBase()) {
+      handshakeTimeout_(this, AsyncSocket::getEventBase()),
+      connectionTimeout_(this, AsyncSocket::getEventBase()) {
   noTransparentTls_ = true;
   init();
   if (server) {
@@ -403,7 +406,7 @@ size_t AsyncSSLSocket::getRawBytesWritten() const {
     return 0;
   }
   BIO* next = BIO_next(b);
-  while (next != NULL) {
+  while (next != nullptr) {
     b = next;
     next = BIO_next(b);
   }
@@ -480,7 +483,6 @@ void AsyncSSLSocket::sslAccept(
   checkForImmediateRead();
 }
 
-#if OPENSSL_VERSION_NUMBER >= 0x009080bfL
 void AsyncSSLSocket::attachSSLContext(
   const std::shared_ptr<SSLContext>& ctx) {
 
@@ -503,10 +505,16 @@ void AsyncSSLSocket::attachSSLContext(
   // We need to update the initial_ctx if necessary
   auto sslCtx = ctx->getSSLCtx();
   SSL_CTX_up_ref(sslCtx);
-#ifndef OPENSSL_NO_TLSEXT
-  // note that detachSSLContext has already freed ssl_->initial_ctx
-  ssl_->initial_ctx = sslCtx;
-#endif
+
+  // The 'initial_ctx' inside an SSL* points to the context that it was created
+  // with, which is also where session callbacks and servername callbacks
+  // happen.
+  // When we switch to a different SSL_CTX, we want to update the initial_ctx as
+  // well so that any callbacks don't go to a different object
+  // NOTE: this will only work if we have access to ssl_ internals, so it may
+  // not work on
+  // OpenSSL version >= 1.1.0
+  OpenSSLUtils::setSSLInitialCtx(ssl_, sslCtx);
   // Detach sets the socket's context to the dummy context. Thus we must acquire
   // this lock.
   SpinLockGuard guard(dummyCtxLock);
@@ -521,14 +529,20 @@ void AsyncSSLSocket::detachSSLContext() {
   if (!ssl_) {
     return;
   }
-// Detach the initial_ctx as well.  Internally w/ OPENSSL_NO_TLSEXT
-// it is used for session info.  It will be reattached in attachSSLContext
-#ifndef OPENSSL_NO_TLSEXT
-  if (ssl_->initial_ctx) {
-    SSL_CTX_free(ssl_->initial_ctx);
-    ssl_->initial_ctx = nullptr;
+  // The 'initial_ctx' inside an SSL* points to the context that it was created
+  // with, which is also where session callbacks and servername callbacks
+  // happen.
+  // Detach the initial_ctx as well.  It will be reattached in attachSSLContext
+  // it is used for session info.
+  // NOTE: this will only work if we have access to ssl_ internals, so it may
+  // not work on
+  // OpenSSL version >= 1.1.0
+  SSL_CTX* initialCtx = OpenSSLUtils::getSSLInitialCtx(ssl_);
+  if (initialCtx) {
+    SSL_CTX_free(initialCtx);
+    OpenSSLUtils::setSSLInitialCtx(ssl_, nullptr);
   }
-#endif
+
   SpinLockGuard guard(dummyCtxLock);
   if (nullptr == dummyCtx) {
     // We need to lazily initialize the dummy context so we don't
@@ -541,7 +555,6 @@ void AsyncSSLSocket::detachSSLContext() {
   // would not be thread safe.
   SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
 }
-#endif
 
 #if FOLLY_OPENSSL_HAS_SNI
 void AsyncSSLSocket::switchServerSSLContext(
@@ -574,10 +587,8 @@ bool AsyncSSLSocket::isServerNameMatch() const {
     return false;
   }
 
-  if(!ss->tlsext_hostname) {
-    return false;
-  }
-  return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
+  auto tlsextHostname = SSL_SESSION_get0_hostname(ss);
+  return (tlsextHostname && !tlsextHostname_.compare(tlsextHostname));
 }
 
 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
@@ -691,10 +702,10 @@ void AsyncSSLSocket::connect(
   assert(sslState_ == STATE_UNINIT);
   noTransparentTls_ = true;
   totalConnectTimeout_ = totalConnectTimeout;
-  AsyncSSLSocketConnector* connector =
-      new AsyncSSLSocketConnector(this, callback, totalConnectTimeout.count());
+  AsyncSSLSocketConnector* connector = new AsyncSSLSocketConnector(
+      this, callback, int(totalConnectTimeout.count()));
   AsyncSocket::connect(
-      connector, address, connectTimeout.count(), options, bindAddr);
+      connector, address, int(connectTimeout.count()), options, bindAddr);
 }
 
 bool AsyncSSLSocket::needsPeerVerification() const {
@@ -723,7 +734,7 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
 }
 
 bool AsyncSSLSocket::setupSSLBio() {
-  auto sslBio = BIO_new(&sslBioMethod);
+  auto sslBio = BIO_new(getSSLBioMethod());
 
   if (!sslBio) {
     return false;
@@ -911,7 +922,7 @@ int AsyncSSLSocket::getSSLVersion() const {
 const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
   if (cert) {
-    int nid = OBJ_obj2nid(cert->sig_alg->algorithm);
+    int nid = X509_get_signature_nid(cert);
     return OBJ_nid2ln(nid);
   }
   return nullptr;
@@ -965,14 +976,14 @@ bool AsyncSSLSocket::willBlock(int ret,
     // The timeout (if set) keeps running here
     return true;
 #endif
-  } else if (0
+  } else if ((0
 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
       || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
 #endif
 #ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
       || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
 #endif
-      ) {
+      )) {
     // Our custom openssl function has kicked off an async request to do
     // rsa/ecdsa private key operation.  When that call returns, a callback will
     // be invoked that will re-call handleAccept.
@@ -1327,7 +1338,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
                << "): client intitiated SSL renegotiation not permitted";
     return ReadResult(
         READ_ERROR,
-        folly::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
+        std::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
   }
   if (bytes <= 0) {
     int error = SSL_get_error(ssl_, bytes);
@@ -1347,7 +1358,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
                  << "): unsupported SSL renegotiation during read";
       return ReadResult(
           READ_ERROR,
-          folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
+          std::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
     } else {
       if (zero_return(error, bytes)) {
         return ReadResult(bytes);
@@ -1364,7 +1375,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
               << "reason: " << ERR_reason_error_string(errError);
       return ReadResult(
           READ_ERROR,
-          folly::make_unique<SSLException>(error, errError, bytes, errno));
+          std::make_unique<SSLException>(error, errError, bytes, errno));
     }
   } else {
     appBytesReceived_ += bytes;
@@ -1407,7 +1418,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
                << "unsupported SSL renegotiation during write";
     return WriteResult(
         WRITE_ERROR,
-        folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
+        std::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
   } else {
     if (zero_return(error, rc)) {
       return WriteResult(0);
@@ -1420,7 +1431,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
             << ", reason: " << ERR_reason_error_string(errError);
     return WriteResult(
         WRITE_ERROR,
-        folly::make_unique<SSLException>(error, errError, rc, errno));
+        std::make_unique<SSLException>(error, errError, rc, errno));
   }
 }
 
@@ -1441,7 +1452,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
                << "TODO: AsyncSSLSocket currently does not support calling "
                << "write() before the handshake has fully completed";
     return WriteResult(
-        WRITE_ERROR, folly::make_unique<SSLException>(SSLError::EARLY_WRITE));
+        WRITE_ERROR, std::make_unique<SSLException>(SSLError::EARLY_WRITE));
   }
 
   // Declare a buffer used to hold small write requests.  It could point to a
@@ -1624,7 +1635,6 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
 int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
   struct msghdr msg;
   struct iovec iov;
-  int flags = 0;
   AsyncSSLSocket* tsslSock;
 
   iov.iov_base = const_cast<char*>(in);
@@ -1639,23 +1649,28 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
   tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
   CHECK(tsslSock);
 
+  WriteFlags flags = WriteFlags::NONE;
   if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ &&
       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
-    flags = MSG_EOR;
+    flags |= WriteFlags::EOR;
   }
 
-#ifdef MSG_NOSIGNAL
-  flags |= MSG_NOSIGNAL;
-#endif
-
-#ifdef MSG_MORE
   if (tsslSock->corkCurrentWrite_) {
-    flags |= MSG_MORE;
+    flags |= WriteFlags::CORK;
+  }
+
+  int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(flags);
+  msg.msg_controllen =
+      tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags);
+  CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
+           msg.msg_controllen);
+  if (msg.msg_controllen != 0) {
+    msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
+    tsslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control);
   }
-#endif
 
   auto result = tsslSock->sendSocketMessage(
-      OpenSSLUtils::getBioFd(b, nullptr), &msg, flags);
+      OpenSSLUtils::getBioFd(b, nullptr), &msg, msg_flags);
   BIO_clear_retry_flags(b);
   if (!result.exception && result.writeReturn <= 0) {
     if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) {
@@ -1686,9 +1701,9 @@ int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) {
     queue.append(std::move(sslSock->preReceivedData_));
     queue.trimStart(len);
     sslSock->preReceivedData_ = queue.move();
-    return len;
+    return static_cast<int>(len);
   } else {
-    auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0);
+    auto result = int(recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0));
     if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
       BIO_set_retry_read(b);
     }