X-Git-Url: http://plrg.eecs.uci.edu/git/?p=folly.git;a=blobdiff_plain;f=folly%2Fio%2Fasync%2FSSLContext.cpp;h=7ce05f44a7a062c974ffe8a65dd17c1bc1293f40;hp=7ab01c301f8be6d56ed9b35a95a0202330b2773b;hb=3e19d28a142149241d81c5e736aa4117fe7cbec8;hpb=17d04308e64ee7a11ad68f4b4b4c03498c3c8844 diff --git a/folly/io/async/SSLContext.cpp b/folly/io/async/SSLContext.cpp index 7ab01c30..7ce05f44 100644 --- a/folly/io/async/SSLContext.cpp +++ b/folly/io/async/SSLContext.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2015 Facebook, Inc. + * Copyright 2017 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,45 +16,25 @@ #include "SSLContext.h" -#include -#include -#include -#include - #include +#include +#include +#include #include +#include +#include // --------------------------------------------------------------------- // SSLContext implementation // --------------------------------------------------------------------- - -struct CRYPTO_dynlock_value { - std::mutex mutex; -}; - namespace folly { - -bool SSLContext::initialized_ = false; - -namespace { - -std::mutex& initMutex() { - static std::mutex m; - return m; -} - -} // anonymous namespace - -#ifdef OPENSSL_NPN_NEGOTIATED -int SSLContext::sNextProtocolsExDataIndex_ = -1; -#endif +// +// For OpenSSL portability API +using namespace folly::ssl; // SSLContext implementation SSLContext::SSLContext(SSLVersion version) { - { - std::lock_guard g(initMutex()); - initializeOpenSSLLocked(); - } + folly::ssl::init(); ctx_ = SSL_CTX_new(SSLv23_method()); if (ctx_ == nullptr) { @@ -69,6 +49,10 @@ SSLContext::SSLContext(SSLVersion version) { case SSLv3: opt = SSL_OP_NO_SSLv2; break; + case TLSv1_2: + opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 | + SSL_OP_NO_TLSv1_1; + break; default: // do nothing break; @@ -80,7 +64,9 @@ SSLContext::SSLContext(SSLVersion version) { checkPeerName_ = false; -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) + SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION); + +#if FOLLY_OPENSSL_HAS_SNI SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback); SSL_CTX_set_tlsext_servername_arg(ctx_, this); #endif @@ -98,18 +84,68 @@ SSLContext::~SSLContext() { } void SSLContext::ciphers(const std::string& ciphers) { - providedCiphersString_ = ciphers; setCiphersOrThrow(ciphers); } +void SSLContext::setClientECCurvesList( + const std::vector& ecCurves) { + if (ecCurves.size() == 0) { + return; + } +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL + std::string ecCurvesList; + join(":", ecCurves, ecCurvesList); + int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str()); + if (rc == 0) { + throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors()); + } +#endif +} + +void SSLContext::setServerECCurve(const std::string& curveName) { +#if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH) + EC_KEY* ecdh = nullptr; + int nid; + + /* + * Elliptic-Curve Diffie-Hellman parameters are either "named curves" + * from RFC 4492 section 5.1.1, or explicitly described curves over + * binary fields. OpenSSL only supports the "named curves", which provide + * maximum interoperability. + */ + + nid = OBJ_sn2nid(curveName.c_str()); + if (nid == 0) { + LOG(FATAL) << "Unknown curve name:" << curveName.c_str(); + } + ecdh = EC_KEY_new_by_curve_name(nid); + if (ecdh == nullptr) { + LOG(FATAL) << "Unable to create curve:" << curveName.c_str(); + } + + SSL_CTX_set_tmp_ecdh(ctx_, ecdh); + EC_KEY_free(ecdh); +#else + throw std::runtime_error("Elliptic curve encryption not allowed"); +#endif +} + +void SSLContext::setX509VerifyParam( + const ssl::X509VerifyParam& x509VerifyParam) { + if (!x509VerifyParam) { + return; + } + if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) { + throw std::runtime_error("SSL_CTX_set1_param " + getErrors()); + } +} + void SSLContext::setCiphersOrThrow(const std::string& ciphers) { int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str()); - if (ERR_peek_error() != 0) { - throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors()); - } if (rc == 0) { - throw std::runtime_error("None of specified ciphers are supported"); + throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors()); } + providedCiphersString_ = ciphers; } void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum& @@ -152,7 +188,8 @@ void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName, const std::string& peerName) { int mode; if (checkPeerCert) { - mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE; + mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | + SSL_VERIFY_CLIENT_ONCE; checkPeerName_ = checkPeerName; peerFixedName_ = peerName; } else { @@ -178,32 +215,106 @@ void SSLContext::loadCertificate(const char* path, const char* format) { throw std::runtime_error(reason); } } else { - throw std::runtime_error("Unsupported certificate format: " + std::string(format)); + throw std::runtime_error( + "Unsupported certificate format: " + std::string(format)); + } +} + +void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) { + if (cert.data() == nullptr) { + throw std::invalid_argument("loadCertificate: is nullptr"); + } + + ssl::BioUniquePtr bio(BIO_new(BIO_s_mem())); + if (bio == nullptr) { + throw std::runtime_error("BIO_new: " + getErrors()); + } + + int written = BIO_write(bio.get(), cert.data(), int(cert.size())); + if (written <= 0 || static_cast(written) != cert.size()) { + throw std::runtime_error("BIO_write: " + getErrors()); + } + + ssl::X509UniquePtr x509( + PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); + if (x509 == nullptr) { + throw std::runtime_error("PEM_read_bio_X509: " + getErrors()); + } + + if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) { + throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors()); } } void SSLContext::loadPrivateKey(const char* path, const char* format) { if (path == nullptr || format == nullptr) { throw std::invalid_argument( - "loadPrivateKey: either or is nullptr"); + "loadPrivateKey: either or is nullptr"); } if (strcmp(format, "PEM") == 0) { if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) { throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors()); } } else { - throw std::runtime_error("Unsupported private key format: " + std::string(format)); + throw std::runtime_error( + "Unsupported private key format: " + std::string(format)); + } +} + +void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) { + if (pkey.data() == nullptr) { + throw std::invalid_argument("loadPrivateKey: is nullptr"); + } + + ssl::BioUniquePtr bio(BIO_new(BIO_s_mem())); + if (bio == nullptr) { + throw std::runtime_error("BIO_new: " + getErrors()); + } + + int written = BIO_write(bio.get(), pkey.data(), int(pkey.size())); + if (written <= 0 || static_cast(written) != pkey.size()) { + throw std::runtime_error("BIO_write: " + getErrors()); + } + + ssl::EvpPkeyUniquePtr key( + PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr)); + if (key == nullptr) { + throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors()); + } + + if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) { + throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors()); } } +void SSLContext::loadCertKeyPairFromBufferPEM( + folly::StringPiece cert, + folly::StringPiece pkey) { + loadCertificateFromBufferPEM(cert); + loadPrivateKeyFromBufferPEM(pkey); +} + +void SSLContext::loadCertKeyPairFromFiles( + const char* certPath, + const char* keyPath, + const char* certFormat, + const char* keyFormat) { + loadCertificate(certPath, certFormat); + loadPrivateKey(keyPath, keyFormat); +} + +bool SSLContext::isCertKeyPairValid() const { + return SSL_CTX_check_private_key(ctx_) == 1; +} + void SSLContext::loadTrustedCertificates(const char* path) { if (path == nullptr) { - throw std::invalid_argument( - "loadTrustedCertificates: is nullptr"); + throw std::invalid_argument("loadTrustedCertificates: is nullptr"); } if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) { throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors()); } + ERR_clear_error(); } void SSLContext::loadTrustedCertificates(X509_STORE* store) { @@ -219,11 +330,8 @@ void SSLContext::loadClientCAList(const char* path) { SSL_CTX_set_client_CA_list(ctx_, clientCAs); } -void SSLContext::randomize() { - RAND_poll(); -} - -void SSLContext::passwordCollector(std::shared_ptr collector) { +void SSLContext::passwordCollector( + std::shared_ptr collector) { if (collector == nullptr) { LOG(ERROR) << "passwordCollector: ignore invalid password collector"; return; @@ -233,7 +341,7 @@ void SSLContext::passwordCollector(std::shared_ptr collector) SSL_CTX_set_default_passwd_cb_userdata(ctx_, this); } -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI void SSLContext::setServerNameCallback(const ServerNameCallback& cb) { serverNameCb_ = cb; @@ -279,34 +387,10 @@ int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) { return SSL_TLSEXT_ERR_NOACK; } +#endif // FOLLY_OPENSSL_HAS_SNI -void SSLContext::switchCiphersIfTLS11( - SSL* ssl, - const std::string& tls11CipherString) { - - CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers"; - - if (TLS1_get_client_version(ssl) <= TLS1_VERSION) { - // We only do this for TLS v 1.1 and later - return; - } - - // Prefer AES for TLS versions 1.1 and later since these are not - // vulnerable to BEAST attacks on AES. Note that we're setting the - // cipher list on the SSL object, not the SSL_CTX object, so it will - // only last for this request. - int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str()); - if ((rc == 0) || ERR_peek_error() != 0) { - // This shouldn't happen since we checked for this when proxygen - // started up. - LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch"; - SSL_set_cipher_list(ssl, providedCiphersString_.c_str()); - } -} -#endif - -#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) -int SSLContext::alpnSelectCallback(SSL* ssl, +#if FOLLY_OPENSSL_HAS_ALPN +int SSLContext::alpnSelectCallback(SSL* /* ssl */, const unsigned char** out, unsigned char* outlen, const unsigned char* in, @@ -331,7 +415,7 @@ int SSLContext::alpnSelectCallback(SSL* ssl, } return SSL_TLSEXT_ERR_OK; } -#endif +#endif // FOLLY_OPENSSL_HAS_ALPN #ifdef OPENSSL_NPN_NEGOTIATED @@ -355,12 +439,12 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols( advertised_item.length = 0; for (const auto& proto : item.protocols) { ++advertised_item.length; - unsigned protoLength = proto.length(); + auto protoLength = proto.length(); if (protoLength >= 256) { deleteNextProtocolsStrings(); return false; } - advertised_item.length += protoLength; + advertised_item.length += unsigned(protoLength); } advertised_item.protocols = new unsigned char[advertised_item.length]; if (!advertised_item.protocols) { @@ -368,28 +452,28 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols( } unsigned char* dst = advertised_item.protocols; for (auto& proto : item.protocols) { - unsigned protoLength = proto.length(); + uint8_t protoLength = uint8_t(proto.length()); *dst++ = (unsigned char)protoLength; memcpy(dst, proto.data(), protoLength); dst += protoLength; } total_weight += item.weight; - advertised_item.probability = item.weight; advertisedNextProtocols_.push_back(advertised_item); + advertisedNextProtocolWeights_.push_back(item.weight); } if (total_weight == 0) { deleteNextProtocolsStrings(); return false; } - for (auto& advertised_item : advertisedNextProtocols_) { - advertised_item.probability /= total_weight; - } + nextProtocolDistribution_ = + std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(), + advertisedNextProtocolWeights_.end()); if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) { SSL_CTX_set_next_protos_advertised_cb( ctx_, advertisedNextProtocolCallback, this); SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this); } -#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_ALPN if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) { SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this); // Client cannot really use randomized alpn @@ -406,35 +490,30 @@ void SSLContext::deleteNextProtocolsStrings() { delete[] protocols.protocols; } advertisedNextProtocols_.clear(); + advertisedNextProtocolWeights_.clear(); } void SSLContext::unsetNextProtocols() { deleteNextProtocolsStrings(); SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr); SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr); -#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_ALPN SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr); SSL_CTX_set_alpn_protos(ctx_, nullptr, 0); #endif } size_t SSLContext::pickNextProtocols() { - unsigned char random_byte; - RAND_bytes(&random_byte, 1); - double random_value = random_byte / 255.0; - double sum = 0; - for (size_t i = 0; i < advertisedNextProtocols_.size(); ++i) { - sum += advertisedNextProtocols_[i].probability; - if (sum < random_value && i + 1 < advertisedNextProtocols_.size()) { - continue; - } - return i; - } - CHECK(false) << "Failed to pickNextProtocols"; + CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols"; + auto rng = ThreadLocalPRNG(); + return size_t(nextProtocolDistribution_(rng)); } int SSLContext::advertisedNextProtocolCallback(SSL* ssl, const unsigned char** out, unsigned int* outlen, void* data) { + static int nextProtocolsExDataIndex = SSL_get_ex_new_index( + 0, (void*)"Advertised next protocol index", nullptr, nullptr, nullptr); + SSLContext* context = (SSLContext*)data; if (context == nullptr || context->advertisedNextProtocols_.empty()) { *out = nullptr; @@ -443,8 +522,8 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl, *out = context->advertisedNextProtocols_[0].protocols; *outlen = context->advertisedNextProtocols_[0].length; } else { - uintptr_t selected_index = reinterpret_cast(SSL_get_ex_data(ssl, - sNextProtocolsExDataIndex_)); + uintptr_t selected_index = reinterpret_cast( + SSL_get_ex_data(ssl, nextProtocolsExDataIndex)); if (selected_index) { --selected_index; *out = context->advertisedNextProtocols_[selected_index].protocols; @@ -452,7 +531,7 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl, } else { auto i = context->pickNextProtocols(); uintptr_t selected = i + 1; - SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected); + SSL_set_ex_data(ssl, nextProtocolsExDataIndex, (void*)selected); *out = context->advertisedNextProtocols_[i].protocols; *outlen = context->advertisedNextProtocols_[i].length; } @@ -460,78 +539,21 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl, return SSL_TLSEXT_ERR_OK; } -#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \ - FOLLY_SSLCONTEXT_USE_TLS_FALSE_START -SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() : - ciphers_{ - TLS1_CK_DHE_DSS_WITH_AES_128_SHA, - TLS1_CK_DHE_RSA_WITH_AES_128_SHA, - TLS1_CK_DHE_DSS_WITH_AES_256_SHA, - TLS1_CK_DHE_RSA_WITH_AES_256_SHA, - TLS1_CK_DHE_DSS_WITH_AES_128_SHA256, - TLS1_CK_DHE_RSA_WITH_AES_128_SHA256, - TLS1_CK_DHE_DSS_WITH_AES_256_SHA256, - TLS1_CK_DHE_RSA_WITH_AES_256_SHA256, - TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256, - TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384, - TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256, - TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384, - TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA, - TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA, - TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256, - TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384, - TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256, - TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384, - TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256, - TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384, - TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - } { - length_ = sizeof(ciphers_)/sizeof(ciphers_[0]); - width_ = sizeof(ciphers_[0]); - qsort(ciphers_, length_, width_, compare_ulong); -} - -bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher( - const SSL_CIPHER *cipher) { - unsigned long cid = cipher->id; - unsigned long *r = - (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong); - return r != nullptr; -} - -int -SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) { - if (*(unsigned long *)x < *(unsigned long *)y) { - return -1; - } - if (*(unsigned long *)x > *(unsigned long *)y) { - return 1; - } - return 0; -}; - -bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) { - return falseStartChecker_.canUseFalseStartWithCipher(cipher); -} -#endif - -int SSLContext::selectNextProtocolCallback( - SSL* ssl, unsigned char **out, unsigned char *outlen, - const unsigned char *server, unsigned int server_len, void *data) { - +int SSLContext::selectNextProtocolCallback(SSL* ssl, + unsigned char** out, + unsigned char* outlen, + const unsigned char* server, + unsigned int server_len, + void* data) { + (void)ssl; // Make -Wunused-parameters happy SSLContext* ctx = (SSLContext*)data; if (ctx->advertisedNextProtocols_.size() > 1) { VLOG(3) << "SSLContext::selectNextProcolCallback() " << "client should be deterministic in selecting protocols."; } - unsigned char *client; - unsigned int client_len; + unsigned char* client = nullptr; + unsigned int client_len = 0; bool filtered = false; auto cpf = ctx->getClientProtocolFilterCallback(); if (cpf) { @@ -553,14 +575,6 @@ int SSLContext::selectNextProtocolCallback( if (retval != OPENSSL_NPN_NEGOTIATED) { VLOG(3) << "SSLContext::selectNextProcolCallback() " << "unable to pick a next protocol."; -#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \ - FOLLY_SSLCONTEXT_USE_TLS_FALSE_START - } else { - const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher; - if (cipher && ctx->canUseFalseStartWithCipher(cipher)) { - SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH); - } -#endif } return SSL_TLSEXT_ERR_OK; } @@ -574,6 +588,14 @@ SSL* SSLContext::createSSL() const { return ssl; } +void SSLContext::setSessionCacheContext(const std::string& context) { + SSL_CTX_set_session_id_context( + ctx_, + reinterpret_cast(context.data()), + std::min( + static_cast(context.length()), SSL_MAX_SID_CTX_LENGTH)); +} + /** * Match a name with a pattern. The pattern may include wildcard. A single * wildcard "*" can match up to one component in the domain name. @@ -618,158 +640,19 @@ int SSLContext::passwordCallback(char* password, std::string userPassword; // call user defined password collector to get password context->passwordCollector()->getPassword(userPassword, size); - int length = userPassword.size(); - if (length > size) { - length = size; - } - strncpy(password, userPassword.c_str(), length); - return length; + auto const length = std::min(userPassword.size(), size_t(size)); + std::memcpy(password, userPassword.data(), length); + return int(length); } -struct SSLLock { - explicit SSLLock( - SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) : - lockType(inLockType) { - } - - void lock() { - if (lockType == SSLContext::LOCK_MUTEX) { - mutex.lock(); - } else if (lockType == SSLContext::LOCK_SPINLOCK) { - spinLock.lock(); - } - // lockType == LOCK_NONE, no-op - } - - void unlock() { - if (lockType == SSLContext::LOCK_MUTEX) { - mutex.unlock(); - } else if (lockType == SSLContext::LOCK_SPINLOCK) { - spinLock.unlock(); - } - // lockType == LOCK_NONE, no-op - } - - SSLContext::SSLLockType lockType; - folly::SpinLock spinLock{}; - std::mutex mutex; -}; - -// Statics are unsafe in environments that call exit(). -// If one thread calls exit() while another thread is -// references a member of SSLContext, bad things can happen. -// SSLContext runs in such environments. -// Instead of declaring a static member we "new" the static -// member so that it won't be destructed on exit(). -static std::unique_ptr& locks() { - static auto locksInst = new std::unique_ptr(); - return *locksInst; -} - -static std::map& lockTypes() { - static auto lockTypesInst = new std::map(); - return *lockTypesInst; -} - -static void callbackLocking(int mode, int n, const char*, int) { - if (mode & CRYPTO_LOCK) { - locks()[n].lock(); - } else { - locks()[n].unlock(); - } +#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) +void SSLContext::enableFalseStart() { + SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH); } - -static unsigned long callbackThreadID() { - return static_cast( -#ifdef __APPLE__ - pthread_mach_thread_np(pthread_self()) -#else - pthread_self() #endif - ); -} - -static CRYPTO_dynlock_value* dyn_create(const char*, int) { - return new CRYPTO_dynlock_value; -} - -static void dyn_lock(int mode, - struct CRYPTO_dynlock_value* lock, - const char*, int) { - if (lock != nullptr) { - if (mode & CRYPTO_LOCK) { - lock->mutex.lock(); - } else { - lock->mutex.unlock(); - } - } -} - -static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) { - delete lock; -} - -void SSLContext::setSSLLockTypes(std::map inLockTypes) { - lockTypes() = inLockTypes; -} - -void SSLContext::markInitialized() { - std::lock_guard g(initMutex()); - initialized_ = true; -} void SSLContext::initializeOpenSSL() { - std::lock_guard g(initMutex()); - initializeOpenSSLLocked(); -} - -void SSLContext::initializeOpenSSLLocked() { - if (initialized_) { - return; - } - SSL_library_init(); - SSL_load_error_strings(); - ERR_load_crypto_strings(); - // static locking - locks().reset(new SSLLock[::CRYPTO_num_locks()]); - for (auto it: lockTypes()) { - locks()[it.first].lockType = it.second; - } - CRYPTO_set_id_callback(callbackThreadID); - CRYPTO_set_locking_callback(callbackLocking); - // dynamic locking - CRYPTO_set_dynlock_create_callback(dyn_create); - CRYPTO_set_dynlock_lock_callback(dyn_lock); - CRYPTO_set_dynlock_destroy_callback(dyn_destroy); - randomize(); -#ifdef OPENSSL_NPN_NEGOTIATED - sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0, - (void*)"Advertised next protocol index", nullptr, nullptr, nullptr); -#endif - initialized_ = true; -} - -void SSLContext::cleanupOpenSSL() { - std::lock_guard g(initMutex()); - cleanupOpenSSLLocked(); -} - -void SSLContext::cleanupOpenSSLLocked() { - if (!initialized_) { - return; - } - - CRYPTO_set_id_callback(nullptr); - CRYPTO_set_locking_callback(nullptr); - CRYPTO_set_dynlock_create_callback(nullptr); - CRYPTO_set_dynlock_lock_callback(nullptr); - CRYPTO_set_dynlock_destroy_callback(nullptr); - CRYPTO_cleanup_all_ex_data(); - ERR_free_strings(); - EVP_cleanup(); - ERR_remove_state(0); - locks().reset(); - initialized_ = false; + folly::ssl::init(); } void SSLContext::setOptions(long options) { @@ -808,83 +691,4 @@ operator<<(std::ostream& os, const PasswordCollector& collector) { return os; } -bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx, - sockaddr_storage* addrStorage, - socklen_t* addrLen) { - // Grab the ssl idx and then the ssl object so that we can get the peer - // name to compare against the ips in the subjectAltName - auto sslIdx = SSL_get_ex_data_X509_STORE_CTX_idx(); - auto ssl = - reinterpret_cast(X509_STORE_CTX_get_ex_data(ctx, sslIdx)); - int fd = SSL_get_fd(ssl); - if (fd < 0) { - LOG(ERROR) << "Inexplicably couldn't get fd from SSL"; - return false; - } - - *addrLen = sizeof(*addrStorage); - if (getpeername(fd, reinterpret_cast(addrStorage), addrLen) != 0) { - PLOG(ERROR) << "Unable to get peer name"; - return false; - } - CHECK(*addrLen <= sizeof(*addrStorage)); - return true; -} - -bool OpenSSLUtils::validatePeerCertNames(X509* cert, - const sockaddr* addr, - socklen_t addrLen) { - // Try to extract the names within the SAN extension from the certificate - auto altNames = - reinterpret_cast( - X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)); - SCOPE_EXIT { - if (altNames != nullptr) { - sk_GENERAL_NAME_pop_free(altNames, GENERAL_NAME_free); - } - }; - if (altNames == nullptr) { - LOG(WARNING) << "No subjectAltName provided and we only support ip auth"; - return false; - } - - const sockaddr_in* addr4 = nullptr; - const sockaddr_in6* addr6 = nullptr; - if (addr != nullptr) { - if (addr->sa_family == AF_INET) { - addr4 = reinterpret_cast(addr); - } else if (addr->sa_family == AF_INET6) { - addr6 = reinterpret_cast(addr); - } else { - LOG(FATAL) << "Unsupported sockaddr family: " << addr->sa_family; - } - } - - - for (int i = 0; i < sk_GENERAL_NAME_num(altNames); i++) { - auto name = sk_GENERAL_NAME_value(altNames, i); - if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) { - // Extra const-ness for paranoia - unsigned char const * const rawIpStr = name->d.iPAddress->data; - int const rawIpLen = name->d.iPAddress->length; - - if (rawIpLen == 4 && addr4 != nullptr) { - if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) { - return true; - } - } else if (rawIpLen == 16 && addr6 != nullptr) { - if (::memcmp(rawIpStr, &addr6->sin6_addr, rawIpLen) == 0) { - return true; - } - } else if (rawIpLen != 4 && rawIpLen != 16) { - LOG(WARNING) << "Unexpected IP length: " << rawIpLen; - } - } - } - - LOG(WARNING) << "Unable to match client cert against alt name ip"; - return false; -} - - -} // folly +} // namespace folly