X-Git-Url: http://plrg.eecs.uci.edu/git/?p=folly.git;a=blobdiff_plain;f=folly%2Fio%2Fasync%2FSSLContext.cpp;h=4a95a5723c267e5f715567454416de25675061b8;hp=3abb5d0b8a307439a2ec09d3a7c19ac32f014c06;hb=3764b633f977129f8ee3bca60db7c5d1bb969eec;hpb=2d466553db47cf7c3d7d98837466d3b3d6874bae diff --git a/folly/io/async/SSLContext.cpp b/folly/io/async/SSLContext.cpp index 3abb5d0b..4a95a572 100644 --- a/folly/io/async/SSLContext.cpp +++ b/folly/io/async/SSLContext.cpp @@ -19,43 +19,22 @@ #include #include #include +#include #include -#include +#include +#include // --------------------------------------------------------------------- // SSLContext implementation // --------------------------------------------------------------------- - -struct CRYPTO_dynlock_value { - std::mutex mutex; -}; - namespace folly { // // For OpenSSL portability API using namespace folly::ssl; -bool SSLContext::initialized_ = false; - -namespace { - -std::mutex& initMutex() { - static std::mutex m; - return m; -} - -} // anonymous namespace - -#ifdef OPENSSL_NPN_NEGOTIATED -int SSLContext::sNextProtocolsExDataIndex_ = -1; -#endif - // SSLContext implementation SSLContext::SSLContext(SSLVersion version) { - { - std::lock_guard g(initMutex()); - initializeOpenSSLLocked(); - } + folly::ssl::init(); ctx_ = SSL_CTX_new(SSLv23_method()); if (ctx_ == nullptr) { @@ -70,6 +49,10 @@ SSLContext::SSLContext(SSLVersion version) { case SSLv3: opt = SSL_OP_NO_SSLv2; break; + case TLSv1_2: + opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 | + SSL_OP_NO_TLSv1_1; + break; default: // do nothing break; @@ -101,34 +84,9 @@ SSLContext::~SSLContext() { } void SSLContext::ciphers(const std::string& ciphers) { - providedCiphersString_ = ciphers; setCiphersOrThrow(ciphers); } -void SSLContext::setCipherList(const std::vector& ciphers) { - if (ciphers.size() == 0) { - return; - } - std::string opensslCipherList; - join(":", ciphers, opensslCipherList); - setCiphersOrThrow(opensslCipherList); -} - -void SSLContext::setSignatureAlgorithms( - const std::vector& sigalgs) { - if (sigalgs.size() == 0) { - return; - } -#if OPENSSL_VERSION_NUMBER >= 0x1000200fL - std::string opensslSigAlgsList; - join(":", sigalgs, opensslSigAlgsList); - int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str()); - if (rc == 0) { - throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors()); - } -#endif -} - void SSLContext::setClientECCurvesList( const std::vector& ecCurves) { if (ecCurves.size() == 0) { @@ -187,6 +145,7 @@ void SSLContext::setCiphersOrThrow(const std::string& ciphers) { if (rc == 0) { throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors()); } + providedCiphersString_ = ciphers; } void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum& @@ -229,7 +188,8 @@ void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName, const std::string& peerName) { int mode; if (checkPeerCert) { - mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE; + mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | + SSL_VERIFY_CLIENT_ONCE; checkPeerName_ = checkPeerName; peerFixedName_ = peerName; } else { @@ -246,7 +206,7 @@ void SSLContext::loadCertificate(const char* path, const char* format) { "loadCertificateChain: either or is nullptr"); } if (strcmp(format, "PEM") == 0) { - if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) { + if (SSL_CTX_use_certificate_chain_file(ctx_, path) != 1) { int errnoCopy = errno; std::string reason("SSL_CTX_use_certificate_chain_file: "); reason.append(path); @@ -255,7 +215,8 @@ void SSLContext::loadCertificate(const char* path, const char* format) { throw std::runtime_error(reason); } } else { - throw std::runtime_error("Unsupported certificate format: " + std::string(format)); + throw std::runtime_error( + "Unsupported certificate format: " + std::string(format)); } } @@ -295,7 +256,8 @@ void SSLContext::loadPrivateKey(const char* path, const char* format) { throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors()); } } else { - throw std::runtime_error("Unsupported private key format: " + std::string(format)); + throw std::runtime_error( + "Unsupported private key format: " + std::string(format)); } } @@ -325,6 +287,32 @@ void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) { } } +void SSLContext::loadCertKeyPairFromBufferPEM( + folly::StringPiece cert, + folly::StringPiece pkey) { + loadCertificateFromBufferPEM(cert); + loadPrivateKeyFromBufferPEM(pkey); + if (!isCertKeyPairValid()) { + throw std::runtime_error("SSL certificate and private key do not match"); + } +} + +void SSLContext::loadCertKeyPairFromFiles( + const char* certPath, + const char* keyPath, + const char* certFormat, + const char* keyFormat) { + loadCertificate(certPath, certFormat); + loadPrivateKey(keyPath, keyFormat); + if (!isCertKeyPairValid()) { + throw std::runtime_error("SSL certificate and private key do not match"); + } +} + +bool SSLContext::isCertKeyPairValid() const { + return SSL_CTX_check_private_key(ctx_) == 1; +} + void SSLContext::loadTrustedCertificates(const char* path) { if (path == nullptr) { throw std::invalid_argument("loadTrustedCertificates: is nullptr"); @@ -348,11 +336,8 @@ void SSLContext::loadClientCAList(const char* path) { SSL_CTX_set_client_CA_list(ctx_, clientCAs); } -void SSLContext::randomize() { - RAND_poll(); -} - -void SSLContext::passwordCollector(std::shared_ptr collector) { +void SSLContext::passwordCollector( + std::shared_ptr collector) { if (collector == nullptr) { LOG(ERROR) << "passwordCollector: ignore invalid password collector"; return; @@ -408,55 +393,6 @@ int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) { return SSL_TLSEXT_ERR_NOACK; } - -void SSLContext::switchCiphersIfTLS11( - SSL* ssl, - const std::string& tls11CipherString, - const std::vector>& tls11AltCipherlist) { - CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty())) - << "Shouldn't call if empty ciphers / alt ciphers"; - - if (TLS1_get_client_version(ssl) <= TLS1_VERSION) { - // We only do this for TLS v 1.1 and later - return; - } - - const std::string* ciphers = &tls11CipherString; - if (!tls11AltCipherlist.empty()) { - if (!cipherListPicker_) { - std::vector weights; - std::for_each( - tls11AltCipherlist.begin(), - tls11AltCipherlist.end(), - [&](const std::pair& e) { - weights.push_back(e.second); - }); - cipherListPicker_.reset( - new std::discrete_distribution(weights.begin(), weights.end())); - } - auto rng = ThreadLocalPRNG(); - auto index = (*cipherListPicker_)(rng); - if ((size_t)index >= tls11AltCipherlist.size()) { - LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index - << ", but tls11AltCipherlist is of length " - << tls11AltCipherlist.size(); - } else { - ciphers = &tls11AltCipherlist[size_t(index)].first; - } - } - - // Prefer AES for TLS versions 1.1 and later since these are not - // vulnerable to BEAST attacks on AES. Note that we're setting the - // cipher list on the SSL object, not the SSL_CTX object, so it will - // only last for this request. - int rc = SSL_set_cipher_list(ssl, ciphers->c_str()); - if ((rc == 0) || ERR_peek_error() != 0) { - // This shouldn't happen since we checked for this when proxygen - // started up. - LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch"; - SSL_set_cipher_list(ssl, providedCiphersString_.c_str()); - } -} #endif // FOLLY_OPENSSL_HAS_SNI #if FOLLY_OPENSSL_HAS_ALPN @@ -581,6 +517,9 @@ size_t SSLContext::pickNextProtocols() { int SSLContext::advertisedNextProtocolCallback(SSL* ssl, const unsigned char** out, unsigned int* outlen, void* data) { + static int nextProtocolsExDataIndex = SSL_get_ex_new_index( + 0, (void*)"Advertised next protocol index", nullptr, nullptr, nullptr); + SSLContext* context = (SSLContext*)data; if (context == nullptr || context->advertisedNextProtocols_.empty()) { *out = nullptr; @@ -589,8 +528,8 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl, *out = context->advertisedNextProtocols_[0].protocols; *outlen = context->advertisedNextProtocols_[0].length; } else { - uintptr_t selected_index = reinterpret_cast(SSL_get_ex_data(ssl, - sNextProtocolsExDataIndex_)); + uintptr_t selected_index = reinterpret_cast( + SSL_get_ex_data(ssl, nextProtocolsExDataIndex)); if (selected_index) { --selected_index; *out = context->advertisedNextProtocols_[selected_index].protocols; @@ -598,7 +537,7 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl, } else { auto i = context->pickNextProtocols(); uintptr_t selected = i + 1; - SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected); + SSL_set_ex_data(ssl, nextProtocolsExDataIndex, (void*)selected); *out = context->advertisedNextProtocols_[i].protocols; *outlen = context->advertisedNextProtocols_[i].length; } @@ -660,8 +599,7 @@ void SSLContext::setSessionCacheContext(const std::string& context) { ctx_, reinterpret_cast(context.data()), std::min( - static_cast(context.length()), - SSL_MAX_SSL_SESSION_ID_LENGTH)); + static_cast(context.length()), SSL_MAX_SID_CTX_LENGTH)); } /** @@ -708,107 +646,9 @@ int SSLContext::passwordCallback(char* password, std::string userPassword; // call user defined password collector to get password context->passwordCollector()->getPassword(userPassword, size); - auto length = int(userPassword.size()); - if (length > size) { - length = size; - } - strncpy(password, userPassword.c_str(), size_t(length)); - return length; -} - -struct SSLLock { - explicit SSLLock( - SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) : - lockType(inLockType) { - } - - void lock() { - if (lockType == SSLContext::LOCK_MUTEX) { - mutex.lock(); - } else if (lockType == SSLContext::LOCK_SPINLOCK) { - spinLock.lock(); - } - // lockType == LOCK_NONE, no-op - } - - void unlock() { - if (lockType == SSLContext::LOCK_MUTEX) { - mutex.unlock(); - } else if (lockType == SSLContext::LOCK_SPINLOCK) { - spinLock.unlock(); - } - // lockType == LOCK_NONE, no-op - } - - SSLContext::SSLLockType lockType; - folly::SpinLock spinLock{}; - std::mutex mutex; -}; - -// Statics are unsafe in environments that call exit(). -// If one thread calls exit() while another thread is -// references a member of SSLContext, bad things can happen. -// SSLContext runs in such environments. -// Instead of declaring a static member we "new" the static -// member so that it won't be destructed on exit(). -static std::unique_ptr& locks() { - static auto locksInst = new std::unique_ptr(); - return *locksInst; -} - -static std::map& lockTypes() { - static auto lockTypesInst = new std::map(); - return *lockTypesInst; -} - -static void callbackLocking(int mode, int n, const char*, int) { - if (mode & CRYPTO_LOCK) { - locks()[size_t(n)].lock(); - } else { - locks()[size_t(n)].unlock(); - } -} - -static unsigned long callbackThreadID() { - return static_cast(folly::getCurrentThreadID()); -} - -static CRYPTO_dynlock_value* dyn_create(const char*, int) { - return new CRYPTO_dynlock_value; -} - -static void dyn_lock(int mode, - struct CRYPTO_dynlock_value* lock, - const char*, int) { - if (lock != nullptr) { - if (mode & CRYPTO_LOCK) { - lock->mutex.lock(); - } else { - lock->mutex.unlock(); - } - } -} - -static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) { - delete lock; -} - -void SSLContext::setSSLLockTypes(std::map inLockTypes) { - if (initialized_) { - // We set the locks on initialization, so if we are already initialized - // this would have no affect. - LOG(INFO) << "Ignoring setSSLLockTypes after initialization"; - return; - } - - lockTypes() = inLockTypes; -} - -bool SSLContext::isSSLLockDisabled(int lockId) { - const auto& sslLocks = lockTypes(); - const auto it = sslLocks.find(lockId); - return it != sslLocks.end() && - it->second == SSLContext::SSLLockType::LOCK_NONE; + auto const length = std::min(userPassword.size(), size_t(size)); + std::memcpy(password, userPassword.data(), length); + return int(length); } #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) @@ -817,63 +657,8 @@ void SSLContext::enableFalseStart() { } #endif -void SSLContext::markInitialized() { - std::lock_guard g(initMutex()); - initialized_ = true; -} - void SSLContext::initializeOpenSSL() { - std::lock_guard g(initMutex()); - initializeOpenSSLLocked(); -} - -void SSLContext::initializeOpenSSLLocked() { - if (initialized_) { - return; - } - SSL_library_init(); - SSL_load_error_strings(); - ERR_load_crypto_strings(); - // static locking - locks().reset(new SSLLock[size_t(CRYPTO_num_locks())]); - for (auto it: lockTypes()) { - locks()[size_t(it.first)].lockType = it.second; - } - CRYPTO_set_id_callback(callbackThreadID); - CRYPTO_set_locking_callback(callbackLocking); - // dynamic locking - CRYPTO_set_dynlock_create_callback(dyn_create); - CRYPTO_set_dynlock_lock_callback(dyn_lock); - CRYPTO_set_dynlock_destroy_callback(dyn_destroy); - randomize(); -#ifdef OPENSSL_NPN_NEGOTIATED - sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0, - (void*)"Advertised next protocol index", nullptr, nullptr, nullptr); -#endif - initialized_ = true; -} - -void SSLContext::cleanupOpenSSL() { - std::lock_guard g(initMutex()); - cleanupOpenSSLLocked(); -} - -void SSLContext::cleanupOpenSSLLocked() { - if (!initialized_) { - return; - } - - CRYPTO_set_id_callback(nullptr); - CRYPTO_set_locking_callback(nullptr); - CRYPTO_set_dynlock_create_callback(nullptr); - CRYPTO_set_dynlock_lock_callback(nullptr); - CRYPTO_set_dynlock_destroy_callback(nullptr); - CRYPTO_cleanup_all_ex_data(); - ERR_free_strings(); - EVP_cleanup(); - ERR_clear_error(); - locks().reset(); - initialized_ = false; + folly::ssl::init(); } void SSLContext::setOptions(long options) { @@ -912,4 +697,4 @@ operator<<(std::ostream& os, const PasswordCollector& collector) { return os; } -} // folly +} // namespace folly