X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=folly%2Fio%2Fasync%2FAsyncSSLSocket.cpp;h=8c7ab10c3076d9104768936cf93aa5e2f36fb49f;hb=c5b9338ec192ed46907905d173b65d158a038842;hp=242e0b26bad7665b8c12289b7f393dddf1684359;hpb=7c8f000f43d6f89406e939ddc59da687ba040b78;p=folly.git diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index 242e0b26..8c7ab10c 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2016 Facebook, Inc. + * Copyright 2017 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,18 +22,15 @@ #include #include #include -#include -#include -#include #include #include #include #include #include -#include #include -#include +#include +#include using folly::SocketAddress; using folly::SSLContext; @@ -54,8 +51,11 @@ using folly::AsyncSocketException; using folly::AsyncSSLSocket; using folly::Optional; using folly::SSLContext; +// For OpenSSL portability API +using namespace folly::ssl; using folly::ssl::OpenSSLUtils; + // We have one single dummy SSL context so that we can implement attach // and detach methods in a thread safe fashion without modifying opnessl. static SSLContext *dummyCtx = nullptr; @@ -103,14 +103,15 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback, timeoutLeft = timeout_ - (curTime - startTime_); if (timeoutLeft <= 0) { - AsyncSocketException ex(AsyncSocketException::TIMED_OUT, - "SSL connect timed out"); + AsyncSocketException ex( + AsyncSocketException::TIMED_OUT, + folly::sformat("SSL connect timed out after {}ms", timeout_)); fail(ex); delete this; return; } } - sslSocket_->sslConn(this, timeoutLeft); + sslSocket_->sslConn(this, std::chrono::milliseconds(timeoutLeft)); } void connectErr(const AsyncSocketException& ex) noexcept override { @@ -173,15 +174,22 @@ void setup_SSL_CTX(SSL_CTX *ctx) { } -BIO_METHOD sslWriteBioMethod; +// Note: This is a Leaky Meyer's Singleton. The reason we can't use a non-leaky +// thing is because we will be setting this BIO_METHOD* inside BIOs owned by +// various SSL objects which may get callbacks even during teardown. We may +// eventually try to fix this +static BIO_METHOD* getSSLBioMethod() { + static auto const instance = OpenSSLUtils::newSocketBioMethod().release(); + return instance; +} -void* initsslWriteBioMethod(void) { - memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod)); +void* initsslBioMethod(void) { + auto sslBioMethod = getSSLBioMethod(); // override the bwrite method for MSG_EOR support - OpenSSLUtils::setCustomBioWriteMethod( - &sslWriteBioMethod, AsyncSSLSocket::bioWrite); + OpenSSLUtils::setCustomBioWriteMethod(sslBioMethod, AsyncSSLSocket::bioWrite); + OpenSSLUtils::setCustomBioReadMethod(sslBioMethod, AsyncSSLSocket::bioRead); - // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not + // Note that the sslBioMethod.type and sslBioMethod.name are not // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and // then have specific handlings. The sslWriteBioWrite should be compatible // with the one in openssl. @@ -213,14 +221,39 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, /** * Create a server/client AsyncSSLSocket */ -AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, - EventBase* evb, int fd, bool server, - bool deferSecurityNegotiation) : - AsyncSocket(evb, fd), - server_(server), - ctx_(ctx), - handshakeTimeout_(this, evb), - connectionTimeout_(this, evb) { +AsyncSSLSocket::AsyncSSLSocket( + const shared_ptr& ctx, + EventBase* evb, + int fd, + bool server, + bool deferSecurityNegotiation) + : AsyncSocket(evb, fd), + server_(server), + ctx_(ctx), + handshakeTimeout_(this, evb), + connectionTimeout_(this, evb) { + noTransparentTls_ = true; + init(); + if (server) { + SSL_CTX_set_info_callback( + ctx_->getSSLCtx(), AsyncSSLSocket::sslInfoCallback); + } + if (deferSecurityNegotiation) { + sslState_ = STATE_UNENCRYPTED; + } +} + +AsyncSSLSocket::AsyncSSLSocket( + const shared_ptr& ctx, + AsyncSocket::UniquePtr oldAsyncSocket, + bool server, + bool deferSecurityNegotiation) + : AsyncSocket(std::move(oldAsyncSocket)), + server_(server), + ctx_(ctx), + handshakeTimeout_(this, AsyncSocket::getEventBase()), + connectionTimeout_(this, AsyncSocket::getEventBase()) { + noTransparentTls_ = true; init(); if (server) { SSL_CTX_set_info_callback(ctx_->getSSLCtx(), @@ -231,7 +264,7 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, } } -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI /** * Create a client AsyncSSLSocket and allow tlsext_hostname * to be sent in Client Hello. @@ -248,14 +281,16 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, * Create a client AsyncSSLSocket from an already connected fd * and allow tlsext_hostname to be sent in Client Hello. */ -AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, - EventBase* evb, int fd, - const std::string& serverName, - bool deferSecurityNegotiation) : - AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) { +AsyncSSLSocket::AsyncSSLSocket( + const shared_ptr& ctx, + EventBase* evb, + int fd, + const std::string& serverName, + bool deferSecurityNegotiation) + : AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) { tlsextHostname_ = serverName; } -#endif +#endif // FOLLY_OPENSSL_HAS_SNI AsyncSSLSocket::~AsyncSSLSocket() { VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this @@ -267,8 +302,8 @@ AsyncSSLSocket::~AsyncSSLSocket() { void AsyncSSLSocket::init() { // Do this here to ensure we initialize this once before any use of // AsyncSSLSocket instances and not as part of library load. - static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod(); - (void)sslWriteBioMethodInitializer; + static const auto sslBioMethodInitializer = initsslBioMethod(); + (void)sslBioMethodInitializer; setup_SSL_CTX(ctx_->getSSLCtx()); } @@ -353,13 +388,9 @@ std::string AsyncSSLSocket::getApplicationProtocol() noexcept { return ""; } -bool AsyncSSLSocket::isEorTrackingEnabled() const { - return trackEor_; -} - void AsyncSSLSocket::setEorTracking(bool track) { - if (trackEor_ != track) { - trackEor_ = track; + if (isEorTrackingEnabled() != track) { + AsyncSocket::setEorTracking(track); appEorByteNo_ = 0; minEorRawByteNo_ = 0; } @@ -411,14 +442,13 @@ void AsyncSSLSocket::invalidState(HandshakeCB* callback) { callback->handshakeErr(this, ex); } - // Check the socket state not the ssl state here. - if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) { - failHandshake(__func__, ex); - } + failHandshake(__func__, ex); } -void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout, - const SSLContext::SSLVerifyPeerEnum& verifyPeer) { +void AsyncSSLSocket::sslAccept( + HandshakeCB* callback, + std::chrono::milliseconds timeout, + const SSLContext::SSLVerifyPeerEnum& verifyPeer) { DestructorGuard dg(this); assert(eventBase_->isInEventBaseThread()); verifyPeer_ = verifyPeer; @@ -443,15 +473,16 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout, sslState_ = STATE_ACCEPTING; handshakeCallback_ = callback; - if (timeout > 0) { + if (timeout > std::chrono::milliseconds::zero()) { handshakeTimeout_.scheduleTimeout(timeout); } /* register for a read operation (waiting for CLIENT HELLO) */ updateEventRegistration(EventHandler::READ, EventHandler::WRITE); + + checkForImmediateRead(); } -#if OPENSSL_VERSION_NUMBER >= 0x009080bfL void AsyncSSLSocket::attachSSLContext( const std::shared_ptr& ctx) { @@ -464,24 +495,54 @@ void AsyncSSLSocket::attachSSLContext( DCHECK(ctx->getSSLCtx()); ctx_ = ctx; + // It's possible this could be attached before ssl_ is set up + if (!ssl_) { + return; + } + // In order to call attachSSLContext, detachSSLContext must have been - // previously called which sets the socket's context to the dummy - // context. Thus we must acquire this lock. + // previously called. + // We need to update the initial_ctx if necessary + auto sslCtx = ctx->getSSLCtx(); + SSL_CTX_up_ref(sslCtx); + + // The 'initial_ctx' inside an SSL* points to the context that it was created + // with, which is also where session callbacks and servername callbacks + // happen. + // When we switch to a different SSL_CTX, we want to update the initial_ctx as + // well so that any callbacks don't go to a different object + // NOTE: this will only work if we have access to ssl_ internals, so it may + // not work on + // OpenSSL version >= 1.1.0 + OpenSSLUtils::setSSLInitialCtx(ssl_, sslCtx); + // Detach sets the socket's context to the dummy context. Thus we must acquire + // this lock. SpinLockGuard guard(dummyCtxLock); - SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx()); + SSL_set_SSL_CTX(ssl_, sslCtx); } void AsyncSSLSocket::detachSSLContext() { DCHECK(ctx_); ctx_.reset(); - // We aren't using the initial_ctx for now, and it can introduce race - // conditions in the destructor of the SSL object. -#ifndef OPENSSL_NO_TLSEXT - if (ssl_->initial_ctx) { - SSL_CTX_free(ssl_->initial_ctx); - ssl_->initial_ctx = nullptr; + // It's possible for this to be called before ssl_ has been + // set up + if (!ssl_) { + return; } -#endif + // The 'initial_ctx' inside an SSL* points to the context that it was created + // with, which is also where session callbacks and servername callbacks + // happen. + // Detach the initial_ctx as well. It will be reattached in attachSSLContext + // it is used for session info. + // NOTE: this will only work if we have access to ssl_ internals, so it may + // not work on + // OpenSSL version >= 1.1.0 + SSL_CTX* initialCtx = OpenSSLUtils::getSSLInitialCtx(ssl_); + if (initialCtx) { + SSL_CTX_free(initialCtx); + OpenSSLUtils::setSSLInitialCtx(ssl_, nullptr); + } + SpinLockGuard guard(dummyCtxLock); if (nullptr == dummyCtx) { // We need to lazily initialize the dummy context so we don't @@ -494,9 +555,8 @@ void AsyncSSLSocket::detachSSLContext() { // would not be thread safe. SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx()); } -#endif -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI void AsyncSSLSocket::switchServerSSLContext( const std::shared_ptr& handshakeCtx) { CHECK(server_); @@ -527,22 +587,20 @@ bool AsyncSSLSocket::isServerNameMatch() const { return false; } - if(!ss->tlsext_hostname) { - return false; - } - return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true); + auto tlsextHostname = SSL_SESSION_get0_hostname(ss); + return (tlsextHostname && !tlsextHostname_.compare(tlsextHostname)); } void AsyncSSLSocket::setServerName(std::string serverName) noexcept { tlsextHostname_ = std::move(serverName); } -#endif +#endif // FOLLY_OPENSSL_HAS_SNI -void AsyncSSLSocket::timeoutExpired() noexcept { +void AsyncSSLSocket::timeoutExpired( + std::chrono::milliseconds timeout) noexcept { if (state_ == StateEnum::ESTABLISHED && - (sslState_ == STATE_CACHE_LOOKUP || - sslState_ == STATE_ASYNC_PENDING)) { + (sslState_ == STATE_CACHE_LOOKUP || sslState_ == STATE_ASYNC_PENDING)) { sslState_ = STATE_ERROR; // We are expecting a callback in restartSSLAccept. The cache lookup // and rsa-call necessarily have pointers to this ssl socket, so delay @@ -557,9 +615,12 @@ void AsyncSSLSocket::timeoutExpired() noexcept { assert(state_ == StateEnum::ESTABLISHED && (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING)); DestructorGuard dg(this); - AsyncSocketException ex(AsyncSocketException::TIMED_OUT, - (sslState_ == STATE_CONNECTING) ? - "SSL connect timed out" : "SSL accept timed out"); + AsyncSocketException ex( + AsyncSocketException::TIMED_OUT, + folly::sformat( + "SSL {} timed out after {}ms", + (sslState_ == STATE_CONNECTING) ? "connect" : "accept", + timeout.count())); failHandshake(__func__, ex); } } @@ -619,18 +680,41 @@ void AsyncSSLSocket::cacheLocalPeerAddr() { } } -void AsyncSSLSocket::connect(ConnectCallback* callback, - const folly::SocketAddress& address, - int timeout, - const OptionMap &options, - const folly::SocketAddress& bindAddr) - noexcept { +void AsyncSSLSocket::connect( + ConnectCallback* callback, + const folly::SocketAddress& address, + int timeout, + const OptionMap& options, + const folly::SocketAddress& bindAddr) noexcept { + auto timeoutChrono = std::chrono::milliseconds(timeout); + connect(callback, address, timeoutChrono, timeoutChrono, options, bindAddr); +} + +void AsyncSSLSocket::connect( + ConnectCallback* callback, + const folly::SocketAddress& address, + std::chrono::milliseconds connectTimeout, + std::chrono::milliseconds totalConnectTimeout, + const OptionMap& options, + const folly::SocketAddress& bindAddr) noexcept { assert(!server_); assert(state_ == StateEnum::UNINIT); assert(sslState_ == STATE_UNINIT); - AsyncSSLSocketConnector *connector = - new AsyncSSLSocketConnector(this, callback, timeout); - AsyncSocket::connect(connector, address, timeout, options, bindAddr); + noTransparentTls_ = true; + totalConnectTimeout_ = totalConnectTimeout; + AsyncSSLSocketConnector* connector = + new AsyncSSLSocketConnector(this, callback, totalConnectTimeout.count()); + AsyncSocket::connect( + connector, address, connectTimeout.count(), options, bindAddr); +} + +bool AsyncSSLSocket::needsPeerVerification() const { + if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) { + return ctx_->needsPeerVerification(); + } + return ( + verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY || + verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT); } void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) { @@ -650,20 +734,22 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) { } bool AsyncSSLSocket::setupSSLBio() { - auto wb = BIO_new(&sslWriteBioMethod); + auto sslBio = BIO_new(getSSLBioMethod()); - if (!wb) { + if (!sslBio) { return false; } - OpenSSLUtils::setBioAppData(wb, this); - OpenSSLUtils::setBioFd(wb, fd_, BIO_NOCLOSE); - SSL_set_bio(ssl_, wb, wb); + OpenSSLUtils::setBioAppData(sslBio, this); + OpenSSLUtils::setBioFd(sslBio, fd_, BIO_NOCLOSE); + SSL_set_bio(ssl_, sslBio, sslBio); return true; } -void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout, - const SSLContext::SSLVerifyPeerEnum& verifyPeer) { +void AsyncSSLSocket::sslConn( + HandshakeCB* callback, + std::chrono::milliseconds timeout, + const SSLContext::SSLVerifyPeerEnum& verifyPeer) { DestructorGuard dg(this); assert(eventBase_->isInEventBaseThread()); @@ -711,7 +797,7 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout, SSL_SESSION_free(sslSession_); sslSession_ = nullptr; } -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI if (tlsextHostname_.size()) { SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str()); } @@ -729,7 +815,7 @@ void AsyncSSLSocket::startSSLConnect() { handshakeStartTime_ = std::chrono::steady_clock::now(); // Make end time at least >= start time. handshakeEndTime_ = handshakeStartTime_; - if (handshakeConnectTimeout_ > 0) { + if (handshakeConnectTimeout_ > std::chrono::milliseconds::zero()) { handshakeTimeout_.scheduleTimeout(handshakeConnectTimeout_); } handleConnect(); @@ -751,7 +837,8 @@ void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) { sslSession_ = session; if (!takeOwnership && session != nullptr) { // Increment the reference count - CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION); + // This API exists in BoringSSL and OpenSSL 1.1.0 + SSL_SESSION_up_ref(session); } } @@ -771,7 +858,7 @@ bool AsyncSSLSocket::getSelectedNextProtocolNoThrow( SSLContext::NextProtocolType* protoType) const { *protoName = nullptr; *protoLen = 0; -#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_ALPN SSL_get0_alpn_selected(ssl_, protoName, protoLen); if (*protoLen > 0) { if (protoType) { @@ -835,7 +922,7 @@ int AsyncSSLSocket::getSSLVersion() const { const char *AsyncSSLSocket::getSSLCertSigAlgName() const { X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr; if (cert) { - int nid = OBJ_obj2nid(cert->sig_alg->algorithm); + int nid = X509_get_signature_nid(cert); return OBJ_nid2ln(nid); } return nullptr; @@ -933,6 +1020,8 @@ void AsyncSSLSocket::checkForImmediateRead() noexcept { // the socket to become readable again. if (ssl_ != nullptr && SSL_pending(ssl_) > 0) { AsyncSocket::handleRead(); + } else { + AsyncSocket::checkForImmediateRead(); } } @@ -955,8 +1044,8 @@ AsyncSSLSocket::restartSSLAccept() } if (sslState_ == STATE_ERROR) { // go straight to fail if timeout expired during lookup - AsyncSocketException ex(AsyncSocketException::TIMED_OUT, - "SSL accept timed out"); + AsyncSocketException ex( + AsyncSocketException::TIMED_OUT, "SSL accept timed out"); failHandshake(__func__, ex); return; } @@ -1159,9 +1248,8 @@ void AsyncSSLSocket::scheduleConnectTimeout() { assert(connectCallback_ == nullptr); // We use a different connect timeout here than the handshake timeout, so // that we can disambiguate the 2 timers. - int timeout = connectTimeout_.count(); - if (timeout > 0) { - if (!connectionTimeout_.scheduleTimeout(timeout)) { + if (connectTimeout_.count() > 0) { + if (!connectionTimeout_.scheduleTimeout(connectTimeout_)) { throw AsyncSocketException( AsyncSocketException::INTERNAL_ERROR, withAddr("failed to schedule AsyncSSLSocket connect timeout")); @@ -1234,9 +1322,9 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) { return AsyncSocket::performRead(buf, buflen, offset); } - ssize_t bytes = 0; + int bytes = 0; if (!isBufferMovable_) { - bytes = SSL_read(ssl_, *buf, *buflen); + bytes = SSL_read(ssl_, *buf, int(*buflen)); } #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP else { @@ -1275,7 +1363,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) { if (zero_return(error, bytes)) { return ReadResult(bytes); } - long errError = ERR_get_error(); + auto errError = ERR_get_error(); VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", " << "state=" << state_ << ", " << "sslState=" << sslState_ << ", " @@ -1425,8 +1513,13 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite( uint32_t nextIndex = i + buffersStolen + 1; bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len, minWriteSize_ - len); - memcpy(combinedBuf + len, vec[nextIndex].iov_base, - bytesStolenFromNextBuffer); + if (bytesStolenFromNextBuffer > 0) { + assert(vec[nextIndex].iov_base != nullptr); + ::memcpy( + combinedBuf + len, + vec[nextIndex].iov_base, + bytesStolenFromNextBuffer); + } len += bytesStolenFromNextBuffer; if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) { // couldn't steal the whole buffer @@ -1451,17 +1544,17 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite( bytes = eorAwareSSLWrite( ssl_, sslWriteBuf, - len, + int(len), (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count)); if (bytes <= 0) { - int error = SSL_get_error(ssl_, bytes); + int error = SSL_get_error(ssl_, int(bytes)); if (error == SSL_ERROR_WANT_WRITE) { // The caller will register for write event if not already. - *partialWritten = offset; + *partialWritten = uint32_t(offset); return WriteResult(totalWritten); } - auto writeResult = interpretSSLError(bytes, error); + auto writeResult = interpretSSLError(int(bytes), error); if (writeResult.writeReturn < 0) { return writeResult; } // else fall through to below to correctly record totalWritten @@ -1484,7 +1577,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite( (*countWritten)++; v = &(vec[++i]); } - *partialWritten = bytes; + *partialWritten = uint32_t(bytes); return WriteResult(totalWritten); } } @@ -1494,7 +1587,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite( int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n, bool eor) { - if (eor && trackEor_) { + if (eor && isEorTrackingEnabled()) { if (appEorByteNo_) { // cannot track for more than one app byte EOR CHECK(appEorByteNo_ == appBytesWritten_ + n); @@ -1542,11 +1635,10 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) { int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) { struct msghdr msg; struct iovec iov; - int flags = 0; AsyncSSLSocket* tsslSock; iov.iov_base = const_cast(in); - iov.iov_len = inl; + iov.iov_len = size_t(inl); memset(&msg, 0, sizeof(msg)); msg.msg_iov = &iov; msg.msg_iovlen = 1; @@ -1557,30 +1649,66 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) { tsslSock = reinterpret_cast(appData); CHECK(tsslSock); - if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ && + WriteFlags flags = WriteFlags::NONE; + if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ && tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) { - flags = MSG_EOR; + flags |= WriteFlags::EOR; } -#ifdef MSG_NOSIGNAL - flags |= MSG_NOSIGNAL; -#endif - -#ifdef MSG_MORE if (tsslSock->corkCurrentWrite_) { - flags |= MSG_MORE; + flags |= WriteFlags::CORK; + } + + int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(flags); + msg.msg_controllen = + tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags); + CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize, + msg.msg_controllen); + if (msg.msg_controllen != 0) { + msg.msg_control = reinterpret_cast(alloca(msg.msg_controllen)); + tsslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control); } -#endif auto result = tsslSock->sendSocketMessage( - OpenSSLUtils::getBioFd(b, nullptr), &msg, flags); + OpenSSLUtils::getBioFd(b, nullptr), &msg, msg_flags); BIO_clear_retry_flags(b); if (!result.exception && result.writeReturn <= 0) { - if (OpenSSLUtils::getBioShouldRetryWrite(result.writeReturn)) { + if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) { BIO_set_retry_write(b); } } - return result.writeReturn; + return int(result.writeReturn); +} + +int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) { + if (!out) { + return 0; + } + BIO_clear_retry_flags(b); + + auto appData = OpenSSLUtils::getBioAppData(b); + CHECK(appData); + auto sslSock = reinterpret_cast(appData); + + if (sslSock->preReceivedData_ && !sslSock->preReceivedData_->empty()) { + VLOG(5) << "AsyncSSLSocket::bioRead() this=" << sslSock + << ", reading pre-received data"; + + Cursor cursor(sslSock->preReceivedData_.get()); + auto len = cursor.pullAtMost(out, outl); + + IOBufQueue queue; + queue.append(std::move(sslSock->preReceivedData_)); + queue.trimStart(len); + sslSock->preReceivedData_ = queue.move(); + return len; + } else { + auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0); + if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) { + BIO_set_retry_read(b); + } + return result; + } } int AsyncSSLSocket::sslVerifyCallback( @@ -1697,12 +1825,20 @@ void AsyncSSLSocket::clientHelloParsingCallback(int written, sock->clientHelloInfo_-> clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg); } + } else if (extensionType == ssl::TLSExtension::SUPPORTED_VERSIONS) { + cursor.skip(1); + extensionDataLength -= 1; + while (extensionDataLength) { + sock->clientHelloInfo_->clientHelloSupportedVersions_.push_back( + cursor.readBE()); + extensionDataLength -= 2; + } } else { cursor.skip(extensionDataLength); } } } - } catch (std::out_of_range& e) { + } catch (std::out_of_range&) { // we'll use what we found and cleanup below. VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): " << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock; @@ -1790,6 +1926,13 @@ std::string AsyncSSLSocket::getSSLClientSigAlgs() const { return sigAlgs; } +std::string AsyncSSLSocket::getSSLClientSupportedVersions() const { + if (!parseClientHello_) { + return ""; + } + return folly::join(":", clientHelloInfo_->clientHelloSupportedVersions_); +} + std::string AsyncSSLSocket::getSSLAlertsReceived() const { std::string ret; @@ -1803,6 +1946,14 @@ std::string AsyncSSLSocket::getSSLAlertsReceived() const { return ret; } +void AsyncSSLSocket::setSSLCertVerificationAlert(std::string alert) { + sslVerificationAlert_ = std::move(alert); +} + +std::string AsyncSSLSocket::getSSLCertVerificationAlert() const { + return sslVerificationAlert_; +} + void AsyncSSLSocket::getSSLSharedCiphers(std::string& sharedCiphers) const { char ciphersBuffer[1024]; ciphersBuffer[0] = '\0';