/*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
#include "SSLContext.h"
-#include <openssl/err.h>
-#include <openssl/rand.h>
-#include <openssl/ssl.h>
-#include <openssl/x509v3.h>
-
#include <folly/Format.h>
+#include <folly/Memory.h>
+#include <folly/Random.h>
+#include <folly/SharedMutex.h>
#include <folly/SpinLock.h>
+#include <folly/ThreadId.h>
// ---------------------------------------------------------------------
// SSLContext implementation
};
namespace folly {
+//
+// For OpenSSL portability API
+using namespace folly::ssl;
bool SSLContext::initialized_ = false;
return m;
}
-} // anonymous namespace
+} // anonymous namespace
#ifdef OPENSSL_NPN_NEGOTIATED
int SSLContext::sNextProtocolsExDataIndex_ = -1;
checkPeerName_ = false;
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+ SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
+
+#if FOLLY_OPENSSL_HAS_SNI
SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
SSL_CTX_set_tlsext_servername_arg(ctx_, this);
#endif
}
void SSLContext::ciphers(const std::string& ciphers) {
- providedCiphersString_ = ciphers;
setCiphersOrThrow(ciphers);
}
+void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
+ if (ciphers.size() == 0) {
+ return;
+ }
+ std::string opensslCipherList;
+ join(":", ciphers, opensslCipherList);
+ setCiphersOrThrow(opensslCipherList);
+}
+
+void SSLContext::setSignatureAlgorithms(
+ const std::vector<std::string>& sigalgs) {
+ if (sigalgs.size() == 0) {
+ return;
+ }
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL
+ std::string opensslSigAlgsList;
+ join(":", sigalgs, opensslSigAlgsList);
+ int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
+ if (rc == 0) {
+ throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
+ }
+#endif
+}
+
+void SSLContext::setClientECCurvesList(
+ const std::vector<std::string>& ecCurves) {
+ if (ecCurves.size() == 0) {
+ return;
+ }
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL
+ std::string ecCurvesList;
+ join(":", ecCurves, ecCurvesList);
+ int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
+ if (rc == 0) {
+ throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
+ }
+#endif
+}
+
+void SSLContext::setServerECCurve(const std::string& curveName) {
+#if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
+ EC_KEY* ecdh = nullptr;
+ int nid;
+
+ /*
+ * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
+ * from RFC 4492 section 5.1.1, or explicitly described curves over
+ * binary fields. OpenSSL only supports the "named curves", which provide
+ * maximum interoperability.
+ */
+
+ nid = OBJ_sn2nid(curveName.c_str());
+ if (nid == 0) {
+ LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
+ }
+ ecdh = EC_KEY_new_by_curve_name(nid);
+ if (ecdh == nullptr) {
+ LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
+ }
+
+ SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
+ EC_KEY_free(ecdh);
+#else
+ throw std::runtime_error("Elliptic curve encryption not allowed");
+#endif
+}
+
+void SSLContext::setX509VerifyParam(
+ const ssl::X509VerifyParam& x509VerifyParam) {
+ if (!x509VerifyParam) {
+ return;
+ }
+ if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
+ throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
+ }
+}
+
void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
- if (ERR_peek_error() != 0) {
- throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
- }
if (rc == 0) {
- throw std::runtime_error("None of specified ciphers are supported");
+ throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
}
+ providedCiphersString_ = ciphers;
}
void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
}
}
+void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
+ if (cert.data() == nullptr) {
+ throw std::invalid_argument("loadCertificate: <cert> is nullptr");
+ }
+
+ ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
+ if (bio == nullptr) {
+ throw std::runtime_error("BIO_new: " + getErrors());
+ }
+
+ int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
+ if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
+ throw std::runtime_error("BIO_write: " + getErrors());
+ }
+
+ ssl::X509UniquePtr x509(
+ PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
+ if (x509 == nullptr) {
+ throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
+ }
+
+ if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
+ throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
+ }
+}
+
void SSLContext::loadPrivateKey(const char* path, const char* format) {
if (path == nullptr || format == nullptr) {
throw std::invalid_argument(
- "loadPrivateKey: either <path> or <format> is nullptr");
+ "loadPrivateKey: either <path> or <format> is nullptr");
}
if (strcmp(format, "PEM") == 0) {
if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
}
}
+void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
+ if (pkey.data() == nullptr) {
+ throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
+ }
+
+ ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
+ if (bio == nullptr) {
+ throw std::runtime_error("BIO_new: " + getErrors());
+ }
+
+ int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
+ if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
+ throw std::runtime_error("BIO_write: " + getErrors());
+ }
+
+ ssl::EvpPkeyUniquePtr key(
+ PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
+ if (key == nullptr) {
+ throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
+ }
+
+ if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
+ throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
+ }
+}
+
void SSLContext::loadTrustedCertificates(const char* path) {
if (path == nullptr) {
- throw std::invalid_argument(
- "loadTrustedCertificates: <path> is nullptr");
+ throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
}
if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
}
+ ERR_clear_error();
}
void SSLContext::loadTrustedCertificates(X509_STORE* store) {
SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
}
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
serverNameCb_ = cb;
void SSLContext::switchCiphersIfTLS11(
SSL* ssl,
- const std::string& tls11CipherString) {
-
- CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
+ const std::string& tls11CipherString,
+ const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
+ CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
+ << "Shouldn't call if empty ciphers / alt ciphers";
if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
// We only do this for TLS v 1.1 and later
return;
}
+ const std::string* ciphers = &tls11CipherString;
+ if (!tls11AltCipherlist.empty()) {
+ if (!cipherListPicker_) {
+ std::vector<int> weights;
+ std::for_each(
+ tls11AltCipherlist.begin(),
+ tls11AltCipherlist.end(),
+ [&](const std::pair<std::string, int>& e) {
+ weights.push_back(e.second);
+ });
+ cipherListPicker_.reset(
+ new std::discrete_distribution<int>(weights.begin(), weights.end()));
+ }
+ auto rng = ThreadLocalPRNG();
+ auto index = (*cipherListPicker_)(rng);
+ if ((size_t)index >= tls11AltCipherlist.size()) {
+ LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
+ << ", but tls11AltCipherlist is of length "
+ << tls11AltCipherlist.size();
+ } else {
+ ciphers = &tls11AltCipherlist[size_t(index)].first;
+ }
+ }
+
// Prefer AES for TLS versions 1.1 and later since these are not
// vulnerable to BEAST attacks on AES. Note that we're setting the
// cipher list on the SSL object, not the SSL_CTX object, so it will
// only last for this request.
- int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
+ int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
if ((rc == 0) || ERR_peek_error() != 0) {
// This shouldn't happen since we checked for this when proxygen
// started up.
SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
}
}
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
+
+#if FOLLY_OPENSSL_HAS_ALPN
+int SSLContext::alpnSelectCallback(SSL* /* ssl */,
+ const unsigned char** out,
+ unsigned char* outlen,
+ const unsigned char* in,
+ unsigned int inlen,
+ void* data) {
+ SSLContext* context = (SSLContext*)data;
+ CHECK(context);
+ if (context->advertisedNextProtocols_.empty()) {
+ *out = nullptr;
+ *outlen = 0;
+ } else {
+ auto i = context->pickNextProtocols();
+ const auto& item = context->advertisedNextProtocols_[i];
+ if (SSL_select_next_proto((unsigned char**)out,
+ outlen,
+ item.protocols,
+ item.length,
+ in,
+ inlen) != OPENSSL_NPN_NEGOTIATED) {
+ return SSL_TLSEXT_ERR_NOACK;
+ }
+ }
+ return SSL_TLSEXT_ERR_OK;
+}
+#endif // FOLLY_OPENSSL_HAS_ALPN
#ifdef OPENSSL_NPN_NEGOTIATED
-bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
- return setRandomizedAdvertisedNextProtocols({{1, protocols}});
+
+bool SSLContext::setAdvertisedNextProtocols(
+ const std::list<std::string>& protocols, NextProtocolType protocolType) {
+ return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
}
bool SSLContext::setRandomizedAdvertisedNextProtocols(
- const std::list<NextProtocolsItem>& items) {
+ const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
unsetNextProtocols();
if (items.size() == 0) {
return false;
advertised_item.length = 0;
for (const auto& proto : item.protocols) {
++advertised_item.length;
- unsigned protoLength = proto.length();
+ auto protoLength = proto.length();
if (protoLength >= 256) {
deleteNextProtocolsStrings();
return false;
}
- advertised_item.length += protoLength;
+ advertised_item.length += unsigned(protoLength);
}
advertised_item.protocols = new unsigned char[advertised_item.length];
if (!advertised_item.protocols) {
}
unsigned char* dst = advertised_item.protocols;
for (auto& proto : item.protocols) {
- unsigned protoLength = proto.length();
+ uint8_t protoLength = uint8_t(proto.length());
*dst++ = (unsigned char)protoLength;
memcpy(dst, proto.data(), protoLength);
dst += protoLength;
}
total_weight += item.weight;
- advertised_item.probability = item.weight;
advertisedNextProtocols_.push_back(advertised_item);
+ advertisedNextProtocolWeights_.push_back(item.weight);
}
if (total_weight == 0) {
deleteNextProtocolsStrings();
return false;
}
- for (auto &advertised_item : advertisedNextProtocols_) {
- advertised_item.probability /= total_weight;
+ nextProtocolDistribution_ =
+ std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
+ advertisedNextProtocolWeights_.end());
+ if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
+ SSL_CTX_set_next_protos_advertised_cb(
+ ctx_, advertisedNextProtocolCallback, this);
+ SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
+ }
+#if FOLLY_OPENSSL_HAS_ALPN
+ if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
+ SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
+ // Client cannot really use randomized alpn
+ SSL_CTX_set_alpn_protos(ctx_,
+ advertisedNextProtocols_[0].protocols,
+ advertisedNextProtocols_[0].length);
}
- SSL_CTX_set_next_protos_advertised_cb(
- ctx_, advertisedNextProtocolCallback, this);
- SSL_CTX_set_next_proto_select_cb(
- ctx_, selectNextProtocolCallback, this);
+#endif
return true;
}
delete[] protocols.protocols;
}
advertisedNextProtocols_.clear();
+ advertisedNextProtocolWeights_.clear();
}
void SSLContext::unsetNextProtocols() {
deleteNextProtocolsStrings();
SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
+#if FOLLY_OPENSSL_HAS_ALPN
+ SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
+ SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
+#endif
+}
+
+size_t SSLContext::pickNextProtocols() {
+ CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
+ auto rng = ThreadLocalPRNG();
+ return size_t(nextProtocolDistribution_(rng));
}
int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
*out = context->advertisedNextProtocols_[selected_index].protocols;
*outlen = context->advertisedNextProtocols_[selected_index].length;
} else {
- unsigned char random_byte;
- RAND_bytes(&random_byte, 1);
- double random_value = random_byte / 255.0;
- double sum = 0;
- for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
- sum += context->advertisedNextProtocols_[i].probability;
- if (sum < random_value &&
- i + 1 < context->advertisedNextProtocols_.size()) {
- continue;
- }
- uintptr_t selected = i + 1;
- SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
- *out = context->advertisedNextProtocols_[i].protocols;
- *outlen = context->advertisedNextProtocols_[i].length;
- break;
- }
+ auto i = context->pickNextProtocols();
+ uintptr_t selected = i + 1;
+ SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
+ *out = context->advertisedNextProtocols_[i].protocols;
+ *outlen = context->advertisedNextProtocols_[i].length;
}
}
return SSL_TLSEXT_ERR_OK;
}
-#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
- FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
-SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
- /**
- * The list was generated as follows:
- * grep "_CK_" openssl-1.0.1e/ssl/s3_lib.c -A 4 |
- * while read A && read B && read C && read D && read E && read F; do
- * echo $A $B $C $D $E; done |
- * grep "\(SSL_kDHr\|SSL_kDHd\|SSL_kEDH\|SSL_kECDHr\|
- * SSL_kECDHe\|SSL_kEECDH\)" | grep -v SSL_aNULL | grep SSL_AES |
- * awk -F, '{ print $1"," }'
- */
- ciphers_{
- TLS1_CK_DH_DSS_WITH_AES_128_SHA,
- TLS1_CK_DH_RSA_WITH_AES_128_SHA,
- TLS1_CK_DHE_DSS_WITH_AES_128_SHA,
- TLS1_CK_DHE_RSA_WITH_AES_128_SHA,
- TLS1_CK_DH_DSS_WITH_AES_256_SHA,
- TLS1_CK_DH_RSA_WITH_AES_256_SHA,
- TLS1_CK_DHE_DSS_WITH_AES_256_SHA,
- TLS1_CK_DHE_RSA_WITH_AES_256_SHA,
- TLS1_CK_DH_DSS_WITH_AES_128_SHA256,
- TLS1_CK_DH_RSA_WITH_AES_128_SHA256,
- TLS1_CK_DHE_DSS_WITH_AES_128_SHA256,
- TLS1_CK_DHE_RSA_WITH_AES_128_SHA256,
- TLS1_CK_DH_DSS_WITH_AES_256_SHA256,
- TLS1_CK_DH_RSA_WITH_AES_256_SHA256,
- TLS1_CK_DHE_DSS_WITH_AES_256_SHA256,
- TLS1_CK_DHE_RSA_WITH_AES_256_SHA256,
- TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256,
- TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384,
- TLS1_CK_DH_RSA_WITH_AES_128_GCM_SHA256,
- TLS1_CK_DH_RSA_WITH_AES_256_GCM_SHA384,
- TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256,
- TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384,
- TLS1_CK_DH_DSS_WITH_AES_128_GCM_SHA256,
- TLS1_CK_DH_DSS_WITH_AES_256_GCM_SHA384,
- TLS1_CK_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
- TLS1_CK_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
- TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
- TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
- TLS1_CK_ECDH_RSA_WITH_AES_128_CBC_SHA,
- TLS1_CK_ECDH_RSA_WITH_AES_256_CBC_SHA,
- TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA,
- TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA,
- TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256,
- TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384,
- TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256,
- TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384,
- TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256,
- TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384,
- TLS1_CK_ECDH_RSA_WITH_AES_128_SHA256,
- TLS1_CK_ECDH_RSA_WITH_AES_256_SHA384,
- TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
- TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
- TLS1_CK_ECDH_ECDSA_WITH_AES_128_GCM_SHA256,
- TLS1_CK_ECDH_ECDSA_WITH_AES_256_GCM_SHA384,
- TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
- TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
- TLS1_CK_ECDH_RSA_WITH_AES_128_GCM_SHA256,
- } {
- length_ = sizeof(ciphers_)/sizeof(ciphers_[0]);
- width_ = sizeof(ciphers_[0]);
- qsort(ciphers_, length_, width_, compare_ulong);
-}
-
-bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher(
- const SSL_CIPHER *cipher) {
- unsigned long cid = cipher->id;
- unsigned long *r =
- (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong);
- return r != nullptr;
-}
-
-int
-SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) {
- if (*(unsigned long *)x < *(unsigned long *)y) {
- return -1;
- }
- if (*(unsigned long *)x > *(unsigned long *)y) {
- return 1;
- }
- return 0;
-};
-
-bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) {
- return falseStartChecker_.canUseFalseStartWithCipher(cipher);
-}
-#endif
-
-int SSLContext::selectNextProtocolCallback(
- SSL* ssl, unsigned char **out, unsigned char *outlen,
- const unsigned char *server, unsigned int server_len, void *data) {
-
+int SSLContext::selectNextProtocolCallback(SSL* ssl,
+ unsigned char** out,
+ unsigned char* outlen,
+ const unsigned char* server,
+ unsigned int server_len,
+ void* data) {
+ (void)ssl; // Make -Wunused-parameters happy
SSLContext* ctx = (SSLContext*)data;
if (ctx->advertisedNextProtocols_.size() > 1) {
VLOG(3) << "SSLContext::selectNextProcolCallback() "
<< "client should be deterministic in selecting protocols.";
}
- unsigned char *client;
- unsigned int client_len;
+ unsigned char* client = nullptr;
+ unsigned int client_len = 0;
bool filtered = false;
auto cpf = ctx->getClientProtocolFilterCallback();
if (cpf) {
if (retval != OPENSSL_NPN_NEGOTIATED) {
VLOG(3) << "SSLContext::selectNextProcolCallback() "
<< "unable to pick a next protocol.";
-#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
- FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
- } else {
- const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher;
- if (cipher && ctx->canUseFalseStartWithCipher(cipher)) {
- SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH);
- }
-#endif
}
return SSL_TLSEXT_ERR_OK;
}
return ssl;
}
+void SSLContext::setSessionCacheContext(const std::string& context) {
+ SSL_CTX_set_session_id_context(
+ ctx_,
+ reinterpret_cast<const unsigned char*>(context.data()),
+ std::min<unsigned int>(
+ static_cast<unsigned int>(context.length()),
+ SSL_MAX_SSL_SESSION_ID_LENGTH));
+}
+
/**
* Match a name with a pattern. The pattern may include wildcard. A single
* wildcard "*" can match up to one component in the domain name.
std::string userPassword;
// call user defined password collector to get password
context->passwordCollector()->getPassword(userPassword, size);
- int length = userPassword.size();
+ auto length = int(userPassword.size());
if (length > size) {
length = size;
}
- strncpy(password, userPassword.c_str(), length);
+ strncpy(password, userPassword.c_str(), size_t(length));
return length;
}
lockType(inLockType) {
}
- void lock() {
+ void lock(bool read) {
if (lockType == SSLContext::LOCK_MUTEX) {
mutex.lock();
} else if (lockType == SSLContext::LOCK_SPINLOCK) {
spinLock.lock();
+ } else if (lockType == SSLContext::LOCK_SHAREDMUTEX) {
+ if (read) {
+ sharedMutex.lock_shared();
+ } else {
+ sharedMutex.lock();
+ }
}
// lockType == LOCK_NONE, no-op
}
- void unlock() {
+ void unlock(bool read) {
if (lockType == SSLContext::LOCK_MUTEX) {
mutex.unlock();
} else if (lockType == SSLContext::LOCK_SPINLOCK) {
spinLock.unlock();
+ } else if (lockType == SSLContext::LOCK_SHAREDMUTEX) {
+ if (read) {
+ sharedMutex.unlock_shared();
+ } else {
+ sharedMutex.unlock();
+ }
}
// lockType == LOCK_NONE, no-op
}
SSLContext::SSLLockType lockType;
folly::SpinLock spinLock{};
std::mutex mutex;
+ SharedMutex sharedMutex;
};
// Statics are unsafe in environments that call exit().
static void callbackLocking(int mode, int n, const char*, int) {
if (mode & CRYPTO_LOCK) {
- locks()[n].lock();
+ locks()[size_t(n)].lock(mode & CRYPTO_READ);
} else {
- locks()[n].unlock();
+ locks()[size_t(n)].unlock(mode & CRYPTO_READ);
}
}
static unsigned long callbackThreadID() {
- return static_cast<unsigned long>(
-#ifdef __APPLE__
- pthread_mach_thread_np(pthread_self())
-#else
- pthread_self()
-#endif
- );
+ return static_cast<unsigned long>(folly::getCurrentThreadID());
}
static CRYPTO_dynlock_value* dyn_create(const char*, int) {
delete lock;
}
-void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
+void SSLContext::setSSLLockTypesLocked(std::map<int, SSLLockType> inLockTypes) {
lockTypes() = inLockTypes;
}
+void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
+ std::lock_guard<std::mutex> g(initMutex());
+ if (initialized_) {
+ // We set the locks on initialization, so if we are already initialized
+ // this would have no affect.
+ LOG(INFO) << "Ignoring setSSLLockTypes after initialization";
+ return;
+ }
+ setSSLLockTypesLocked(std::move(inLockTypes));
+}
+
+void SSLContext::setSSLLockTypesAndInitOpenSSL(
+ std::map<int, SSLLockType> inLockTypes) {
+ std::lock_guard<std::mutex> g(initMutex());
+ CHECK(!initialized_) << "OpenSSL is already initialized";
+ setSSLLockTypesLocked(std::move(inLockTypes));
+ initializeOpenSSLLocked();
+}
+
+bool SSLContext::isSSLLockDisabled(int lockId) {
+ std::lock_guard<std::mutex> g(initMutex());
+ CHECK(initialized_) << "OpenSSL is not initialized yet";
+ const auto& sslLocks = lockTypes();
+ const auto it = sslLocks.find(lockId);
+ return it != sslLocks.end() &&
+ it->second == SSLContext::SSLLockType::LOCK_NONE;
+}
+
+#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
+void SSLContext::enableFalseStart() {
+ SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
+}
+#endif
+
void SSLContext::markInitialized() {
std::lock_guard<std::mutex> g(initMutex());
initialized_ = true;
SSL_load_error_strings();
ERR_load_crypto_strings();
// static locking
- locks().reset(new SSLLock[::CRYPTO_num_locks()]);
+ locks().reset(new SSLLock[size_t(CRYPTO_num_locks())]);
for (auto it: lockTypes()) {
- locks()[it.first].lockType = it.second;
+ locks()[size_t(it.first)].lockType = it.second;
}
CRYPTO_set_id_callback(callbackThreadID);
CRYPTO_set_locking_callback(callbackLocking);
CRYPTO_cleanup_all_ex_data();
ERR_free_strings();
EVP_cleanup();
- ERR_remove_state(0);
+ ERR_clear_error();
locks().reset();
initialized_ = false;
}
return os;
}
-bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx,
- sockaddr_storage* addrStorage,
- socklen_t* addrLen) {
- // Grab the ssl idx and then the ssl object so that we can get the peer
- // name to compare against the ips in the subjectAltName
- auto sslIdx = SSL_get_ex_data_X509_STORE_CTX_idx();
- auto ssl =
- reinterpret_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, sslIdx));
- int fd = SSL_get_fd(ssl);
- if (fd < 0) {
- LOG(ERROR) << "Inexplicably couldn't get fd from SSL";
- return false;
- }
-
- *addrLen = sizeof(*addrStorage);
- if (getpeername(fd, reinterpret_cast<sockaddr*>(addrStorage), addrLen) != 0) {
- PLOG(ERROR) << "Unable to get peer name";
- return false;
- }
- CHECK(*addrLen <= sizeof(*addrStorage));
- return true;
-}
-
-bool OpenSSLUtils::validatePeerCertNames(X509* cert,
- const sockaddr* addr,
- socklen_t addrLen) {
- // Try to extract the names within the SAN extension from the certificate
- auto altNames =
- reinterpret_cast<STACK_OF(GENERAL_NAME)*>(
- X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
- SCOPE_EXIT {
- if (altNames != nullptr) {
- sk_GENERAL_NAME_pop_free(altNames, GENERAL_NAME_free);
- }
- };
- if (altNames == nullptr) {
- LOG(WARNING) << "No subjectAltName provided and we only support ip auth";
- return false;
- }
-
- const sockaddr_in* addr4 = nullptr;
- const sockaddr_in6* addr6 = nullptr;
- if (addr != nullptr) {
- if (addr->sa_family == AF_INET) {
- addr4 = reinterpret_cast<const sockaddr_in*>(addr);
- } else if (addr->sa_family == AF_INET6) {
- addr6 = reinterpret_cast<const sockaddr_in6*>(addr);
- } else {
- LOG(FATAL) << "Unsupported sockaddr family: " << addr->sa_family;
- }
- }
-
-
- for (int i = 0; i < sk_GENERAL_NAME_num(altNames); i++) {
- auto name = sk_GENERAL_NAME_value(altNames, i);
- if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) {
- // Extra const-ness for paranoia
- unsigned char const * const rawIpStr = name->d.iPAddress->data;
- int const rawIpLen = name->d.iPAddress->length;
-
- if (rawIpLen == 4 && addr4 != nullptr) {
- if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) {
- return true;
- }
- } else if (rawIpLen == 16 && addr6 != nullptr) {
- if (::memcmp(rawIpStr, &addr6->sin6_addr, rawIpLen) == 0) {
- return true;
- }
- } else if (rawIpLen != 4 && rawIpLen != 16) {
- LOG(WARNING) << "Unexpected IP length: " << rawIpLen;
- }
- }
- }
-
- LOG(WARNING) << "Unable to match client cert against alt name ip";
- return false;
-}
-
-
} // folly