#include "SSLContext.h"
-#include <openssl/err.h>
-#include <openssl/rand.h>
-#include <openssl/ssl.h>
-#include <openssl/x509v3.h>
-
#include <folly/Format.h>
#include <folly/Memory.h>
#include <folly/Random.h>
+#include <folly/SharedMutex.h>
#include <folly/SpinLock.h>
+#include <folly/ThreadId.h>
+#include <folly/ssl/Init.h>
// ---------------------------------------------------------------------
// SSLContext implementation
// ---------------------------------------------------------------------
-
-struct CRYPTO_dynlock_value {
- std::mutex mutex;
-};
-
namespace folly {
//
// For OpenSSL portability API
using namespace folly::ssl;
-bool SSLContext::initialized_ = false;
-
-namespace {
-
-std::mutex& initMutex() {
- static std::mutex m;
- return m;
-}
-
-} // anonymous namespace
-
-#ifdef OPENSSL_NPN_NEGOTIATED
-int SSLContext::sNextProtocolsExDataIndex_ = -1;
-#endif
-
// SSLContext implementation
SSLContext::SSLContext(SSLVersion version) {
- {
- std::lock_guard<std::mutex> g(initMutex());
- initializeOpenSSLLocked();
- }
+ folly::ssl::init();
ctx_ = SSL_CTX_new(SSLv23_method());
if (ctx_ == nullptr) {
case SSLv3:
opt = SSL_OP_NO_SSLv2;
break;
+ case TLSv1_2:
+ opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 |
+ SSL_OP_NO_TLSv1_1;
+ break;
default:
// do nothing
break;
}
void SSLContext::ciphers(const std::string& ciphers) {
- providedCiphersString_ = ciphers;
setCiphersOrThrow(ciphers);
}
-void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
- if (ciphers.size() == 0) {
- return;
- }
- std::string opensslCipherList;
- join(":", ciphers, opensslCipherList);
- setCiphersOrThrow(opensslCipherList);
-}
-
-void SSLContext::setSignatureAlgorithms(
- const std::vector<std::string>& sigalgs) {
- if (sigalgs.size() == 0) {
- return;
- }
-#if OPENSSL_VERSION_NUMBER >= 0x1000200fL
- std::string opensslSigAlgsList;
- join(":", sigalgs, opensslSigAlgsList);
- int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
- if (rc == 0) {
- throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
- }
-#endif
-}
-
void SSLContext::setClientECCurvesList(
const std::vector<std::string>& ecCurves) {
if (ecCurves.size() == 0) {
}
void SSLContext::setServerECCurve(const std::string& curveName) {
- bool validCall = false;
-#if OPENSSL_VERSION_NUMBER >= 0x0090800fL
-#ifndef OPENSSL_NO_ECDH
- validCall = true;
-#endif
-#endif
- if (!validCall) {
- throw std::runtime_error("Elliptic curve encryption not allowed");
- }
-
+#if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
EC_KEY* ecdh = nullptr;
int nid;
SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
EC_KEY_free(ecdh);
+#else
+ throw std::runtime_error("Elliptic curve encryption not allowed");
+#endif
}
void SSLContext::setX509VerifyParam(
if (rc == 0) {
throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
}
+ providedCiphersString_ = ciphers;
}
void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
const std::string& peerName) {
int mode;
if (checkPeerCert) {
- mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
+ mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
+ SSL_VERIFY_CLIENT_ONCE;
checkPeerName_ = checkPeerName;
peerFixedName_ = peerName;
} else {
throw std::runtime_error(reason);
}
} else {
- throw std::runtime_error("Unsupported certificate format: " + std::string(format));
+ throw std::runtime_error(
+ "Unsupported certificate format: " + std::string(format));
}
}
throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
}
} else {
- throw std::runtime_error("Unsupported private key format: " + std::string(format));
+ throw std::runtime_error(
+ "Unsupported private key format: " + std::string(format));
}
}
SSL_CTX_set_client_CA_list(ctx_, clientCAs);
}
-void SSLContext::randomize() {
- RAND_poll();
-}
-
-void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
+void SSLContext::passwordCollector(
+ std::shared_ptr<PasswordCollector> collector) {
if (collector == nullptr) {
LOG(ERROR) << "passwordCollector: ignore invalid password collector";
return;
return SSL_TLSEXT_ERR_NOACK;
}
-
-void SSLContext::switchCiphersIfTLS11(
- SSL* ssl,
- const std::string& tls11CipherString,
- const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
- CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
- << "Shouldn't call if empty ciphers / alt ciphers";
-
- if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
- // We only do this for TLS v 1.1 and later
- return;
- }
-
- const std::string* ciphers = &tls11CipherString;
- if (!tls11AltCipherlist.empty()) {
- if (!cipherListPicker_) {
- std::vector<int> weights;
- std::for_each(
- tls11AltCipherlist.begin(),
- tls11AltCipherlist.end(),
- [&](const std::pair<std::string, int>& e) {
- weights.push_back(e.second);
- });
- cipherListPicker_.reset(
- new std::discrete_distribution<int>(weights.begin(), weights.end()));
- }
- auto rng = ThreadLocalPRNG();
- auto index = (*cipherListPicker_)(rng);
- if ((size_t)index >= tls11AltCipherlist.size()) {
- LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
- << ", but tls11AltCipherlist is of length "
- << tls11AltCipherlist.size();
- } else {
- ciphers = &tls11AltCipherlist[size_t(index)].first;
- }
- }
-
- // Prefer AES for TLS versions 1.1 and later since these are not
- // vulnerable to BEAST attacks on AES. Note that we're setting the
- // cipher list on the SSL object, not the SSL_CTX object, so it will
- // only last for this request.
- int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
- if ((rc == 0) || ERR_peek_error() != 0) {
- // This shouldn't happen since we checked for this when proxygen
- // started up.
- LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
- SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
- }
-}
#endif // FOLLY_OPENSSL_HAS_SNI
#if FOLLY_OPENSSL_HAS_ALPN
int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
const unsigned char** out, unsigned int* outlen, void* data) {
+ static int nextProtocolsExDataIndex = SSL_get_ex_new_index(
+ 0, (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
+
SSLContext* context = (SSLContext*)data;
if (context == nullptr || context->advertisedNextProtocols_.empty()) {
*out = nullptr;
*out = context->advertisedNextProtocols_[0].protocols;
*outlen = context->advertisedNextProtocols_[0].length;
} else {
- uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
- sNextProtocolsExDataIndex_));
+ uintptr_t selected_index = reinterpret_cast<uintptr_t>(
+ SSL_get_ex_data(ssl, nextProtocolsExDataIndex));
if (selected_index) {
--selected_index;
*out = context->advertisedNextProtocols_[selected_index].protocols;
} else {
auto i = context->pickNextProtocols();
uintptr_t selected = i + 1;
- SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
+ SSL_set_ex_data(ssl, nextProtocolsExDataIndex, (void*)selected);
*out = context->advertisedNextProtocols_[i].protocols;
*outlen = context->advertisedNextProtocols_[i].length;
}
ctx_,
reinterpret_cast<const unsigned char*>(context.data()),
std::min<unsigned int>(
- static_cast<unsigned int>(context.length()),
- SSL_MAX_SSL_SESSION_ID_LENGTH));
+ static_cast<unsigned int>(context.length()), SSL_MAX_SID_CTX_LENGTH));
}
/**
std::string userPassword;
// call user defined password collector to get password
context->passwordCollector()->getPassword(userPassword, size);
- auto length = int(userPassword.size());
- if (length > size) {
- length = size;
- }
- strncpy(password, userPassword.c_str(), size_t(length));
- return length;
-}
-
-struct SSLLock {
- explicit SSLLock(
- SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
- lockType(inLockType) {
- }
-
- void lock() {
- if (lockType == SSLContext::LOCK_MUTEX) {
- mutex.lock();
- } else if (lockType == SSLContext::LOCK_SPINLOCK) {
- spinLock.lock();
- }
- // lockType == LOCK_NONE, no-op
- }
-
- void unlock() {
- if (lockType == SSLContext::LOCK_MUTEX) {
- mutex.unlock();
- } else if (lockType == SSLContext::LOCK_SPINLOCK) {
- spinLock.unlock();
- }
- // lockType == LOCK_NONE, no-op
- }
-
- SSLContext::SSLLockType lockType;
- folly::SpinLock spinLock{};
- std::mutex mutex;
-};
-
-// Statics are unsafe in environments that call exit().
-// If one thread calls exit() while another thread is
-// references a member of SSLContext, bad things can happen.
-// SSLContext runs in such environments.
-// Instead of declaring a static member we "new" the static
-// member so that it won't be destructed on exit().
-static std::unique_ptr<SSLLock[]>& locks() {
- static auto locksInst = new std::unique_ptr<SSLLock[]>();
- return *locksInst;
-}
-
-static std::map<int, SSLContext::SSLLockType>& lockTypes() {
- static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
- return *lockTypesInst;
-}
-
-static void callbackLocking(int mode, int n, const char*, int) {
- if (mode & CRYPTO_LOCK) {
- locks()[size_t(n)].lock();
- } else {
- locks()[size_t(n)].unlock();
- }
-}
-
-static unsigned long callbackThreadID() {
- return static_cast<unsigned long>(
-#ifdef __APPLE__
- pthread_mach_thread_np(pthread_self())
-#elif _MSC_VER
- pthread_getw32threadid_np(pthread_self())
-#else
- pthread_self()
-#endif
- );
-}
-
-static CRYPTO_dynlock_value* dyn_create(const char*, int) {
- return new CRYPTO_dynlock_value;
-}
-
-static void dyn_lock(int mode,
- struct CRYPTO_dynlock_value* lock,
- const char*, int) {
- if (lock != nullptr) {
- if (mode & CRYPTO_LOCK) {
- lock->mutex.lock();
- } else {
- lock->mutex.unlock();
- }
- }
-}
-
-static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
- delete lock;
-}
-
-void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
- lockTypes() = inLockTypes;
+ auto const length = std::min(userPassword.size(), size_t(size));
+ std::memcpy(password, userPassword.data(), length);
+ return int(length);
}
#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
}
#endif
-void SSLContext::markInitialized() {
- std::lock_guard<std::mutex> g(initMutex());
- initialized_ = true;
-}
-
void SSLContext::initializeOpenSSL() {
- std::lock_guard<std::mutex> g(initMutex());
- initializeOpenSSLLocked();
-}
-
-void SSLContext::initializeOpenSSLLocked() {
- if (initialized_) {
- return;
- }
- SSL_library_init();
- SSL_load_error_strings();
- ERR_load_crypto_strings();
- // static locking
- locks().reset(new SSLLock[size_t(::CRYPTO_num_locks())]);
- for (auto it: lockTypes()) {
- locks()[size_t(it.first)].lockType = it.second;
- }
- CRYPTO_set_id_callback(callbackThreadID);
- CRYPTO_set_locking_callback(callbackLocking);
- // dynamic locking
- CRYPTO_set_dynlock_create_callback(dyn_create);
- CRYPTO_set_dynlock_lock_callback(dyn_lock);
- CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
- randomize();
-#ifdef OPENSSL_NPN_NEGOTIATED
- sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
- (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
-#endif
- initialized_ = true;
-}
-
-void SSLContext::cleanupOpenSSL() {
- std::lock_guard<std::mutex> g(initMutex());
- cleanupOpenSSLLocked();
-}
-
-void SSLContext::cleanupOpenSSLLocked() {
- if (!initialized_) {
- return;
- }
-
- CRYPTO_set_id_callback(nullptr);
- CRYPTO_set_locking_callback(nullptr);
- CRYPTO_set_dynlock_create_callback(nullptr);
- CRYPTO_set_dynlock_lock_callback(nullptr);
- CRYPTO_set_dynlock_destroy_callback(nullptr);
- CRYPTO_cleanup_all_ex_data();
- ERR_free_strings();
- EVP_cleanup();
- ERR_remove_state(0);
- locks().reset();
- initialized_ = false;
+ folly::ssl::init();
}
void SSLContext::setOptions(long options) {
return os;
}
-} // folly
+} // namespace folly