X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=folly%2Fio%2Fasync%2FAsyncSSLSocket.cpp;h=8c7ab10c3076d9104768936cf93aa5e2f36fb49f;hb=c5b9338ec192ed46907905d173b65d158a038842;hp=8c60902055be92282d2023ef024f3e79eea50054;hpb=193eb597fb6e9cd1bea0269ab04ca30750183785;p=folly.git diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index 8c609020..8c7ab10c 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2015 Facebook, Inc. + * Copyright 2017 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,25 +17,20 @@ #include #include +#include #include #include #include -#include -#include -#include -#include -#include #include -#include -#include #include #include #include #include -#include #include +#include +#include using folly::SocketAddress; using folly::SSLContext; @@ -55,16 +50,17 @@ using folly::AsyncSocket; using folly::AsyncSocketException; using folly::AsyncSSLSocket; using folly::Optional; +using folly::SSLContext; +// For OpenSSL portability API +using namespace folly::ssl; +using folly::ssl::OpenSSLUtils; + // We have one single dummy SSL context so that we can implement attach // and detach methods in a thread safe fashion without modifying opnessl. static SSLContext *dummyCtx = nullptr; static SpinLock dummyCtxLock; -// Numbers chosen as to not collide with functions in ssl.h -const uint8_t TASYNCSSLSOCKET_F_PERFORM_READ = 90; -const uint8_t TASYNCSSLSOCKET_F_PERFORM_WRITE = 91; - // If given min write size is less than this, buffer will be allocated on // stack, otherwise it is allocated on heap const size_t MAX_STACK_BUF_SIZE = 2048; @@ -84,8 +80,7 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback, int64_t startTime_; protected: - virtual ~AsyncSSLSocketConnector() { - } + ~AsyncSSLSocketConnector() override {} public: AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket, @@ -98,7 +93,7 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback, std::chrono::steady_clock::now().time_since_epoch()).count()) { } - virtual void connectSuccess() noexcept { + void connectSuccess() noexcept override { VLOG(7) << "client socket connected"; int64_t timeoutLeft = 0; @@ -108,23 +103,24 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback, timeoutLeft = timeout_ - (curTime - startTime_); if (timeoutLeft <= 0) { - AsyncSocketException ex(AsyncSocketException::TIMED_OUT, - "SSL connect timed out"); + AsyncSocketException ex( + AsyncSocketException::TIMED_OUT, + folly::sformat("SSL connect timed out after {}ms", timeout_)); fail(ex); delete this; return; } } - sslSocket_->sslConn(this, timeoutLeft); + sslSocket_->sslConn(this, std::chrono::milliseconds(timeoutLeft)); } - virtual void connectErr(const AsyncSocketException& ex) noexcept { - LOG(ERROR) << "TCP connect failed: " << ex.what(); + void connectErr(const AsyncSocketException& ex) noexcept override { + VLOG(1) << "TCP connect failed: " << ex.what(); fail(ex); delete this; } - virtual void handshakeSuc(AsyncSSLSocket *sock) noexcept { + void handshakeSuc(AsyncSSLSocket* /* sock */) noexcept override { VLOG(7) << "client handshake success"; if (callback_) { callback_->connectSuccess(); @@ -132,9 +128,9 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback, delete this; } - virtual void handshakeErr(AsyncSSLSocket *socket, - const AsyncSocketException& ex) noexcept { - LOG(ERROR) << "client handshakeErr: " << ex.what(); + void handshakeErr(AsyncSSLSocket* /* socket */, + const AsyncSocketException& ex) noexcept override { + VLOG(1) << "client handshakeErr: " << ex.what(); fail(ex); delete this; } @@ -156,57 +152,6 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback, } }; -// XXX: implement an equivalent to corking for platforms with TCP_NOPUSH? -#ifdef TCP_CORK // Linux-only -/** - * Utility class that corks a TCP socket upon construction or uncorks - * the socket upon destruction - */ -class CorkGuard : private boost::noncopyable { - public: - CorkGuard(int fd, bool multipleWrites, bool haveMore, bool* corked): - fd_(fd), haveMore_(haveMore), corked_(corked) { - if (*corked_) { - // socket is already corked; nothing to do - return; - } - if (multipleWrites || haveMore) { - // We are performing multiple writes in this performWrite() call, - // and/or there are more calls to performWrite() that will be invoked - // later, so enable corking - int flag = 1; - setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag)); - *corked_ = true; - } - } - - ~CorkGuard() { - if (haveMore_) { - // more data to come; don't uncork yet - return; - } - if (!*corked_) { - // socket isn't corked; nothing to do - return; - } - - int flag = 0; - setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag)); - *corked_ = false; - } - - private: - int fd_; - bool haveMore_; - bool* corked_; -}; -#else -class CorkGuard : private boost::noncopyable { - public: - CorkGuard(int, bool, bool, bool*) {} -}; -#endif - void setup_SSL_CTX(SSL_CTX *ctx) { #ifdef SSL_MODE_RELEASE_BUFFERS SSL_CTX_set_mode(ctx, @@ -229,16 +174,24 @@ void setup_SSL_CTX(SSL_CTX *ctx) { } -BIO_METHOD eorAwareBioMethod; +// Note: This is a Leaky Meyer's Singleton. The reason we can't use a non-leaky +// thing is because we will be setting this BIO_METHOD* inside BIOs owned by +// various SSL objects which may get callbacks even during teardown. We may +// eventually try to fix this +static BIO_METHOD* getSSLBioMethod() { + static auto const instance = OpenSSLUtils::newSocketBioMethod().release(); + return instance; +} -void* initEorBioMethod(void) { - memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod)); +void* initsslBioMethod(void) { + auto sslBioMethod = getSSLBioMethod(); // override the bwrite method for MSG_EOR support - eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite; + OpenSSLUtils::setCustomBioWriteMethod(sslBioMethod, AsyncSSLSocket::bioWrite); + OpenSSLUtils::setCustomBioReadMethod(sslBioMethod, AsyncSSLSocket::bioRead); - // Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not + // Note that the sslBioMethod.type and sslBioMethod.name are not // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and - // then have specific handlings. The eorAwareBioWrite should be compatible + // then have specific handlings. The sslWriteBioWrite should be compatible // with the one in openssl. // Return something here to enable AsyncSSLSocket to call this method using @@ -250,12 +203,6 @@ void* initEorBioMethod(void) { namespace folly { -SSLException::SSLException(int sslError, int errno_copy): - AsyncSocketException( - AsyncSocketException::SSL_ERROR, - ERR_error_string(sslError, msg_), - sslError == SSL_ERROR_SYSCALL ? errno_copy : 0), error_(sslError) {} - /** * Create a client AsyncSSLSocket */ @@ -263,7 +210,8 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, EventBase* evb, bool deferSecurityNegotiation) : AsyncSocket(evb), ctx_(ctx), - handshakeTimeout_(this, evb) { + handshakeTimeout_(this, evb), + connectionTimeout_(this, evb) { init(); if (deferSecurityNegotiation) { sslState_ = STATE_UNENCRYPTED; @@ -273,13 +221,39 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, /** * Create a server/client AsyncSSLSocket */ -AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, - EventBase* evb, int fd, bool server, - bool deferSecurityNegotiation) : - AsyncSocket(evb, fd), - server_(server), - ctx_(ctx), - handshakeTimeout_(this, evb) { +AsyncSSLSocket::AsyncSSLSocket( + const shared_ptr& ctx, + EventBase* evb, + int fd, + bool server, + bool deferSecurityNegotiation) + : AsyncSocket(evb, fd), + server_(server), + ctx_(ctx), + handshakeTimeout_(this, evb), + connectionTimeout_(this, evb) { + noTransparentTls_ = true; + init(); + if (server) { + SSL_CTX_set_info_callback( + ctx_->getSSLCtx(), AsyncSSLSocket::sslInfoCallback); + } + if (deferSecurityNegotiation) { + sslState_ = STATE_UNENCRYPTED; + } +} + +AsyncSSLSocket::AsyncSSLSocket( + const shared_ptr& ctx, + AsyncSocket::UniquePtr oldAsyncSocket, + bool server, + bool deferSecurityNegotiation) + : AsyncSocket(std::move(oldAsyncSocket)), + server_(server), + ctx_(ctx), + handshakeTimeout_(this, AsyncSocket::getEventBase()), + connectionTimeout_(this, AsyncSocket::getEventBase()) { + noTransparentTls_ = true; init(); if (server) { SSL_CTX_set_info_callback(ctx_->getSSLCtx(), @@ -290,7 +264,7 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, } } -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI /** * Create a client AsyncSSLSocket and allow tlsext_hostname * to be sent in Client Hello. @@ -307,14 +281,16 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, * Create a client AsyncSSLSocket from an already connected fd * and allow tlsext_hostname to be sent in Client Hello. */ -AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, - EventBase* evb, int fd, - const std::string& serverName, - bool deferSecurityNegotiation) : - AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) { +AsyncSSLSocket::AsyncSSLSocket( + const shared_ptr& ctx, + EventBase* evb, + int fd, + const std::string& serverName, + bool deferSecurityNegotiation) + : AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) { tlsextHostname_ = serverName; } -#endif +#endif // FOLLY_OPENSSL_HAS_SNI AsyncSSLSocket::~AsyncSSLSocket() { VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this @@ -326,7 +302,9 @@ AsyncSSLSocket::~AsyncSSLSocket() { void AsyncSSLSocket::init() { // Do this here to ensure we initialize this once before any use of // AsyncSSLSocket instances and not as part of library load. - static const auto eorAwareBioMethodInitializer = initEorBioMethod(); + static const auto sslBioMethodInitializer = initsslBioMethod(); + (void)sslBioMethodInitializer; + setup_SSL_CTX(ctx_->getSSLCtx()); } @@ -355,13 +333,10 @@ void AsyncSSLSocket::closeNow() { DestructorGuard dg(this); - if (handshakeCallback_) { - AsyncSocketException ex(AsyncSocketException::END_OF_FILE, - "SSL connection closed locally"); - HandshakeCB* callback = handshakeCallback_; - handshakeCallback_ = nullptr; - callback->handshakeErr(this, ex); - } + invokeHandshakeErr( + AsyncSocketException( + AsyncSocketException::END_OF_FILE, + "SSL connection closed locally")); if (ssl_ != nullptr) { SSL_free(ssl_); @@ -404,42 +379,37 @@ bool AsyncSSLSocket::connecting() const { sslState_ == STATE_CONNECTING)))); } -bool AsyncSSLSocket::isEorTrackingEnabled() const { - const BIO *wb = SSL_get_wbio(ssl_); - return wb && wb->method == &eorAwareBioMethod; +std::string AsyncSSLSocket::getApplicationProtocol() noexcept { + const unsigned char* protoName = nullptr; + unsigned protoLength; + if (getSelectedNextProtocolNoThrow(&protoName, &protoLength)) { + return std::string(reinterpret_cast(protoName), protoLength); + } + return ""; } void AsyncSSLSocket::setEorTracking(bool track) { - BIO *wb = SSL_get_wbio(ssl_); - if (!wb) { - throw AsyncSocketException(AsyncSocketException::INVALID_STATE, - "setting EOR tracking without an initialized " - "BIO"); - } - - if (track) { - if (wb->method != &eorAwareBioMethod) { - // only do this if we didn't - wb->method = &eorAwareBioMethod; - BIO_set_app_data(wb, this); - appEorByteNo_ = 0; - minEorRawByteNo_ = 0; - } - } else if (wb->method == &eorAwareBioMethod) { - wb->method = BIO_s_socket(); - BIO_set_app_data(wb, nullptr); + if (isEorTrackingEnabled() != track) { + AsyncSocket::setEorTracking(track); appEorByteNo_ = 0; minEorRawByteNo_ = 0; - } else { - CHECK(wb->method == BIO_s_socket()); } } size_t AsyncSSLSocket::getRawBytesWritten() const { + // The bio(s) in the write path are in a chain + // each bio flushes to the next and finally written into the socket + // to get the rawBytesWritten on the socket, + // get the write bytes of the last bio BIO *b; if (!ssl_ || !(b = SSL_get_wbio(ssl_))) { return 0; } + BIO* next = BIO_next(b); + while (next != NULL) { + b = next; + next = BIO_next(b); + } return BIO_number_written(b); } @@ -457,28 +427,28 @@ size_t AsyncSSLSocket::getRawBytesReceived() const { void AsyncSSLSocket::invalidState(HandshakeCB* callback) { LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_ << ", state=" << int(state_) << ", sslState=" << sslState_ << ", " - << "events=" << eventFlags_ << ", server=" << short(server_) << "): " - << "sslAccept/Connect() called in invalid " - << "state, handshake callback " << handshakeCallback_ << ", new callback " - << callback; + << "events=" << eventFlags_ << ", server=" << short(server_) + << "): " << "sslAccept/Connect() called in invalid " + << "state, handshake callback " << handshakeCallback_ + << ", new callback " << callback; assert(!handshakeTimeout_.isScheduled()); sslState_ = STATE_ERROR; AsyncSocketException ex(AsyncSocketException::INVALID_STATE, "sslAccept() called with socket in invalid state"); + handshakeEndTime_ = std::chrono::steady_clock::now(); if (callback) { callback->handshakeErr(this, ex); } - // Check the socket state not the ssl state here. - if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) { - failHandshake(__func__, ex); - } + failHandshake(__func__, ex); } -void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout, - const SSLContext::SSLVerifyPeerEnum& verifyPeer) { +void AsyncSSLSocket::sslAccept( + HandshakeCB* callback, + std::chrono::milliseconds timeout, + const SSLContext::SSLVerifyPeerEnum& verifyPeer) { DestructorGuard dg(this); assert(eventBase_->isInEventBaseThread()); verifyPeer_ = verifyPeer; @@ -490,18 +460,29 @@ void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout, return invalidState(callback); } + // Cache local and remote socket addresses to keep them available + // after socket file descriptor is closed. + if (cacheAddrOnFailure_ && -1 != getFd()) { + cacheLocalPeerAddr(); + } + + handshakeStartTime_ = std::chrono::steady_clock::now(); + // Make end time at least >= start time. + handshakeEndTime_ = handshakeStartTime_; + sslState_ = STATE_ACCEPTING; handshakeCallback_ = callback; - if (timeout > 0) { + if (timeout > std::chrono::milliseconds::zero()) { handshakeTimeout_.scheduleTimeout(timeout); } /* register for a read operation (waiting for CLIENT HELLO) */ updateEventRegistration(EventHandler::READ, EventHandler::WRITE); + + checkForImmediateRead(); } -#if OPENSSL_VERSION_NUMBER >= 0x009080bfL void AsyncSSLSocket::attachSSLContext( const std::shared_ptr& ctx) { @@ -514,24 +495,54 @@ void AsyncSSLSocket::attachSSLContext( DCHECK(ctx->getSSLCtx()); ctx_ = ctx; + // It's possible this could be attached before ssl_ is set up + if (!ssl_) { + return; + } + // In order to call attachSSLContext, detachSSLContext must have been - // previously called which sets the socket's context to the dummy - // context. Thus we must acquire this lock. + // previously called. + // We need to update the initial_ctx if necessary + auto sslCtx = ctx->getSSLCtx(); + SSL_CTX_up_ref(sslCtx); + + // The 'initial_ctx' inside an SSL* points to the context that it was created + // with, which is also where session callbacks and servername callbacks + // happen. + // When we switch to a different SSL_CTX, we want to update the initial_ctx as + // well so that any callbacks don't go to a different object + // NOTE: this will only work if we have access to ssl_ internals, so it may + // not work on + // OpenSSL version >= 1.1.0 + OpenSSLUtils::setSSLInitialCtx(ssl_, sslCtx); + // Detach sets the socket's context to the dummy context. Thus we must acquire + // this lock. SpinLockGuard guard(dummyCtxLock); - SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx()); + SSL_set_SSL_CTX(ssl_, sslCtx); } void AsyncSSLSocket::detachSSLContext() { DCHECK(ctx_); ctx_.reset(); - // We aren't using the initial_ctx for now, and it can introduce race - // conditions in the destructor of the SSL object. -#ifndef OPENSSL_NO_TLSEXT - if (ssl_->initial_ctx) { - SSL_CTX_free(ssl_->initial_ctx); - ssl_->initial_ctx = nullptr; + // It's possible for this to be called before ssl_ has been + // set up + if (!ssl_) { + return; } -#endif + // The 'initial_ctx' inside an SSL* points to the context that it was created + // with, which is also where session callbacks and servername callbacks + // happen. + // Detach the initial_ctx as well. It will be reattached in attachSSLContext + // it is used for session info. + // NOTE: this will only work if we have access to ssl_ internals, so it may + // not work on + // OpenSSL version >= 1.1.0 + SSL_CTX* initialCtx = OpenSSLUtils::getSSLInitialCtx(ssl_); + if (initialCtx) { + SSL_CTX_free(initialCtx); + OpenSSLUtils::setSSLInitialCtx(ssl_, nullptr); + } + SpinLockGuard guard(dummyCtxLock); if (nullptr == dummyCtx) { // We need to lazily initialize the dummy context so we don't @@ -544,9 +555,8 @@ void AsyncSSLSocket::detachSSLContext() { // would not be thread safe. SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx()); } -#endif -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI void AsyncSSLSocket::switchServerSSLContext( const std::shared_ptr& handshakeCtx) { CHECK(server_); @@ -577,49 +587,48 @@ bool AsyncSSLSocket::isServerNameMatch() const { return false; } - if(!ss->tlsext_hostname) { - return false; - } - return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true); + auto tlsextHostname = SSL_SESSION_get0_hostname(ss); + return (tlsextHostname && !tlsextHostname_.compare(tlsextHostname)); } void AsyncSSLSocket::setServerName(std::string serverName) noexcept { tlsextHostname_ = std::move(serverName); } -#endif +#endif // FOLLY_OPENSSL_HAS_SNI -void AsyncSSLSocket::timeoutExpired() noexcept { +void AsyncSSLSocket::timeoutExpired( + std::chrono::milliseconds timeout) noexcept { if (state_ == StateEnum::ESTABLISHED && - (sslState_ == STATE_CACHE_LOOKUP || - sslState_ == STATE_RSA_ASYNC_PENDING)) { + (sslState_ == STATE_CACHE_LOOKUP || sslState_ == STATE_ASYNC_PENDING)) { sslState_ = STATE_ERROR; // We are expecting a callback in restartSSLAccept. The cache lookup // and rsa-call necessarily have pointers to this ssl socket, so delay // the cleanup until he calls us back. + } else if (state_ == StateEnum::CONNECTING) { + assert(sslState_ == STATE_CONNECTING); + DestructorGuard dg(this); + AsyncSocketException ex(AsyncSocketException::TIMED_OUT, + "Fallback connect timed out during TFO"); + failHandshake(__func__, ex); } else { assert(state_ == StateEnum::ESTABLISHED && (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING)); DestructorGuard dg(this); - AsyncSocketException ex(AsyncSocketException::TIMED_OUT, - (sslState_ == STATE_CONNECTING) ? - "SSL connect timed out" : "SSL accept timed out"); + AsyncSocketException ex( + AsyncSocketException::TIMED_OUT, + folly::sformat( + "SSL {} timed out after {}ms", + (sslState_ == STATE_CONNECTING) ? "connect" : "accept", + timeout.count())); failHandshake(__func__, ex); } } -int AsyncSSLSocket::sslExDataIndex_ = -1; -std::mutex AsyncSSLSocket::mutex_; - int AsyncSSLSocket::getSSLExDataIndex() { - if (sslExDataIndex_ < 0) { - std::lock_guard g(mutex_); - if (sslExDataIndex_ < 0) { - sslExDataIndex_ = SSL_get_ex_new_index(0, - (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr); - } - } - return sslExDataIndex_; + static auto index = SSL_get_ex_new_index( + 0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr); + return index; } AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) { @@ -627,23 +636,27 @@ AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) { getSSLExDataIndex())); } -void AsyncSSLSocket::failHandshake(const char* fn, - const AsyncSocketException& ex) { +void AsyncSSLSocket::failHandshake(const char* /* fn */, + const AsyncSocketException& ex) { startFail(); - if (handshakeTimeout_.isScheduled()) { handshakeTimeout_.cancelTimeout(); } + invokeHandshakeErr(ex); + finishFail(); +} + +void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) { + handshakeEndTime_ = std::chrono::steady_clock::now(); if (handshakeCallback_ != nullptr) { HandshakeCB* callback = handshakeCallback_; handshakeCallback_ = nullptr; callback->handshakeErr(this, ex); } - - finishFail(); } void AsyncSSLSocket::invokeHandshakeCB() { + handshakeEndTime_ = std::chrono::steady_clock::now(); if (handshakeTimeout_.isScheduled()) { handshakeTimeout_.cancelTimeout(); } @@ -654,18 +667,54 @@ void AsyncSSLSocket::invokeHandshakeCB() { } } -void AsyncSSLSocket::connect(ConnectCallback* callback, - const folly::SocketAddress& address, - int timeout, - const OptionMap &options, - const folly::SocketAddress& bindAddr) - noexcept { +void AsyncSSLSocket::cacheLocalPeerAddr() { + SocketAddress address; + try { + getLocalAddress(&address); + getPeerAddress(&address); + } catch (const std::system_error& e) { + // The handle can be still valid while the connection is already closed. + if (e.code() != std::error_code(ENOTCONN, std::system_category())) { + throw; + } + } +} + +void AsyncSSLSocket::connect( + ConnectCallback* callback, + const folly::SocketAddress& address, + int timeout, + const OptionMap& options, + const folly::SocketAddress& bindAddr) noexcept { + auto timeoutChrono = std::chrono::milliseconds(timeout); + connect(callback, address, timeoutChrono, timeoutChrono, options, bindAddr); +} + +void AsyncSSLSocket::connect( + ConnectCallback* callback, + const folly::SocketAddress& address, + std::chrono::milliseconds connectTimeout, + std::chrono::milliseconds totalConnectTimeout, + const OptionMap& options, + const folly::SocketAddress& bindAddr) noexcept { assert(!server_); assert(state_ == StateEnum::UNINIT); assert(sslState_ == STATE_UNINIT); - AsyncSSLSocketConnector *connector = - new AsyncSSLSocketConnector(this, callback, timeout); - AsyncSocket::connect(connector, address, timeout, options, bindAddr); + noTransparentTls_ = true; + totalConnectTimeout_ = totalConnectTimeout; + AsyncSSLSocketConnector* connector = + new AsyncSSLSocketConnector(this, callback, totalConnectTimeout.count()); + AsyncSocket::connect( + connector, address, connectTimeout.count(), options, bindAddr); +} + +bool AsyncSSLSocket::needsPeerVerification() const { + if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) { + return ctx_->needsPeerVerification(); + } + return ( + verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY || + verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT); } void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) { @@ -684,11 +733,32 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) { } } -void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout, - const SSLContext::SSLVerifyPeerEnum& verifyPeer) { +bool AsyncSSLSocket::setupSSLBio() { + auto sslBio = BIO_new(getSSLBioMethod()); + + if (!sslBio) { + return false; + } + + OpenSSLUtils::setBioAppData(sslBio, this); + OpenSSLUtils::setBioFd(sslBio, fd_, BIO_NOCLOSE); + SSL_set_bio(ssl_, sslBio, sslBio); + return true; +} + +void AsyncSSLSocket::sslConn( + HandshakeCB* callback, + std::chrono::milliseconds timeout, + const SSLContext::SSLVerifyPeerEnum& verifyPeer) { DestructorGuard dg(this); assert(eventBase_->isInEventBaseThread()); + // Cache local and remote socket addresses to keep them available + // after socket file descriptor is closed. + if (cacheAddrOnFailure_ && -1 != getFd()) { + cacheLocalPeerAddr(); + } + verifyPeer_ = verifyPeer; // Make sure we're in the uninitialized state @@ -712,15 +782,22 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout, return failHandshake(__func__, ex); } + if (!setupSSLBio()) { + sslState_ = STATE_ERROR; + AsyncSocketException ex( + AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio"); + return failHandshake(__func__, ex); + } + applyVerificationOptions(ssl_); - SSL_set_fd(ssl_, fd_); if (sslSession_ != nullptr) { + sessionResumptionAttempted_ = true; SSL_set_session(ssl_, sslSession_); SSL_SESSION_free(sslSession_); sslSession_ = nullptr; } -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI if (tlsextHostname_.size()) { SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str()); } @@ -728,10 +805,19 @@ void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout, SSL_set_ex_data(ssl_, getSSLExDataIndex(), this); - if (timeout > 0) { - handshakeTimeout_.scheduleTimeout(timeout); - } + handshakeConnectTimeout_ = timeout; + startSSLConnect(); +} +// This could be called multiple times, during normal ssl connections +// and after TFO fallback. +void AsyncSSLSocket::startSSLConnect() { + handshakeStartTime_ = std::chrono::steady_clock::now(); + // Make end time at least >= start time. + handshakeEndTime_ = handshakeStartTime_; + if (handshakeConnectTimeout_ > std::chrono::milliseconds::zero()) { + handshakeTimeout_.scheduleTimeout(handshakeConnectTimeout_); + } handleConnect(); } @@ -743,31 +829,52 @@ SSL_SESSION *AsyncSSLSocket::getSSLSession() { return sslSession_; } +const SSL* AsyncSSLSocket::getSSL() const { + return ssl_; +} + void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) { sslSession_ = session; if (!takeOwnership && session != nullptr) { // Increment the reference count - CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION); + // This API exists in BoringSSL and OpenSSL 1.1.0 + SSL_SESSION_up_ref(session); } } -void AsyncSSLSocket::getSelectedNextProtocol(const unsigned char** protoName, - unsigned* protoLen) const { - if (!getSelectedNextProtocolNoThrow(protoName, protoLen)) { +void AsyncSSLSocket::getSelectedNextProtocol( + const unsigned char** protoName, + unsigned* protoLen, + SSLContext::NextProtocolType* protoType) const { + if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) { throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED, "NPN not supported"); } } bool AsyncSSLSocket::getSelectedNextProtocolNoThrow( - const unsigned char** protoName, - unsigned* protoLen) const { + const unsigned char** protoName, + unsigned* protoLen, + SSLContext::NextProtocolType* protoType) const { *protoName = nullptr; *protoLen = 0; +#if FOLLY_OPENSSL_HAS_ALPN + SSL_get0_alpn_selected(ssl_, protoName, protoLen); + if (*protoLen > 0) { + if (protoType) { + *protoType = SSLContext::NextProtocolType::ALPN; + } + return true; + } +#endif #ifdef OPENSSL_NPN_NEGOTIATED SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen); + if (protoType) { + *protoType = SSLContext::NextProtocolType::NPN; + } return true; #else + (void)protoType; return false; #endif } @@ -783,28 +890,44 @@ const char *AsyncSSLSocket::getNegotiatedCipherName() const { return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr; } +/* static */ +const char* AsyncSSLSocket::getSSLServerNameFromSSL(SSL* ssl) { + if (ssl == nullptr) { + return nullptr; + } +#ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB + return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); +#else + return nullptr; +#endif +} + const char *AsyncSSLSocket::getSSLServerName() const { #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB - return (ssl_ != nullptr) ? SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name) - : nullptr; + return getSSLServerNameFromSSL(ssl_); #else throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED, - "SNI not supported"); + "SNI not supported"); #endif } const char *AsyncSSLSocket::getSSLServerNameNoThrow() const { - try { - return getSSLServerName(); - } catch (AsyncSocketException& ex) { - return nullptr; - } + return getSSLServerNameFromSSL(ssl_); } int AsyncSSLSocket::getSSLVersion() const { return (ssl_ != nullptr) ? SSL_version(ssl_) : 0; } +const char *AsyncSSLSocket::getSSLCertSigAlgName() const { + X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr; + if (cert) { + int nid = X509_get_signature_nid(cert); + return OBJ_nid2ln(nid); + } + return nullptr; +} + int AsyncSSLSocket::getSSLCertSize() const { int certSize = 0; X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr; @@ -816,8 +939,15 @@ int AsyncSSLSocket::getSSLCertSize() const { return certSize; } -bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept { - int error = *errorOut = SSL_get_error(ssl_, ret); +const X509* AsyncSSLSocket::getSelfCert() const { + return (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr; +} + +bool AsyncSSLSocket::willBlock(int ret, + int* sslErrorOut, + unsigned long* errErrorOut) noexcept { + *errErrorOut = 0; + int error = *sslErrorOut = SSL_get_error(ssl_, ret); if (error == SSL_ERROR_WANT_READ) { // Register for read event if not already. updateEventRegistration(EventHandler::READ, EventHandler::WRITE); @@ -846,12 +976,18 @@ bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept { // The timeout (if set) keeps running here return true; #endif + } else if (0 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING - } else if (error == SSL_ERROR_WANT_RSA_ASYNC_PENDING) { + || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING +#endif +#ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING + || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING +#endif + ) { // Our custom openssl function has kicked off an async request to do - // modular exponentiation. When that call returns, a callback will + // rsa/ecdsa private key operation. When that call returns, a callback will // be invoked that will re-call handleAccept. - sslState_ = STATE_RSA_ASYNC_PENDING; + sslState_ = STATE_ASYNC_PENDING; // Unregister for all events while blocked here updateEventRegistration( @@ -861,11 +997,8 @@ bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept { // The timeout (if set) keeps running here return true; -#endif } else { - // SSL_ERROR_ZERO_RETURN is processed here so we can get some detail - // in the log - long lastError = ERR_get_error(); + unsigned long lastError = *errErrorOut = ERR_get_error(); VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", " << "state=" << state_ << ", " << "sslState=" << sslState_ << ", " @@ -877,17 +1010,6 @@ bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept { << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", " << "func: " << ERR_func_error_string(lastError) << ", " << "reason: " << ERR_reason_error_string(lastError); - if (error != SSL_ERROR_SYSCALL) { - if (error == SSL_ERROR_SSL) { - *errorOut = lastError; - } - if ((unsigned long)lastError < 0x8000) { - errno = ENOSYS; - } else { - errno = lastError; - } - } - ERR_clear_error(); return false; } } @@ -898,22 +1020,23 @@ void AsyncSSLSocket::checkForImmediateRead() noexcept { // the socket to become readable again. if (ssl_ != nullptr && SSL_pending(ssl_) > 0) { AsyncSocket::handleRead(); + } else { + AsyncSocket::checkForImmediateRead(); } } void AsyncSSLSocket::restartSSLAccept() { - VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this << ", fd=" << fd_ - << ", state=" << int(state_) << ", " + VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this + << ", fd=" << fd_ << ", state=" << int(state_) << ", " << "sslState=" << sslState_ << ", events=" << eventFlags_; DestructorGuard dg(this); assert( sslState_ == STATE_CACHE_LOOKUP || - sslState_ == STATE_RSA_ASYNC_PENDING || + sslState_ == STATE_ASYNC_PENDING || sslState_ == STATE_ERROR || - sslState_ == STATE_CLOSED - ); + sslState_ == STATE_CLOSED); if (sslState_ == STATE_CLOSED) { // I sure hope whoever closed this socket didn't delete it already, // but this is not strictly speaking an error @@ -921,8 +1044,8 @@ AsyncSSLSocket::restartSSLAccept() } if (sslState_ == STATE_ERROR) { // go straight to fail if timeout expired during lookup - AsyncSocketException ex(AsyncSocketException::TIMED_OUT, - "SSL accept timed out"); + AsyncSocketException ex( + AsyncSocketException::TIMED_OUT, "SSL accept timed out"); failHandshake(__func__, ex); return; } @@ -950,26 +1073,34 @@ AsyncSSLSocket::handleAccept() noexcept { << ", fd=" << fd_ << "): " << e.what(); return failHandshake(__func__, ex); } - SSL_set_fd(ssl_, fd_); + + if (!setupSSLBio()) { + sslState_ = STATE_ERROR; + AsyncSocketException ex( + AsyncSocketException::INTERNAL_ERROR, "error creating write bio"); + return failHandshake(__func__, ex); + } + SSL_set_ex_data(ssl_, getSSLExDataIndex(), this); applyVerificationOptions(ssl_); } if (server_ && parseClientHello_) { - SSL_set_msg_callback_arg(ssl_, this); SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback); + SSL_set_msg_callback_arg(ssl_, this); } - errno = 0; int ret = SSL_accept(ssl_); if (ret <= 0) { - int error; - if (willBlock(ret, &error)) { + int sslError; + unsigned long errError; + int errnoCopy = errno; + if (willBlock(ret, &sslError, &errError)) { return; } else { sslState_ = STATE_ERROR; - SSLException ex(error, errno); + SSLException ex(sslError, errError, ret, errnoCopy); return failHandshake(__func__, ex); } } @@ -1019,19 +1150,29 @@ AsyncSSLSocket::handleConnect() noexcept { return AsyncSocket::handleConnect(); } - assert(state_ == StateEnum::ESTABLISHED && - sslState_ == STATE_CONNECTING); + assert( + (state_ == StateEnum::FAST_OPEN || state_ == StateEnum::ESTABLISHED) && + sslState_ == STATE_CONNECTING); assert(ssl_); - errno = 0; + auto originalState = state_; int ret = SSL_connect(ssl_); if (ret <= 0) { - int error; - if (willBlock(ret, &error)) { + int sslError; + unsigned long errError; + int errnoCopy = errno; + if (willBlock(ret, &sslError, &errError)) { + // We fell back to connecting state due to TFO + if (state_ == StateEnum::CONNECTING) { + DCHECK_EQ(StateEnum::FAST_OPEN, originalState); + if (handshakeTimeout_.isScheduled()) { + handshakeTimeout_.cancelTimeout(); + } + } return; } else { sslState_ = STATE_ERROR; - SSLException ex(error, errno); + SSLException ex(sslError, errError, ret, errnoCopy); return failHandshake(__func__, ex); } } @@ -1043,7 +1184,8 @@ AsyncSSLSocket::handleConnect() noexcept { // STATE_CONNECTING. sslState_ = STATE_ESTABLISHED; - VLOG(3) << "AsyncSSLSocket %p: fd %d successfully connected; " + VLOG(3) << "AsyncSSLSocket " << this << ": " + << "fd " << fd_ << " successfully connected; " << "state=" << int(state_) << ", sslState=" << sslState_ << ", events=" << eventFlags_; @@ -1071,6 +1213,81 @@ AsyncSSLSocket::handleConnect() noexcept { AsyncSocket::handleInitialReadWrite(); } +void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) { + connectionTimeout_.cancelTimeout(); + AsyncSocket::invokeConnectErr(ex); + if (sslState_ == SSLStateEnum::STATE_CONNECTING) { + if (handshakeTimeout_.isScheduled()) { + handshakeTimeout_.cancelTimeout(); + } + // If we fell back to connecting state during TFO and the connection + // failed, it would be an SSL failure as well. + invokeHandshakeErr(ex); + } +} + +void AsyncSSLSocket::invokeConnectSuccess() { + connectionTimeout_.cancelTimeout(); + if (sslState_ == SSLStateEnum::STATE_CONNECTING) { + assert(tfoAttempted_); + // If we failed TFO, we'd fall back to trying to connect the socket, + // to setup things like timeouts. + startSSLConnect(); + } + // still invoke the base class since it re-sets the connect time. + AsyncSocket::invokeConnectSuccess(); +} + +void AsyncSSLSocket::scheduleConnectTimeout() { + if (sslState_ == SSLStateEnum::STATE_CONNECTING) { + // We fell back from TFO, and need to set the timeouts. + // We will not have a connect callback in this case, thus if the timer + // expires we would have no-one to notify. + // Thus we should reset even the connect timers to point to the handshake + // timeouts. + assert(connectCallback_ == nullptr); + // We use a different connect timeout here than the handshake timeout, so + // that we can disambiguate the 2 timers. + if (connectTimeout_.count() > 0) { + if (!connectionTimeout_.scheduleTimeout(connectTimeout_)) { + throw AsyncSocketException( + AsyncSocketException::INTERNAL_ERROR, + withAddr("failed to schedule AsyncSSLSocket connect timeout")); + } + } + return; + } + AsyncSocket::scheduleConnectTimeout(); +} + +void AsyncSSLSocket::setReadCB(ReadCallback *callback) { +#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP + // turn on the buffer movable in openssl + if (bufferMovableEnabled_ && ssl_ != nullptr && !isBufferMovable_ && + callback != nullptr && callback->isBufferMovable()) { + SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP); + isBufferMovable_ = true; + } +#endif + + AsyncSocket::setReadCB(callback); +} + +void AsyncSSLSocket::setBufferMovableEnabled(bool enabled) { + bufferMovableEnabled_ = enabled; +} + +void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) { + CHECK(readCallback_); + if (isBufferMovable_) { + *buf = nullptr; + *buflen = 0; + } else { + // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP + readCallback_->getReadBuffer(buf, buflen); + } +} + void AsyncSSLSocket::handleRead() noexcept { VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_ @@ -1096,46 +1313,57 @@ AsyncSSLSocket::handleRead() noexcept { AsyncSocket::handleRead(); } -ssize_t -AsyncSSLSocket::performRead(void* buf, size_t buflen) { +AsyncSocket::ReadResult +AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) { + VLOG(4) << "AsyncSSLSocket::performRead() this=" << this << ", buf=" << *buf + << ", buflen=" << *buflen; + if (sslState_ == STATE_UNENCRYPTED) { - return AsyncSocket::performRead(buf, buflen); + return AsyncSocket::performRead(buf, buflen, offset); + } + + int bytes = 0; + if (!isBufferMovable_) { + bytes = SSL_read(ssl_, *buf, int(*buflen)); } +#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP + else { + bytes = SSL_read_buf(ssl_, buf, (int *) offset, (int *) buflen); + } +#endif - errno = 0; - ssize_t bytes = SSL_read(ssl_, buf, buflen); if (server_ && renegotiateAttempted_) { LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) << ", sslstate=" << sslState_ << ", events=" << eventFlags_ << "): client intitiated SSL renegotiation not permitted"; - // We pack our own SSLerr here with a dummy function - errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ, - SSL_CLIENT_RENEGOTIATION_ATTEMPT); - ERR_clear_error(); - return READ_ERROR; + return ReadResult( + READ_ERROR, + folly::make_unique(SSLError::CLIENT_RENEGOTIATION)); } if (bytes <= 0) { int error = SSL_get_error(ssl_, bytes); if (error == SSL_ERROR_WANT_READ) { // The caller will register for read event if not already. - return READ_BLOCKING; + if (errno == EWOULDBLOCK || errno == EAGAIN) { + return ReadResult(READ_BLOCKING); + } else { + return ReadResult(READ_ERROR); + } } else if (error == SSL_ERROR_WANT_WRITE) { // TODO: Even though we are attempting to read data, SSL_read() may // need to write data if renegotiation is being performed. We currently // don't support this and just fail the read. LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) << ", sslState=" << sslState_ << ", events=" << eventFlags_ - << "): unsupported SSL renegotiation during read", - errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ, - SSL_INVALID_RENEGOTIATION); - ERR_clear_error(); - return READ_ERROR; + << "): unsupported SSL renegotiation during read"; + return ReadResult( + READ_ERROR, + folly::make_unique(SSLError::INVALID_RENEGOTIATION)); } else { - // TODO: Fix this code so that it can return a proper error message - // to the callback, rather than relying on AsyncSocket code which - // can't handle SSL errors. - long lastError = ERR_get_error(); - + if (zero_return(error, bytes)) { + return ReadResult(bytes); + } + auto errError = ERR_get_error(); VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", " << "state=" << state_ << ", " << "sslState=" << sslState_ << ", " @@ -1143,24 +1371,15 @@ AsyncSSLSocket::performRead(void* buf, size_t buflen) { << "bytes: " << bytes << ", " << "error: " << error << ", " << "errno: " << errno << ", " - << "func: " << ERR_func_error_string(lastError) << ", " - << "reason: " << ERR_reason_error_string(lastError); - ERR_clear_error(); - if (zero_return(error, bytes)) { - return bytes; - } - if (error != SSL_ERROR_SYSCALL) { - if ((unsigned long)lastError < 0x8000) { - errno = ENOSYS; - } else { - errno = lastError; - } - } - return READ_ERROR; + << "func: " << ERR_func_error_string(errError) << ", " + << "reason: " << ERR_reason_error_string(errError); + return ReadResult( + READ_ERROR, + folly::make_unique(error, errError, bytes, errno)); } } else { appBytesReceived_ += bytes; - return bytes; + return ReadResult(bytes); } } @@ -1188,49 +1407,40 @@ void AsyncSSLSocket::handleWrite() noexcept { AsyncSocket::handleWrite(); } -int AsyncSSLSocket::interpretSSLError(int rc, int error) { +AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) { if (error == SSL_ERROR_WANT_READ) { - // TODO: Even though we are attempting to write data, SSL_write() may + // Even though we are attempting to write data, SSL_write() may // need to read data if renegotiation is being performed. We currently // don't support this and just fail the write. LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) << ", sslState=" << sslState_ << ", events=" << eventFlags_ - << "): " << "unsupported SSL renegotiation during write", - errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE, - SSL_INVALID_RENEGOTIATION); - ERR_clear_error(); - return -1; + << "): " + << "unsupported SSL renegotiation during write"; + return WriteResult( + WRITE_ERROR, + folly::make_unique(SSLError::INVALID_RENEGOTIATION)); } else { - // TODO: Fix this code so that it can return a proper error message - // to the callback, rather than relying on AsyncSocket code which - // can't handle SSL errors. - long lastError = ERR_get_error(); + if (zero_return(error, rc)) { + return WriteResult(0); + } + auto errError = ERR_get_error(); VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_) << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): " << "SSL error: " << error << ", errno: " << errno - << ", func: " << ERR_func_error_string(lastError) - << ", reason: " << ERR_reason_error_string(lastError); - if (error != SSL_ERROR_SYSCALL) { - if ((unsigned long)lastError < 0x8000) { - errno = ENOSYS; - } else { - errno = lastError; - } - } - ERR_clear_error(); - if (!zero_return(error, rc)) { - return -1; - } else { - return 0; - } + << ", func: " << ERR_func_error_string(errError) + << ", reason: " << ERR_reason_error_string(errError); + return WriteResult( + WRITE_ERROR, + folly::make_unique(error, errError, rc, errno)); } } -ssize_t AsyncSSLSocket::performWrite(const iovec* vec, - uint32_t count, - WriteFlags flags, - uint32_t* countWritten, - uint32_t* partialWritten) { +AsyncSocket::WriteResult AsyncSSLSocket::performWrite( + const iovec* vec, + uint32_t count, + WriteFlags flags, + uint32_t* countWritten, + uint32_t* partialWritten) { if (sslState_ == STATE_UNENCRYPTED) { return AsyncSocket::performWrite( vec, count, flags, countWritten, partialWritten); @@ -1241,20 +1451,9 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec, << ", events=" << eventFlags_ << "): " << "TODO: AsyncSSLSocket currently does not support calling " << "write() before the handshake has fully completed"; - errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE, - SSL_EARLY_WRITE); - return -1; - } - - bool cork = isSet(flags, WriteFlags::CORK); - CorkGuard guard(fd_, count > 1, cork, &corked_); - -#if 0 -//#ifdef SSL_MODE_WRITE_IOVEC - if (ssl_->mode & SSL_MODE_WRITE_IOVEC) { - return performWriteIovec(vec, count, flags, countWritten, partialWritten); + return WriteResult( + WRITE_ERROR, folly::make_unique(SSLError::EARLY_WRITE)); } -#endif // Declare a buffer used to hold small write requests. It could point to a // memory block either on stack or on heap. If it is on heap, we release it @@ -1284,8 +1483,8 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec, buf = ((const char*)v->iov_base) + offset; ssize_t bytes; - errno = 0; uint32_t buffersStolen = 0; + auto sslWriteBuf = buf; if ((len < minWriteSize_) && ((i + 1) < count)) { // Combine this buffer with part or all of the next buffers in // order to avoid really small-grained calls to SSL_write(). @@ -1306,6 +1505,7 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec, } } assert(combinedBuf != nullptr); + sslWriteBuf = combinedBuf; memcpy(combinedBuf, buf, len); do { @@ -1313,8 +1513,13 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec, uint32_t nextIndex = i + buffersStolen + 1; bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len, minWriteSize_ - len); - memcpy(combinedBuf + len, vec[nextIndex].iov_base, - bytesStolenFromNextBuffer); + if (bytesStolenFromNextBuffer > 0) { + assert(vec[nextIndex].iov_base != nullptr); + ::memcpy( + combinedBuf + len, + vec[nextIndex].iov_base, + bytesStolenFromNextBuffer); + } len += bytesStolenFromNextBuffer; if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) { // couldn't steal the whole buffer @@ -1324,25 +1529,34 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec, buffersStolen++; } } while ((i + buffersStolen + 1) < count && (len < minWriteSize_)); - bytes = eorAwareSSLWrite( - ssl_, combinedBuf, len, - (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count)); + } - } else { - bytes = eorAwareSSLWrite(ssl_, buf, len, - (isSet(flags, WriteFlags::EOR) && i + 1 == count)); + // Advance any empty buffers immediately after. + if (bytesStolenFromNextBuffer == 0) { + while ((i + buffersStolen + 1) < count && + vec[i + buffersStolen + 1].iov_len == 0) { + buffersStolen++; + } } + corkCurrentWrite_ = + isSet(flags, WriteFlags::CORK) || (i + buffersStolen + 1 < count); + bytes = eorAwareSSLWrite( + ssl_, + sslWriteBuf, + int(len), + (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count)); + if (bytes <= 0) { - int error = SSL_get_error(ssl_, bytes); + int error = SSL_get_error(ssl_, int(bytes)); if (error == SSL_ERROR_WANT_WRITE) { // The caller will register for write event if not already. - *partialWritten = offset; - return totalWritten; + *partialWritten = uint32_t(offset); + return WriteResult(totalWritten); } - int rc = interpretSSLError(bytes, error); - if (rc < 0) { - return rc; + auto writeResult = interpretSSLError(int(bytes), error); + if (writeResult.writeReturn < 0) { + return writeResult; } // else fall through to below to correctly record totalWritten } @@ -1363,55 +1577,17 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec, (*countWritten)++; v = &(vec[++i]); } - *partialWritten = bytes; - return totalWritten; + *partialWritten = uint32_t(bytes); + return WriteResult(totalWritten); } } - return totalWritten; + return WriteResult(totalWritten); } -#if 0 -//#ifdef SSL_MODE_WRITE_IOVEC -ssize_t AsyncSSLSocket::performWriteIovec(const iovec* vec, - uint32_t count, - WriteFlags flags, - uint32_t* countWritten, - uint32_t* partialWritten) { - size_t tot = 0; - for (uint32_t j = 0; j < count; j++) { - tot += vec[j].iov_len; - } - - ssize_t totalWritten = SSL_write_iovec(ssl_, vec, count); - - *countWritten = 0; - *partialWritten = 0; - if (totalWritten <= 0) { - return interpretSSLError(totalWritten, SSL_get_error(ssl_, totalWritten)); - } else { - ssize_t bytes = totalWritten, i = 0; - while (i < count && bytes >= (ssize_t)vec[i].iov_len) { - // we managed to write all of this buf - bytes -= vec[i].iov_len; - (*countWritten)++; - i++; - } - *partialWritten = bytes; - - VLOG(4) << "SSL_write_iovec() writes " << tot - << ", returns " << totalWritten << " bytes" - << ", max_send_fragment=" << ssl_->max_send_fragment - << ", count=" << count << ", countWritten=" << *countWritten; - - return totalWritten; - } -} -#endif - int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n, bool eor) { - if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) { + if (eor && isEorTrackingEnabled()) { if (appEorByteNo_) { // cannot track for more than one app byte EOR CHECK(appEorByteNo_ == appBytesWritten_ + n); @@ -1441,47 +1617,103 @@ int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n, return n; } -void -AsyncSSLSocket::sslInfoCallback(const SSL *ssl, int where, int ret) { +void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) { AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl); if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) { sslSocket->renegotiateAttempted_ = true; } + if (where & SSL_CB_READ_ALERT) { + const char* type = SSL_alert_type_string(ret); + if (type) { + const char* desc = SSL_alert_desc_string(ret); + sslSocket->alertsReceived_.emplace_back( + *type, StringPiece(desc, std::strlen(desc))); + } + } } -int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) { - int ret; +int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) { struct msghdr msg; struct iovec iov; - int flags = 0; - AsyncSSLSocket *tsslSock; + AsyncSSLSocket* tsslSock; - iov.iov_base = const_cast(in); - iov.iov_len = inl; + iov.iov_base = const_cast(in); + iov.iov_len = size_t(inl); memset(&msg, 0, sizeof(msg)); msg.msg_iov = &iov; msg.msg_iovlen = 1; - tsslSock = - reinterpret_cast(BIO_get_app_data(b)); - if (tsslSock && - tsslSock->minEorRawByteNo_ && + auto appData = OpenSSLUtils::getBioAppData(b); + CHECK(appData); + + tsslSock = reinterpret_cast(appData); + CHECK(tsslSock); + + WriteFlags flags = WriteFlags::NONE; + if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ && tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) { - flags = MSG_EOR; + flags |= WriteFlags::EOR; } - errno = 0; - ret = sendmsg(b->num, &msg, flags); + if (tsslSock->corkCurrentWrite_) { + flags |= WriteFlags::CORK; + } + + int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(flags); + msg.msg_controllen = + tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags); + CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize, + msg.msg_controllen); + if (msg.msg_controllen != 0) { + msg.msg_control = reinterpret_cast(alloca(msg.msg_controllen)); + tsslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control); + } + + auto result = tsslSock->sendSocketMessage( + OpenSSLUtils::getBioFd(b, nullptr), &msg, msg_flags); BIO_clear_retry_flags(b); - if (ret <= 0) { - if (BIO_sock_should_retry(ret)) + if (!result.exception && result.writeReturn <= 0) { + if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) { BIO_set_retry_write(b); + } } - return(ret); + return int(result.writeReturn); } -int AsyncSSLSocket::sslVerifyCallback(int preverifyOk, - X509_STORE_CTX* x509Ctx) { +int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) { + if (!out) { + return 0; + } + BIO_clear_retry_flags(b); + + auto appData = OpenSSLUtils::getBioAppData(b); + CHECK(appData); + auto sslSock = reinterpret_cast(appData); + + if (sslSock->preReceivedData_ && !sslSock->preReceivedData_->empty()) { + VLOG(5) << "AsyncSSLSocket::bioRead() this=" << sslSock + << ", reading pre-received data"; + + Cursor cursor(sslSock->preReceivedData_.get()); + auto len = cursor.pullAtMost(out, outl); + + IOBufQueue queue; + queue.append(std::move(sslSock->preReceivedData_)); + queue.trimStart(len); + sslSock->preReceivedData_ = queue.move(); + return len; + } else { + auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0); + if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) { + BIO_set_retry_read(b); + } + return result; + } +} + +int AsyncSSLSocket::sslVerifyCallback( + int preverifyOk, + X509_STORE_CTX* x509Ctx) { SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data( x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx()); AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl); @@ -1495,7 +1727,7 @@ int AsyncSSLSocket::sslVerifyCallback(int preverifyOk, void AsyncSSLSocket::enableClientHelloParsing() { parseClientHello_ = true; - clientHelloInfo_.reset(new ClientHelloInfo()); + clientHelloInfo_.reset(new ssl::ClientHelloInfo()); } void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl) { @@ -1504,17 +1736,19 @@ void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl) { clientHelloInfo_->clientHelloBuf_.clear(); } -void -AsyncSSLSocket::clientHelloParsingCallback(int written, int version, - int contentType, const void *buf, size_t len, SSL *ssl, void *arg) -{ +void AsyncSSLSocket::clientHelloParsingCallback(int written, + int /* version */, + int contentType, + const void* buf, + size_t len, + SSL* ssl, + void* arg) { AsyncSSLSocket *sock = static_cast(arg); if (written != 0) { sock->resetClientHelloParsing(ssl); return; } if (contentType != SSL3_RT_HANDSHAKE) { - sock->resetClientHelloParsing(ssl); return; } if (len == 0) { @@ -1570,16 +1804,41 @@ AsyncSSLSocket::clientHelloParsingCallback(int written, int version, if (cursor.totalLength() > 0) { uint16_t extensionsLength = cursor.readBE(); while (extensionsLength) { + ssl::TLSExtension extensionType = + static_cast(cursor.readBE()); sock->clientHelloInfo_-> - clientHelloExtensions_.push_back(cursor.readBE()); + clientHelloExtensions_.push_back(extensionType); extensionsLength -= 2; uint16_t extensionDataLength = cursor.readBE(); extensionsLength -= 2; - cursor.skip(extensionDataLength); extensionsLength -= extensionDataLength; + + if (extensionType == ssl::TLSExtension::SIGNATURE_ALGORITHMS) { + cursor.skip(2); + extensionDataLength -= 2; + while (extensionDataLength) { + ssl::HashAlgorithm hashAlg = + static_cast(cursor.readBE()); + ssl::SignatureAlgorithm sigAlg = + static_cast(cursor.readBE()); + extensionDataLength -= 2; + sock->clientHelloInfo_-> + clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg); + } + } else if (extensionType == ssl::TLSExtension::SUPPORTED_VERSIONS) { + cursor.skip(1); + extensionDataLength -= 1; + while (extensionDataLength) { + sock->clientHelloInfo_->clientHelloSupportedVersions_.push_back( + cursor.readBE()); + extensionDataLength -= 2; + } + } else { + cursor.skip(extensionDataLength); + } } } - } catch (std::out_of_range& e) { + } catch (std::out_of_range&) { // we'll use what we found and cleanup below. VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): " << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock; @@ -1588,4 +1847,129 @@ AsyncSSLSocket::clientHelloParsingCallback(int written, int version, sock->resetClientHelloParsing(ssl); } +void AsyncSSLSocket::getSSLClientCiphers( + std::string& clientCiphers, + bool convertToString) const { + std::string ciphers; + + if (parseClientHello_ == false + || clientHelloInfo_->clientHelloCipherSuites_.empty()) { + clientCiphers = ""; + return; + } + + bool first = true; + for (auto originalCipherCode : clientHelloInfo_->clientHelloCipherSuites_) + { + if (first) { + first = false; + } else { + ciphers += ":"; + } + + bool nameFound = convertToString; + + if (convertToString) { + const auto& name = OpenSSLUtils::getCipherName(originalCipherCode); + if (name.empty()) { + nameFound = false; + } else { + ciphers += name; + } + } + + if (!nameFound) { + folly::hexlify( + std::array{{ + static_cast((originalCipherCode >> 8) & 0xffL), + static_cast(originalCipherCode & 0x00ffL) }}, + ciphers, + /* append to ciphers = */ true); + } + } + + clientCiphers = std::move(ciphers); +} + +std::string AsyncSSLSocket::getSSLClientComprMethods() const { + if (!parseClientHello_) { + return ""; + } + return folly::join(":", clientHelloInfo_->clientHelloCompressionMethods_); +} + +std::string AsyncSSLSocket::getSSLClientExts() const { + if (!parseClientHello_) { + return ""; + } + return folly::join(":", clientHelloInfo_->clientHelloExtensions_); +} + +std::string AsyncSSLSocket::getSSLClientSigAlgs() const { + if (!parseClientHello_) { + return ""; + } + + std::string sigAlgs; + sigAlgs.reserve(clientHelloInfo_->clientHelloSigAlgs_.size() * 4); + for (size_t i = 0; i < clientHelloInfo_->clientHelloSigAlgs_.size(); i++) { + if (i) { + sigAlgs.push_back(':'); + } + sigAlgs.append(folly::to( + clientHelloInfo_->clientHelloSigAlgs_[i].first)); + sigAlgs.push_back(','); + sigAlgs.append(folly::to( + clientHelloInfo_->clientHelloSigAlgs_[i].second)); + } + + return sigAlgs; +} + +std::string AsyncSSLSocket::getSSLClientSupportedVersions() const { + if (!parseClientHello_) { + return ""; + } + return folly::join(":", clientHelloInfo_->clientHelloSupportedVersions_); +} + +std::string AsyncSSLSocket::getSSLAlertsReceived() const { + std::string ret; + + for (const auto& alert : alertsReceived_) { + if (!ret.empty()) { + ret.append(","); + } + ret.append(folly::to(alert.first, ": ", alert.second)); + } + + return ret; +} + +void AsyncSSLSocket::setSSLCertVerificationAlert(std::string alert) { + sslVerificationAlert_ = std::move(alert); +} + +std::string AsyncSSLSocket::getSSLCertVerificationAlert() const { + return sslVerificationAlert_; +} + +void AsyncSSLSocket::getSSLSharedCiphers(std::string& sharedCiphers) const { + char ciphersBuffer[1024]; + ciphersBuffer[0] = '\0'; + SSL_get_shared_ciphers(ssl_, ciphersBuffer, sizeof(ciphersBuffer) - 1); + sharedCiphers = ciphersBuffer; +} + +void AsyncSSLSocket::getSSLServerCiphers(std::string& serverCiphers) const { + serverCiphers = SSL_get_cipher_list(ssl_, 0); + int i = 1; + const char *cipher; + while ((cipher = SSL_get_cipher_list(ssl_, i)) != nullptr) { + serverCiphers.append(":"); + serverCiphers.append(cipher); + i++; + } +} + } // namespace