X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=folly%2Fio%2Fasync%2FSSLContext.cpp;h=7ab01c301f8be6d56ed9b35a95a0202330b2773b;hb=17d04308e64ee7a11ad68f4b4b4c03498c3c8844;hp=38dddf971a31d880d59a20e967be6b3093a9d4ef;hpb=483f7edce871905077d99b975cf006db37e7114b;p=folly.git diff --git a/folly/io/async/SSLContext.cpp b/folly/io/async/SSLContext.cpp index 38dddf97..7ab01c30 100644 --- a/folly/io/async/SSLContext.cpp +++ b/folly/io/async/SSLContext.cpp @@ -35,7 +35,16 @@ struct CRYPTO_dynlock_value { namespace folly { bool SSLContext::initialized_ = false; -std::mutex SSLContext::mutex_; + +namespace { + +std::mutex& initMutex() { + static std::mutex m; + return m; +} + +} // anonymous namespace + #ifdef OPENSSL_NPN_NEGOTIATED int SSLContext::sNextProtocolsExDataIndex_ = -1; #endif @@ -43,7 +52,7 @@ int SSLContext::sNextProtocolsExDataIndex_ = -1; // SSLContext implementation SSLContext::SSLContext(SSLVersion version) { { - std::lock_guard g(mutex_); + std::lock_guard g(initMutex()); initializeOpenSSLLocked(); } @@ -296,13 +305,43 @@ void SSLContext::switchCiphersIfTLS11( } #endif +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) +int SSLContext::alpnSelectCallback(SSL* ssl, + const unsigned char** out, + unsigned char* outlen, + const unsigned char* in, + unsigned int inlen, + void* data) { + SSLContext* context = (SSLContext*)data; + CHECK(context); + if (context->advertisedNextProtocols_.empty()) { + *out = nullptr; + *outlen = 0; + } else { + auto i = context->pickNextProtocols(); + const auto& item = context->advertisedNextProtocols_[i]; + if (SSL_select_next_proto((unsigned char**)out, + outlen, + item.protocols, + item.length, + in, + inlen) != OPENSSL_NPN_NEGOTIATED) { + return SSL_TLSEXT_ERR_NOACK; + } + } + return SSL_TLSEXT_ERR_OK; +} +#endif + #ifdef OPENSSL_NPN_NEGOTIATED -bool SSLContext::setAdvertisedNextProtocols(const std::list& protocols) { - return setRandomizedAdvertisedNextProtocols({{1, protocols}}); + +bool SSLContext::setAdvertisedNextProtocols( + const std::list& protocols, NextProtocolType protocolType) { + return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType); } bool SSLContext::setRandomizedAdvertisedNextProtocols( - const std::list& items) { + const std::list& items, NextProtocolType protocolType) { unsetNextProtocols(); if (items.size() == 0) { return false; @@ -342,13 +381,23 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols( deleteNextProtocolsStrings(); return false; } - for (auto &advertised_item : advertisedNextProtocols_) { + for (auto& advertised_item : advertisedNextProtocols_) { advertised_item.probability /= total_weight; } - SSL_CTX_set_next_protos_advertised_cb( - ctx_, advertisedNextProtocolCallback, this); - SSL_CTX_set_next_proto_select_cb( - ctx_, selectNextProtocolCallback, this); + if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) { + SSL_CTX_set_next_protos_advertised_cb( + ctx_, advertisedNextProtocolCallback, this); + SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this); + } +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) + if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) { + SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this); + // Client cannot really use randomized alpn + SSL_CTX_set_alpn_protos(ctx_, + advertisedNextProtocols_[0].protocols, + advertisedNextProtocols_[0].length); + } +#endif return true; } @@ -363,6 +412,25 @@ void SSLContext::unsetNextProtocols() { deleteNextProtocolsStrings(); SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr); SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr); +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) + SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr); + SSL_CTX_set_alpn_protos(ctx_, nullptr, 0); +#endif +} + +size_t SSLContext::pickNextProtocols() { + unsigned char random_byte; + RAND_bytes(&random_byte, 1); + double random_value = random_byte / 255.0; + double sum = 0; + for (size_t i = 0; i < advertisedNextProtocols_.size(); ++i) { + sum += advertisedNextProtocols_[i].probability; + if (sum < random_value && i + 1 < advertisedNextProtocols_.size()) { + continue; + } + return i; + } + CHECK(false) << "Failed to pickNextProtocols"; } int SSLContext::advertisedNextProtocolCallback(SSL* ssl, @@ -382,22 +450,11 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl, *out = context->advertisedNextProtocols_[selected_index].protocols; *outlen = context->advertisedNextProtocols_[selected_index].length; } else { - unsigned char random_byte; - RAND_bytes(&random_byte, 1); - double random_value = random_byte / 255.0; - double sum = 0; - for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) { - sum += context->advertisedNextProtocols_[i].probability; - if (sum < random_value && - i + 1 < context->advertisedNextProtocols_.size()) { - continue; - } - uintptr_t selected = i + 1; - SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected); - *out = context->advertisedNextProtocols_[i].protocols; - *outlen = context->advertisedNextProtocols_[i].length; - break; - } + auto i = context->pickNextProtocols(); + uintptr_t selected = i + 1; + SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected); + *out = context->advertisedNextProtocols_[i].protocols; + *outlen = context->advertisedNextProtocols_[i].length; } } return SSL_TLSEXT_ERR_OK; @@ -406,46 +463,21 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl, #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \ FOLLY_SSLCONTEXT_USE_TLS_FALSE_START SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() : - /** - * The list was generated as follows: - * grep "_CK_" openssl-1.0.1e/ssl/s3_lib.c -A 4 | - * while read A && read B && read C && read D && read E && read F; do - * echo $A $B $C $D $E; done | - * grep "\(SSL_kDHr\|SSL_kDHd\|SSL_kEDH\|SSL_kECDHr\| - * SSL_kECDHe\|SSL_kEECDH\)" | grep -v SSL_aNULL | grep SSL_AES | - * awk -F, '{ print $1"," }' - */ ciphers_{ - TLS1_CK_DH_DSS_WITH_AES_128_SHA, - TLS1_CK_DH_RSA_WITH_AES_128_SHA, TLS1_CK_DHE_DSS_WITH_AES_128_SHA, TLS1_CK_DHE_RSA_WITH_AES_128_SHA, - TLS1_CK_DH_DSS_WITH_AES_256_SHA, - TLS1_CK_DH_RSA_WITH_AES_256_SHA, TLS1_CK_DHE_DSS_WITH_AES_256_SHA, TLS1_CK_DHE_RSA_WITH_AES_256_SHA, - TLS1_CK_DH_DSS_WITH_AES_128_SHA256, - TLS1_CK_DH_RSA_WITH_AES_128_SHA256, TLS1_CK_DHE_DSS_WITH_AES_128_SHA256, TLS1_CK_DHE_RSA_WITH_AES_128_SHA256, - TLS1_CK_DH_DSS_WITH_AES_256_SHA256, - TLS1_CK_DH_RSA_WITH_AES_256_SHA256, TLS1_CK_DHE_DSS_WITH_AES_256_SHA256, TLS1_CK_DHE_RSA_WITH_AES_256_SHA256, TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256, TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384, - TLS1_CK_DH_RSA_WITH_AES_128_GCM_SHA256, - TLS1_CK_DH_RSA_WITH_AES_256_GCM_SHA384, TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256, TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384, - TLS1_CK_DH_DSS_WITH_AES_128_GCM_SHA256, - TLS1_CK_DH_DSS_WITH_AES_256_GCM_SHA384, - TLS1_CK_ECDH_ECDSA_WITH_AES_128_CBC_SHA, - TLS1_CK_ECDH_ECDSA_WITH_AES_256_CBC_SHA, TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - TLS1_CK_ECDH_RSA_WITH_AES_128_CBC_SHA, - TLS1_CK_ECDH_RSA_WITH_AES_256_CBC_SHA, TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA, TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA, TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256, @@ -454,15 +486,10 @@ SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() : TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384, TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256, TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384, - TLS1_CK_ECDH_RSA_WITH_AES_128_SHA256, - TLS1_CK_ECDH_RSA_WITH_AES_256_SHA384, TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - TLS1_CK_ECDH_ECDSA_WITH_AES_128_GCM_SHA256, - TLS1_CK_ECDH_ECDSA_WITH_AES_256_GCM_SHA384, TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - TLS1_CK_ECDH_RSA_WITH_AES_128_GCM_SHA256, } { length_ = sizeof(ciphers_)/sizeof(ciphers_[0]); width_ = sizeof(ciphers_[0]); @@ -504,13 +531,21 @@ int SSLContext::selectNextProtocolCallback( } unsigned char *client; - int client_len; - if (ctx->advertisedNextProtocols_.empty()) { - client = (unsigned char *) ""; - client_len = 0; - } else { - client = ctx->advertisedNextProtocols_[0].protocols; - client_len = ctx->advertisedNextProtocols_[0].length; + unsigned int client_len; + bool filtered = false; + auto cpf = ctx->getClientProtocolFilterCallback(); + if (cpf) { + filtered = (*cpf)(&client, &client_len, server, server_len); + } + + if (!filtered) { + if (ctx->advertisedNextProtocols_.empty()) { + client = (unsigned char *) ""; + client_len = 0; + } else { + client = ctx->advertisedNextProtocols_[0].protocols; + client_len = ctx->advertisedNextProtocols_[0].length; + } } int retval = SSL_select_next_proto(out, outlen, server, server_len, @@ -626,17 +661,13 @@ struct SSLLock { // SSLContext runs in such environments. // Instead of declaring a static member we "new" the static // member so that it won't be destructed on exit(). -static std::map* lockTypesInst = - new std::map(); - -static std::unique_ptr* locksInst = - new std::unique_ptr(); - static std::unique_ptr& locks() { + static auto locksInst = new std::unique_ptr(); return *locksInst; } static std::map& lockTypes() { + static auto lockTypesInst = new std::map(); return *lockTypesInst; } @@ -683,12 +714,12 @@ void SSLContext::setSSLLockTypes(std::map inLockTypes) { } void SSLContext::markInitialized() { - std::lock_guard g(mutex_); + std::lock_guard g(initMutex()); initialized_ = true; } void SSLContext::initializeOpenSSL() { - std::lock_guard g(mutex_); + std::lock_guard g(initMutex()); initializeOpenSSLLocked(); } @@ -719,7 +750,7 @@ void SSLContext::initializeOpenSSLLocked() { } void SSLContext::cleanupOpenSSL() { - std::lock_guard g(mutex_); + std::lock_guard g(initMutex()); cleanupOpenSSLLocked(); }