X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=folly%2Fio%2Fasync%2FAsyncSSLSocket.cpp;h=8c7ab10c3076d9104768936cf93aa5e2f36fb49f;hb=c5b9338ec192ed46907905d173b65d158a038842;hp=4b9f31738daa91fefde2da90e41cab83dc52a0b9;hpb=1e53154792a1d188cc29b7c78433913f34714912;p=folly.git diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index 4b9f3173..8c7ab10c 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -22,9 +22,6 @@ #include #include #include -#include -#include -#include #include #include @@ -34,7 +31,6 @@ #include #include #include -#include using folly::SocketAddress; using folly::SSLContext; @@ -59,6 +55,7 @@ using folly::SSLContext; using namespace folly::ssl; using folly::ssl::OpenSSLUtils; + // We have one single dummy SSL context so that we can implement attach // and detach methods in a thread safe fashion without modifying opnessl. static SSLContext *dummyCtx = nullptr; @@ -106,8 +103,9 @@ class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback, timeoutLeft = timeout_ - (curTime - startTime_); if (timeoutLeft <= 0) { - AsyncSocketException ex(AsyncSocketException::TIMED_OUT, - "SSL connect timed out"); + AsyncSocketException ex( + AsyncSocketException::TIMED_OUT, + folly::sformat("SSL connect timed out after {}ms", timeout_)); fail(ex); delete this; return; @@ -176,14 +174,20 @@ void setup_SSL_CTX(SSL_CTX *ctx) { } -BIO_METHOD sslBioMethod; +// Note: This is a Leaky Meyer's Singleton. The reason we can't use a non-leaky +// thing is because we will be setting this BIO_METHOD* inside BIOs owned by +// various SSL objects which may get callbacks even during teardown. We may +// eventually try to fix this +static BIO_METHOD* getSSLBioMethod() { + static auto const instance = OpenSSLUtils::newSocketBioMethod().release(); + return instance; +} void* initsslBioMethod(void) { - memcpy(&sslBioMethod, BIO_s_socket(), sizeof(sslBioMethod)); + auto sslBioMethod = getSSLBioMethod(); // override the bwrite method for MSG_EOR support - OpenSSLUtils::setCustomBioWriteMethod( - &sslBioMethod, AsyncSSLSocket::bioWrite); - OpenSSLUtils::setCustomBioReadMethod(&sslBioMethod, AsyncSSLSocket::bioRead); + OpenSSLUtils::setCustomBioWriteMethod(sslBioMethod, AsyncSSLSocket::bioWrite); + OpenSSLUtils::setCustomBioReadMethod(sslBioMethod, AsyncSSLSocket::bioRead); // Note that the sslBioMethod.type and sslBioMethod.name are not // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and @@ -217,14 +221,38 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, /** * Create a server/client AsyncSSLSocket */ -AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, - EventBase* evb, int fd, bool server, - bool deferSecurityNegotiation) : - AsyncSocket(evb, fd), - server_(server), - ctx_(ctx), - handshakeTimeout_(this, evb), - connectionTimeout_(this, evb) { +AsyncSSLSocket::AsyncSSLSocket( + const shared_ptr& ctx, + EventBase* evb, + int fd, + bool server, + bool deferSecurityNegotiation) + : AsyncSocket(evb, fd), + server_(server), + ctx_(ctx), + handshakeTimeout_(this, evb), + connectionTimeout_(this, evb) { + noTransparentTls_ = true; + init(); + if (server) { + SSL_CTX_set_info_callback( + ctx_->getSSLCtx(), AsyncSSLSocket::sslInfoCallback); + } + if (deferSecurityNegotiation) { + sslState_ = STATE_UNENCRYPTED; + } +} + +AsyncSSLSocket::AsyncSSLSocket( + const shared_ptr& ctx, + AsyncSocket::UniquePtr oldAsyncSocket, + bool server, + bool deferSecurityNegotiation) + : AsyncSocket(std::move(oldAsyncSocket)), + server_(server), + ctx_(ctx), + handshakeTimeout_(this, AsyncSocket::getEventBase()), + connectionTimeout_(this, AsyncSocket::getEventBase()) { noTransparentTls_ = true; init(); if (server) { @@ -253,11 +281,13 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr &ctx, * Create a client AsyncSSLSocket from an already connected fd * and allow tlsext_hostname to be sent in Client Hello. */ -AsyncSSLSocket::AsyncSSLSocket(const shared_ptr& ctx, - EventBase* evb, int fd, - const std::string& serverName, - bool deferSecurityNegotiation) : - AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) { +AsyncSSLSocket::AsyncSSLSocket( + const shared_ptr& ctx, + EventBase* evb, + int fd, + const std::string& serverName, + bool deferSecurityNegotiation) + : AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) { tlsextHostname_ = serverName; } #endif // FOLLY_OPENSSL_HAS_SNI @@ -358,13 +388,9 @@ std::string AsyncSSLSocket::getApplicationProtocol() noexcept { return ""; } -bool AsyncSSLSocket::isEorTrackingEnabled() const { - return trackEor_; -} - void AsyncSSLSocket::setEorTracking(bool track) { - if (trackEor_ != track) { - trackEor_ = track; + if (isEorTrackingEnabled() != track) { + AsyncSocket::setEorTracking(track); appEorByteNo_ = 0; minEorRawByteNo_ = 0; } @@ -454,12 +480,9 @@ void AsyncSSLSocket::sslAccept( /* register for a read operation (waiting for CLIENT HELLO) */ updateEventRegistration(EventHandler::READ, EventHandler::WRITE); - if (preReceivedData_) { - handleRead(); - } + checkForImmediateRead(); } -#if OPENSSL_VERSION_NUMBER >= 0x009080bfL void AsyncSSLSocket::attachSSLContext( const std::shared_ptr& ctx) { @@ -482,10 +505,16 @@ void AsyncSSLSocket::attachSSLContext( // We need to update the initial_ctx if necessary auto sslCtx = ctx->getSSLCtx(); SSL_CTX_up_ref(sslCtx); -#ifndef OPENSSL_NO_TLSEXT - // note that detachSSLContext has already freed ssl_->initial_ctx - ssl_->initial_ctx = sslCtx; -#endif + + // The 'initial_ctx' inside an SSL* points to the context that it was created + // with, which is also where session callbacks and servername callbacks + // happen. + // When we switch to a different SSL_CTX, we want to update the initial_ctx as + // well so that any callbacks don't go to a different object + // NOTE: this will only work if we have access to ssl_ internals, so it may + // not work on + // OpenSSL version >= 1.1.0 + OpenSSLUtils::setSSLInitialCtx(ssl_, sslCtx); // Detach sets the socket's context to the dummy context. Thus we must acquire // this lock. SpinLockGuard guard(dummyCtxLock); @@ -500,14 +529,20 @@ void AsyncSSLSocket::detachSSLContext() { if (!ssl_) { return; } -// Detach the initial_ctx as well. Internally w/ OPENSSL_NO_TLSEXT -// it is used for session info. It will be reattached in attachSSLContext -#ifndef OPENSSL_NO_TLSEXT - if (ssl_->initial_ctx) { - SSL_CTX_free(ssl_->initial_ctx); - ssl_->initial_ctx = nullptr; + // The 'initial_ctx' inside an SSL* points to the context that it was created + // with, which is also where session callbacks and servername callbacks + // happen. + // Detach the initial_ctx as well. It will be reattached in attachSSLContext + // it is used for session info. + // NOTE: this will only work if we have access to ssl_ internals, so it may + // not work on + // OpenSSL version >= 1.1.0 + SSL_CTX* initialCtx = OpenSSLUtils::getSSLInitialCtx(ssl_); + if (initialCtx) { + SSL_CTX_free(initialCtx); + OpenSSLUtils::setSSLInitialCtx(ssl_, nullptr); } -#endif + SpinLockGuard guard(dummyCtxLock); if (nullptr == dummyCtx) { // We need to lazily initialize the dummy context so we don't @@ -520,7 +555,6 @@ void AsyncSSLSocket::detachSSLContext() { // would not be thread safe. SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx()); } -#endif #if FOLLY_OPENSSL_HAS_SNI void AsyncSSLSocket::switchServerSSLContext( @@ -553,10 +587,8 @@ bool AsyncSSLSocket::isServerNameMatch() const { return false; } - if(!ss->tlsext_hostname) { - return false; - } - return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true); + auto tlsextHostname = SSL_SESSION_get0_hostname(ss); + return (tlsextHostname && !tlsextHostname_.compare(tlsextHostname)); } void AsyncSSLSocket::setServerName(std::string serverName) noexcept { @@ -565,10 +597,10 @@ void AsyncSSLSocket::setServerName(std::string serverName) noexcept { #endif // FOLLY_OPENSSL_HAS_SNI -void AsyncSSLSocket::timeoutExpired() noexcept { +void AsyncSSLSocket::timeoutExpired( + std::chrono::milliseconds timeout) noexcept { if (state_ == StateEnum::ESTABLISHED && - (sslState_ == STATE_CACHE_LOOKUP || - sslState_ == STATE_ASYNC_PENDING)) { + (sslState_ == STATE_CACHE_LOOKUP || sslState_ == STATE_ASYNC_PENDING)) { sslState_ = STATE_ERROR; // We are expecting a callback in restartSSLAccept. The cache lookup // and rsa-call necessarily have pointers to this ssl socket, so delay @@ -583,9 +615,12 @@ void AsyncSSLSocket::timeoutExpired() noexcept { assert(state_ == StateEnum::ESTABLISHED && (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING)); DestructorGuard dg(this); - AsyncSocketException ex(AsyncSocketException::TIMED_OUT, - (sslState_ == STATE_CONNECTING) ? - "SSL connect timed out" : "SSL accept timed out"); + AsyncSocketException ex( + AsyncSocketException::TIMED_OUT, + folly::sformat( + "SSL {} timed out after {}ms", + (sslState_ == STATE_CONNECTING) ? "connect" : "accept", + timeout.count())); failHandshake(__func__, ex); } } @@ -645,19 +680,41 @@ void AsyncSSLSocket::cacheLocalPeerAddr() { } } -void AsyncSSLSocket::connect(ConnectCallback* callback, - const folly::SocketAddress& address, - int timeout, - const OptionMap &options, - const folly::SocketAddress& bindAddr) - noexcept { +void AsyncSSLSocket::connect( + ConnectCallback* callback, + const folly::SocketAddress& address, + int timeout, + const OptionMap& options, + const folly::SocketAddress& bindAddr) noexcept { + auto timeoutChrono = std::chrono::milliseconds(timeout); + connect(callback, address, timeoutChrono, timeoutChrono, options, bindAddr); +} + +void AsyncSSLSocket::connect( + ConnectCallback* callback, + const folly::SocketAddress& address, + std::chrono::milliseconds connectTimeout, + std::chrono::milliseconds totalConnectTimeout, + const OptionMap& options, + const folly::SocketAddress& bindAddr) noexcept { assert(!server_); assert(state_ == StateEnum::UNINIT); assert(sslState_ == STATE_UNINIT); noTransparentTls_ = true; - AsyncSSLSocketConnector *connector = - new AsyncSSLSocketConnector(this, callback, timeout); - AsyncSocket::connect(connector, address, timeout, options, bindAddr); + totalConnectTimeout_ = totalConnectTimeout; + AsyncSSLSocketConnector* connector = + new AsyncSSLSocketConnector(this, callback, totalConnectTimeout.count()); + AsyncSocket::connect( + connector, address, connectTimeout.count(), options, bindAddr); +} + +bool AsyncSSLSocket::needsPeerVerification() const { + if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) { + return ctx_->needsPeerVerification(); + } + return ( + verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY || + verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT); } void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) { @@ -677,7 +734,7 @@ void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) { } bool AsyncSSLSocket::setupSSLBio() { - auto sslBio = BIO_new(&sslBioMethod); + auto sslBio = BIO_new(getSSLBioMethod()); if (!sslBio) { return false; @@ -865,7 +922,7 @@ int AsyncSSLSocket::getSSLVersion() const { const char *AsyncSSLSocket::getSSLCertSigAlgName() const { X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr; if (cert) { - int nid = OBJ_obj2nid(cert->sig_alg->algorithm); + int nid = X509_get_signature_nid(cert); return OBJ_nid2ln(nid); } return nullptr; @@ -963,6 +1020,8 @@ void AsyncSSLSocket::checkForImmediateRead() noexcept { // the socket to become readable again. if (ssl_ != nullptr && SSL_pending(ssl_) > 0) { AsyncSocket::handleRead(); + } else { + AsyncSocket::checkForImmediateRead(); } } @@ -985,8 +1044,8 @@ AsyncSSLSocket::restartSSLAccept() } if (sslState_ == STATE_ERROR) { // go straight to fail if timeout expired during lookup - AsyncSocketException ex(AsyncSocketException::TIMED_OUT, - "SSL accept timed out"); + AsyncSocketException ex( + AsyncSocketException::TIMED_OUT, "SSL accept timed out"); failHandshake(__func__, ex); return; } @@ -1032,7 +1091,6 @@ AsyncSSLSocket::handleAccept() noexcept { SSL_set_msg_callback_arg(ssl_, this); } - clearOpenSSLErrors(); int ret = SSL_accept(ssl_); if (ret <= 0) { int sslError; @@ -1082,18 +1140,6 @@ AsyncSSLSocket::handleAccept() noexcept { AsyncSocket::handleInitialReadWrite(); } -void AsyncSSLSocket::clearOpenSSLErrors() { - // Normally clearing out the error before calling into an openssl method - // is a bad idea. However there might be other code that we don't control - // calling into openssl in the same thread, which doesn't use openssl - // correctly. We want to safe-guard ourselves from that code. - // However touching the ERR stack each and every time has a cost of taking - // a lock, so we only do this when we've opted in. - if (clearOpenSSLErrors_) { - ERR_clear_error(); - } -} - void AsyncSSLSocket::handleConnect() noexcept { VLOG(3) << "AsyncSSLSocket::handleConnect() this=" << this @@ -1109,7 +1155,6 @@ AsyncSSLSocket::handleConnect() noexcept { sslState_ == STATE_CONNECTING); assert(ssl_); - clearOpenSSLErrors(); auto originalState = state_; int ret = SSL_connect(ssl_); if (ret <= 0) { @@ -1277,7 +1322,6 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) { return AsyncSocket::performRead(buf, buflen, offset); } - clearOpenSSLErrors(); int bytes = 0; if (!isBufferMovable_) { bytes = SSL_read(ssl_, *buf, int(*buflen)); @@ -1319,7 +1363,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) { if (zero_return(error, bytes)) { return ReadResult(bytes); } - long errError = ERR_get_error(); + auto errError = ERR_get_error(); VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", " << "state=" << state_ << ", " << "sslState=" << sslState_ << ", " @@ -1469,8 +1513,13 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite( uint32_t nextIndex = i + buffersStolen + 1; bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len, minWriteSize_ - len); - memcpy(combinedBuf + len, vec[nextIndex].iov_base, - bytesStolenFromNextBuffer); + if (bytesStolenFromNextBuffer > 0) { + assert(vec[nextIndex].iov_base != nullptr); + ::memcpy( + combinedBuf + len, + vec[nextIndex].iov_base, + bytesStolenFromNextBuffer); + } len += bytesStolenFromNextBuffer; if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) { // couldn't steal the whole buffer @@ -1538,7 +1587,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite( int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n, bool eor) { - if (eor && trackEor_) { + if (eor && isEorTrackingEnabled()) { if (appEorByteNo_) { // cannot track for more than one app byte EOR CHECK(appEorByteNo_ == appBytesWritten_ + n); @@ -1586,11 +1635,10 @@ void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) { int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) { struct msghdr msg; struct iovec iov; - int flags = 0; AsyncSSLSocket* tsslSock; iov.iov_base = const_cast(in); - iov.iov_len = inl; + iov.iov_len = size_t(inl); memset(&msg, 0, sizeof(msg)); msg.msg_iov = &iov; msg.msg_iovlen = 1; @@ -1601,23 +1649,28 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) { tsslSock = reinterpret_cast(appData); CHECK(tsslSock); - if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ && + WriteFlags flags = WriteFlags::NONE; + if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ && tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) { - flags = MSG_EOR; + flags |= WriteFlags::EOR; } -#ifdef MSG_NOSIGNAL - flags |= MSG_NOSIGNAL; -#endif - -#ifdef MSG_MORE if (tsslSock->corkCurrentWrite_) { - flags |= MSG_MORE; + flags |= WriteFlags::CORK; + } + + int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(flags); + msg.msg_controllen = + tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags); + CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize, + msg.msg_controllen); + if (msg.msg_controllen != 0) { + msg.msg_control = reinterpret_cast(alloca(msg.msg_controllen)); + tsslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control); } -#endif auto result = tsslSock->sendSocketMessage( - OpenSSLUtils::getBioFd(b, nullptr), &msg, flags); + OpenSSLUtils::getBioFd(b, nullptr), &msg, msg_flags); BIO_clear_retry_flags(b); if (!result.exception && result.writeReturn <= 0) { if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) { @@ -1672,12 +1725,6 @@ int AsyncSSLSocket::sslVerifyCallback( preverifyOk; } -void AsyncSSLSocket::setPreReceivedData(std::unique_ptr data) { - CHECK(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED); - CHECK(!preReceivedData_); - preReceivedData_ = std::move(data); -} - void AsyncSSLSocket::enableClientHelloParsing() { parseClientHello_ = true; clientHelloInfo_.reset(new ssl::ClientHelloInfo()); @@ -1899,6 +1946,14 @@ std::string AsyncSSLSocket::getSSLAlertsReceived() const { return ret; } +void AsyncSSLSocket::setSSLCertVerificationAlert(std::string alert) { + sslVerificationAlert_ = std::move(alert); +} + +std::string AsyncSSLSocket::getSSLCertVerificationAlert() const { + return sslVerificationAlert_; +} + void AsyncSSLSocket::getSSLSharedCiphers(std::string& sharedCiphers) const { char ciphersBuffer[1024]; ciphersBuffer[0] = '\0';