X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=folly%2Fio%2Fasync%2FSSLContext.cpp;h=19c0aa53885fe10dce115b2b72c660e1f48a7b40;hb=08dba5714790020d2fa677e34e624eb4f34a20ca;hp=1b0018fdf5a5851a2ac4502585358da62bf151b2;hpb=fbdf69e683faf90b91dedf86e7198a1dc3fb8dc6;p=folly.git diff --git a/folly/io/async/SSLContext.cpp b/folly/io/async/SSLContext.cpp index 1b0018fd..19c0aa53 100644 --- a/folly/io/async/SSLContext.cpp +++ b/folly/io/async/SSLContext.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2014 Facebook, Inc. + * Copyright 2015 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,9 @@ #include #include +#include #include +#include // --------------------------------------------------------------------- // SSLContext implementation @@ -35,22 +37,27 @@ struct CRYPTO_dynlock_value { namespace folly { bool SSLContext::initialized_ = false; -std::mutex SSLContext::mutex_; + +namespace { + +std::mutex& initMutex() { + static std::mutex m; + return m; +} + +inline void BIO_free_fb(BIO* bio) { CHECK_EQ(1, BIO_free(bio)); } +using BIO_deleter = folly::static_function_deleter; + +} // anonymous namespace + #ifdef OPENSSL_NPN_NEGOTIATED int SSLContext::sNextProtocolsExDataIndex_ = -1; #endif -#ifndef SSLCONTEXT_NO_REFCOUNT -uint64_t SSLContext::count_ = 0; -#endif - // SSLContext implementation SSLContext::SSLContext(SSLVersion version) { { - std::lock_guard g(mutex_); -#ifndef SSLCONTEXT_NO_REFCOUNT - count_++; -#endif + std::lock_guard g(initMutex()); initializeOpenSSLLocked(); } @@ -82,6 +89,10 @@ SSLContext::SSLContext(SSLVersion version) { SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback); SSL_CTX_set_tlsext_servername_arg(ctx_, this); #endif + +#ifdef OPENSSL_NPN_NEGOTIATED + Random::seed(nextProtocolPicker_); +#endif } SSLContext::~SSLContext() { @@ -93,15 +104,6 @@ SSLContext::~SSLContext() { #ifdef OPENSSL_NPN_NEGOTIATED deleteNextProtocolsStrings(); #endif - -#ifndef SSLCONTEXT_NO_REFCOUNT - { - std::lock_guard g(mutex_); - if (!--count_) { - cleanupOpenSSLLocked(); - } - } -#endif } void SSLContext::ciphers(const std::string& ciphers) { @@ -189,10 +191,35 @@ void SSLContext::loadCertificate(const char* path, const char* format) { } } +void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) { + if (cert.data() == nullptr) { + throw std::invalid_argument("loadCertificate: is nullptr"); + } + + std::unique_ptr bio(BIO_new(BIO_s_mem())); + if (bio == nullptr) { + throw std::runtime_error("BIO_new: " + getErrors()); + } + + int written = BIO_write(bio.get(), cert.data(), cert.size()); + if (written <= 0 || static_cast(written) != cert.size()) { + throw std::runtime_error("BIO_write: " + getErrors()); + } + + X509_UniquePtr x509(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); + if (x509 == nullptr) { + throw std::runtime_error("PEM_read_bio_X509: " + getErrors()); + } + + if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) { + throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors()); + } +} + void SSLContext::loadPrivateKey(const char* path, const char* format) { if (path == nullptr || format == nullptr) { throw std::invalid_argument( - "loadPrivateKey: either or is nullptr"); + "loadPrivateKey: either or is nullptr"); } if (strcmp(format, "PEM") == 0) { if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) { @@ -203,10 +230,35 @@ void SSLContext::loadPrivateKey(const char* path, const char* format) { } } +void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) { + if (pkey.data() == nullptr) { + throw std::invalid_argument("loadPrivateKey: is nullptr"); + } + + std::unique_ptr bio(BIO_new(BIO_s_mem())); + if (bio == nullptr) { + throw std::runtime_error("BIO_new: " + getErrors()); + } + + int written = BIO_write(bio.get(), pkey.data(), pkey.size()); + if (written <= 0 || static_cast(written) != pkey.size()) { + throw std::runtime_error("BIO_write: " + getErrors()); + } + + EVP_PKEY_UniquePtr key( + PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr)); + if (key == nullptr) { + throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors()); + } + + if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) { + throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors()); + } +} + void SSLContext::loadTrustedCertificates(const char* path) { if (path == nullptr) { - throw std::invalid_argument( - "loadTrustedCertificates: is nullptr"); + throw std::invalid_argument("loadTrustedCertificates: is nullptr"); } if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) { throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors()); @@ -312,13 +364,43 @@ void SSLContext::switchCiphersIfTLS11( } #endif +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) +int SSLContext::alpnSelectCallback(SSL* /* ssl */, + const unsigned char** out, + unsigned char* outlen, + const unsigned char* in, + unsigned int inlen, + void* data) { + SSLContext* context = (SSLContext*)data; + CHECK(context); + if (context->advertisedNextProtocols_.empty()) { + *out = nullptr; + *outlen = 0; + } else { + auto i = context->pickNextProtocols(); + const auto& item = context->advertisedNextProtocols_[i]; + if (SSL_select_next_proto((unsigned char**)out, + outlen, + item.protocols, + item.length, + in, + inlen) != OPENSSL_NPN_NEGOTIATED) { + return SSL_TLSEXT_ERR_NOACK; + } + } + return SSL_TLSEXT_ERR_OK; +} +#endif + #ifdef OPENSSL_NPN_NEGOTIATED -bool SSLContext::setAdvertisedNextProtocols(const std::list& protocols) { - return setRandomizedAdvertisedNextProtocols({{1, protocols}}); + +bool SSLContext::setAdvertisedNextProtocols( + const std::list& protocols, NextProtocolType protocolType) { + return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType); } bool SSLContext::setRandomizedAdvertisedNextProtocols( - const std::list& items) { + const std::list& items, NextProtocolType protocolType) { unsetNextProtocols(); if (items.size() == 0) { return false; @@ -351,20 +433,30 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols( dst += protoLength; } total_weight += item.weight; - advertised_item.probability = item.weight; advertisedNextProtocols_.push_back(advertised_item); + advertisedNextProtocolWeights_.push_back(item.weight); } if (total_weight == 0) { deleteNextProtocolsStrings(); return false; } - for (auto &advertised_item : advertisedNextProtocols_) { - advertised_item.probability /= total_weight; + nextProtocolDistribution_ = + std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(), + advertisedNextProtocolWeights_.end()); + if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) { + SSL_CTX_set_next_protos_advertised_cb( + ctx_, advertisedNextProtocolCallback, this); + SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this); + } +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) + if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) { + SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this); + // Client cannot really use randomized alpn + SSL_CTX_set_alpn_protos(ctx_, + advertisedNextProtocols_[0].protocols, + advertisedNextProtocols_[0].length); } - SSL_CTX_set_next_protos_advertised_cb( - ctx_, advertisedNextProtocolCallback, this); - SSL_CTX_set_next_proto_select_cb( - ctx_, selectNextProtocolCallback, this); +#endif return true; } @@ -373,12 +465,22 @@ void SSLContext::deleteNextProtocolsStrings() { delete[] protocols.protocols; } advertisedNextProtocols_.clear(); + advertisedNextProtocolWeights_.clear(); } void SSLContext::unsetNextProtocols() { deleteNextProtocolsStrings(); SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr); SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr); +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) + SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr); + SSL_CTX_set_alpn_protos(ctx_, nullptr, 0); +#endif +} + +size_t SSLContext::pickNextProtocols() { + CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols"; + return nextProtocolDistribution_(nextProtocolPicker_); } int SSLContext::advertisedNextProtocolCallback(SSL* ssl, @@ -398,31 +500,83 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl, *out = context->advertisedNextProtocols_[selected_index].protocols; *outlen = context->advertisedNextProtocols_[selected_index].length; } else { - unsigned char random_byte; - RAND_bytes(&random_byte, 1); - double random_value = random_byte / 255.0; - double sum = 0; - for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) { - sum += context->advertisedNextProtocols_[i].probability; - if (sum < random_value && - i + 1 < context->advertisedNextProtocols_.size()) { - continue; - } - uintptr_t selected = i + 1; - SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected); - *out = context->advertisedNextProtocols_[i].protocols; - *outlen = context->advertisedNextProtocols_[i].length; - break; - } + auto i = context->pickNextProtocols(); + uintptr_t selected = i + 1; + SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected); + *out = context->advertisedNextProtocols_[i].protocols; + *outlen = context->advertisedNextProtocols_[i].length; } } return SSL_TLSEXT_ERR_OK; } -int SSLContext::selectNextProtocolCallback( - SSL* ssl, unsigned char **out, unsigned char *outlen, - const unsigned char *server, unsigned int server_len, void *data) { +#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \ + FOLLY_SSLCONTEXT_USE_TLS_FALSE_START +SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() : + ciphers_{ + TLS1_CK_DHE_DSS_WITH_AES_128_SHA, + TLS1_CK_DHE_RSA_WITH_AES_128_SHA, + TLS1_CK_DHE_DSS_WITH_AES_256_SHA, + TLS1_CK_DHE_RSA_WITH_AES_256_SHA, + TLS1_CK_DHE_DSS_WITH_AES_128_SHA256, + TLS1_CK_DHE_RSA_WITH_AES_128_SHA256, + TLS1_CK_DHE_DSS_WITH_AES_256_SHA256, + TLS1_CK_DHE_RSA_WITH_AES_256_SHA256, + TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256, + TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384, + TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256, + TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384, + TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA, + TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256, + TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384, + TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256, + TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384, + TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256, + TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384, + TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + } { + length_ = sizeof(ciphers_)/sizeof(ciphers_[0]); + width_ = sizeof(ciphers_[0]); + qsort(ciphers_, length_, width_, compare_ulong); +} + +bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher( + const SSL_CIPHER *cipher) { + unsigned long cid = cipher->id; + unsigned long *r = + (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong); + return r != nullptr; +} + +int +SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) { + if (*(unsigned long *)x < *(unsigned long *)y) { + return -1; + } + if (*(unsigned long *)x > *(unsigned long *)y) { + return 1; + } + return 0; +}; + +bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) { + return falseStartChecker_.canUseFalseStartWithCipher(cipher); +} +#endif +int SSLContext::selectNextProtocolCallback(SSL* ssl, + unsigned char** out, + unsigned char* outlen, + const unsigned char* server, + unsigned int server_len, + void* data) { + (void)ssl; // Make -Wunused-parameters happy SSLContext* ctx = (SSLContext*)data; if (ctx->advertisedNextProtocols_.size() > 1) { VLOG(3) << "SSLContext::selectNextProcolCallback() " @@ -430,13 +584,21 @@ int SSLContext::selectNextProtocolCallback( } unsigned char *client; - int client_len; - if (ctx->advertisedNextProtocols_.empty()) { - client = (unsigned char *) ""; - client_len = 0; - } else { - client = ctx->advertisedNextProtocols_[0].protocols; - client_len = ctx->advertisedNextProtocols_[0].length; + unsigned int client_len; + bool filtered = false; + auto cpf = ctx->getClientProtocolFilterCallback(); + if (cpf) { + filtered = (*cpf)(&client, &client_len, server, server_len); + } + + if (!filtered) { + if (ctx->advertisedNextProtocols_.empty()) { + client = (unsigned char *) ""; + client_len = 0; + } else { + client = ctx->advertisedNextProtocols_[0].protocols; + client_len = ctx->advertisedNextProtocols_[0].length; + } } int retval = SSL_select_next_proto(out, outlen, server, server_len, @@ -444,6 +606,14 @@ int SSLContext::selectNextProtocolCallback( if (retval != OPENSSL_NPN_NEGOTIATED) { VLOG(3) << "SSLContext::selectNextProcolCallback() " << "unable to pick a next protocol."; +#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \ + FOLLY_SSLCONTEXT_USE_TLS_FALSE_START + } else { + const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher; + if (cipher && ctx->canUseFalseStartWithCipher(cipher)) { + SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH); + } +#endif } return SSL_TLSEXT_ERR_OK; } @@ -544,17 +714,13 @@ struct SSLLock { // SSLContext runs in such environments. // Instead of declaring a static member we "new" the static // member so that it won't be destructed on exit(). -static std::map* lockTypesInst = - new std::map(); - -static std::unique_ptr* locksInst = - new std::unique_ptr(); - static std::unique_ptr& locks() { + static auto locksInst = new std::unique_ptr(); return *locksInst; } static std::map& lockTypes() { + static auto lockTypesInst = new std::map(); return *lockTypesInst; } @@ -600,8 +766,13 @@ void SSLContext::setSSLLockTypes(std::map inLockTypes) { lockTypes() = inLockTypes; } +void SSLContext::markInitialized() { + std::lock_guard g(initMutex()); + initialized_ = true; +} + void SSLContext::initializeOpenSSL() { - std::lock_guard g(mutex_); + std::lock_guard g(initMutex()); initializeOpenSSLLocked(); } @@ -632,7 +803,7 @@ void SSLContext::initializeOpenSSLLocked() { } void SSLContext::cleanupOpenSSL() { - std::lock_guard g(mutex_); + std::lock_guard g(initMutex()); cleanupOpenSSLLocked(); } @@ -715,7 +886,7 @@ bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx, bool OpenSSLUtils::validatePeerCertNames(X509* cert, const sockaddr* addr, - socklen_t addrLen) { + socklen_t /* addrLen */) { // Try to extract the names within the SAN extension from the certificate auto altNames = reinterpret_cast(