/*
- * Copyright 2014 Facebook, Inc.
+ * Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
#include <openssl/x509v3.h>
#include <folly/Format.h>
-#include <folly/io/PortableSpinLock.h>
+#include <folly/Memory.h>
+#include <folly/SpinLock.h>
+#include <folly/io/async/OpenSSLPtrTypes.h>
// ---------------------------------------------------------------------
// SSLContext implementation
namespace folly {
bool SSLContext::initialized_ = false;
-std::mutex SSLContext::mutex_;
+
+namespace {
+
+std::mutex& initMutex() {
+ static std::mutex m;
+ return m;
+}
+
+inline void BIO_free_fb(BIO* bio) { CHECK_EQ(1, BIO_free(bio)); }
+using BIO_deleter = folly::static_function_deleter<BIO, &BIO_free_fb>;
+
+} // anonymous namespace
+
#ifdef OPENSSL_NPN_NEGOTIATED
int SSLContext::sNextProtocolsExDataIndex_ = -1;
#endif
-#ifndef SSLCONTEXT_NO_REFCOUNT
-uint64_t SSLContext::count_ = 0;
-#endif
-
// SSLContext implementation
SSLContext::SSLContext(SSLVersion version) {
{
- std::lock_guard<std::mutex> g(mutex_);
-#ifndef SSLCONTEXT_NO_REFCOUNT
- count_++;
-#endif
+ std::lock_guard<std::mutex> g(initMutex());
initializeOpenSSLLocked();
}
SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
SSL_CTX_set_tlsext_servername_arg(ctx_, this);
#endif
+
+#ifdef OPENSSL_NPN_NEGOTIATED
+ Random::seed(nextProtocolPicker_);
+#endif
}
SSLContext::~SSLContext() {
#ifdef OPENSSL_NPN_NEGOTIATED
deleteNextProtocolsStrings();
#endif
-
-#ifndef SSLCONTEXT_NO_REFCOUNT
- {
- std::lock_guard<std::mutex> g(mutex_);
- if (!--count_) {
- cleanupOpenSSLLocked();
- }
- }
-#endif
}
void SSLContext::ciphers(const std::string& ciphers) {
}
}
+void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
+ if (cert.data() == nullptr) {
+ throw std::invalid_argument("loadCertificate: <cert> is nullptr");
+ }
+
+ std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
+ if (bio == nullptr) {
+ throw std::runtime_error("BIO_new: " + getErrors());
+ }
+
+ int written = BIO_write(bio.get(), cert.data(), cert.size());
+ if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
+ throw std::runtime_error("BIO_write: " + getErrors());
+ }
+
+ X509_UniquePtr x509(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
+ if (x509 == nullptr) {
+ throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
+ }
+
+ if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
+ throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
+ }
+}
+
void SSLContext::loadPrivateKey(const char* path, const char* format) {
if (path == nullptr || format == nullptr) {
throw std::invalid_argument(
- "loadPrivateKey: either <path> or <format> is nullptr");
+ "loadPrivateKey: either <path> or <format> is nullptr");
}
if (strcmp(format, "PEM") == 0) {
if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
}
}
+void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
+ if (pkey.data() == nullptr) {
+ throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
+ }
+
+ std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
+ if (bio == nullptr) {
+ throw std::runtime_error("BIO_new: " + getErrors());
+ }
+
+ int written = BIO_write(bio.get(), pkey.data(), pkey.size());
+ if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
+ throw std::runtime_error("BIO_write: " + getErrors());
+ }
+
+ EVP_PKEY_UniquePtr key(
+ PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
+ if (key == nullptr) {
+ throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
+ }
+
+ if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
+ throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
+ }
+}
+
void SSLContext::loadTrustedCertificates(const char* path) {
if (path == nullptr) {
- throw std::invalid_argument(
- "loadTrustedCertificates: <path> is nullptr");
+ throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
}
if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
}
#endif
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+int SSLContext::alpnSelectCallback(SSL* /* ssl */,
+ const unsigned char** out,
+ unsigned char* outlen,
+ const unsigned char* in,
+ unsigned int inlen,
+ void* data) {
+ SSLContext* context = (SSLContext*)data;
+ CHECK(context);
+ if (context->advertisedNextProtocols_.empty()) {
+ *out = nullptr;
+ *outlen = 0;
+ } else {
+ auto i = context->pickNextProtocols();
+ const auto& item = context->advertisedNextProtocols_[i];
+ if (SSL_select_next_proto((unsigned char**)out,
+ outlen,
+ item.protocols,
+ item.length,
+ in,
+ inlen) != OPENSSL_NPN_NEGOTIATED) {
+ return SSL_TLSEXT_ERR_NOACK;
+ }
+ }
+ return SSL_TLSEXT_ERR_OK;
+}
+#endif
+
#ifdef OPENSSL_NPN_NEGOTIATED
-bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
- return setRandomizedAdvertisedNextProtocols({{1, protocols}});
+
+bool SSLContext::setAdvertisedNextProtocols(
+ const std::list<std::string>& protocols, NextProtocolType protocolType) {
+ return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
}
bool SSLContext::setRandomizedAdvertisedNextProtocols(
- const std::list<NextProtocolsItem>& items) {
+ const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
unsetNextProtocols();
if (items.size() == 0) {
return false;
dst += protoLength;
}
total_weight += item.weight;
- advertised_item.probability = item.weight;
advertisedNextProtocols_.push_back(advertised_item);
+ advertisedNextProtocolWeights_.push_back(item.weight);
}
if (total_weight == 0) {
deleteNextProtocolsStrings();
return false;
}
- for (auto &advertised_item : advertisedNextProtocols_) {
- advertised_item.probability /= total_weight;
+ nextProtocolDistribution_ =
+ std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
+ advertisedNextProtocolWeights_.end());
+ if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
+ SSL_CTX_set_next_protos_advertised_cb(
+ ctx_, advertisedNextProtocolCallback, this);
+ SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
+ }
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+ if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
+ SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
+ // Client cannot really use randomized alpn
+ SSL_CTX_set_alpn_protos(ctx_,
+ advertisedNextProtocols_[0].protocols,
+ advertisedNextProtocols_[0].length);
}
- SSL_CTX_set_next_protos_advertised_cb(
- ctx_, advertisedNextProtocolCallback, this);
- SSL_CTX_set_next_proto_select_cb(
- ctx_, selectNextProtocolCallback, this);
+#endif
return true;
}
delete[] protocols.protocols;
}
advertisedNextProtocols_.clear();
+ advertisedNextProtocolWeights_.clear();
}
void SSLContext::unsetNextProtocols() {
deleteNextProtocolsStrings();
SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+ SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
+ SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
+#endif
+}
+
+size_t SSLContext::pickNextProtocols() {
+ CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
+ return nextProtocolDistribution_(nextProtocolPicker_);
}
int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
*out = context->advertisedNextProtocols_[selected_index].protocols;
*outlen = context->advertisedNextProtocols_[selected_index].length;
} else {
- unsigned char random_byte;
- RAND_bytes(&random_byte, 1);
- double random_value = random_byte / 255.0;
- double sum = 0;
- for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
- sum += context->advertisedNextProtocols_[i].probability;
- if (sum < random_value &&
- i + 1 < context->advertisedNextProtocols_.size()) {
- continue;
- }
- uintptr_t selected = i + 1;
- SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
- *out = context->advertisedNextProtocols_[i].protocols;
- *outlen = context->advertisedNextProtocols_[i].length;
- break;
- }
+ auto i = context->pickNextProtocols();
+ uintptr_t selected = i + 1;
+ SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
+ *out = context->advertisedNextProtocols_[i].protocols;
+ *outlen = context->advertisedNextProtocols_[i].length;
}
}
return SSL_TLSEXT_ERR_OK;
}
-int SSLContext::selectNextProtocolCallback(
- SSL* ssl, unsigned char **out, unsigned char *outlen,
- const unsigned char *server, unsigned int server_len, void *data) {
+#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
+ FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
+SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
+ ciphers_{
+ TLS1_CK_DHE_DSS_WITH_AES_128_SHA,
+ TLS1_CK_DHE_RSA_WITH_AES_128_SHA,
+ TLS1_CK_DHE_DSS_WITH_AES_256_SHA,
+ TLS1_CK_DHE_RSA_WITH_AES_256_SHA,
+ TLS1_CK_DHE_DSS_WITH_AES_128_SHA256,
+ TLS1_CK_DHE_RSA_WITH_AES_128_SHA256,
+ TLS1_CK_DHE_DSS_WITH_AES_256_SHA256,
+ TLS1_CK_DHE_RSA_WITH_AES_256_SHA256,
+ TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256,
+ TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384,
+ TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256,
+ TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384,
+ TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
+ TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA,
+ TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256,
+ TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384,
+ TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256,
+ TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384,
+ TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256,
+ TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384,
+ TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+ TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+ TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+ } {
+ length_ = sizeof(ciphers_)/sizeof(ciphers_[0]);
+ width_ = sizeof(ciphers_[0]);
+ qsort(ciphers_, length_, width_, compare_ulong);
+}
+
+bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher(
+ const SSL_CIPHER *cipher) {
+ unsigned long cid = cipher->id;
+ unsigned long *r =
+ (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong);
+ return r != nullptr;
+}
+
+int
+SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) {
+ if (*(unsigned long *)x < *(unsigned long *)y) {
+ return -1;
+ }
+ if (*(unsigned long *)x > *(unsigned long *)y) {
+ return 1;
+ }
+ return 0;
+};
+
+bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) {
+ return falseStartChecker_.canUseFalseStartWithCipher(cipher);
+}
+#endif
+int SSLContext::selectNextProtocolCallback(SSL* ssl,
+ unsigned char** out,
+ unsigned char* outlen,
+ const unsigned char* server,
+ unsigned int server_len,
+ void* data) {
+ (void)ssl; // Make -Wunused-parameters happy
SSLContext* ctx = (SSLContext*)data;
if (ctx->advertisedNextProtocols_.size() > 1) {
VLOG(3) << "SSLContext::selectNextProcolCallback() "
}
unsigned char *client;
- int client_len;
- if (ctx->advertisedNextProtocols_.empty()) {
- client = (unsigned char *) "";
- client_len = 0;
- } else {
- client = ctx->advertisedNextProtocols_[0].protocols;
- client_len = ctx->advertisedNextProtocols_[0].length;
+ unsigned int client_len;
+ bool filtered = false;
+ auto cpf = ctx->getClientProtocolFilterCallback();
+ if (cpf) {
+ filtered = (*cpf)(&client, &client_len, server, server_len);
+ }
+
+ if (!filtered) {
+ if (ctx->advertisedNextProtocols_.empty()) {
+ client = (unsigned char *) "";
+ client_len = 0;
+ } else {
+ client = ctx->advertisedNextProtocols_[0].protocols;
+ client_len = ctx->advertisedNextProtocols_[0].length;
+ }
}
int retval = SSL_select_next_proto(out, outlen, server, server_len,
if (retval != OPENSSL_NPN_NEGOTIATED) {
VLOG(3) << "SSLContext::selectNextProcolCallback() "
<< "unable to pick a next protocol.";
+#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
+ FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
+ } else {
+ const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher;
+ if (cipher && ctx->canUseFalseStartWithCipher(cipher)) {
+ SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH);
+ }
+#endif
}
return SSL_TLSEXT_ERR_OK;
}
}
SSLContext::SSLLockType lockType;
- folly::io::PortableSpinLock spinLock{};
+ folly::SpinLock spinLock{};
std::mutex mutex;
};
// SSLContext runs in such environments.
// Instead of declaring a static member we "new" the static
// member so that it won't be destructed on exit().
-static std::map<int, SSLContext::SSLLockType>* lockTypesInst =
- new std::map<int, SSLContext::SSLLockType>();
-
-static std::unique_ptr<SSLLock[]>* locksInst =
- new std::unique_ptr<SSLLock[]>();
-
static std::unique_ptr<SSLLock[]>& locks() {
+ static auto locksInst = new std::unique_ptr<SSLLock[]>();
return *locksInst;
}
static std::map<int, SSLContext::SSLLockType>& lockTypes() {
+ static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
return *lockTypesInst;
}
lockTypes() = inLockTypes;
}
+void SSLContext::markInitialized() {
+ std::lock_guard<std::mutex> g(initMutex());
+ initialized_ = true;
+}
+
void SSLContext::initializeOpenSSL() {
- std::lock_guard<std::mutex> g(mutex_);
+ std::lock_guard<std::mutex> g(initMutex());
initializeOpenSSLLocked();
}
}
void SSLContext::cleanupOpenSSL() {
- std::lock_guard<std::mutex> g(mutex_);
+ std::lock_guard<std::mutex> g(initMutex());
cleanupOpenSSLLocked();
}
return os;
}
+bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx,
+ sockaddr_storage* addrStorage,
+ socklen_t* addrLen) {
+ // Grab the ssl idx and then the ssl object so that we can get the peer
+ // name to compare against the ips in the subjectAltName
+ auto sslIdx = SSL_get_ex_data_X509_STORE_CTX_idx();
+ auto ssl =
+ reinterpret_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, sslIdx));
+ int fd = SSL_get_fd(ssl);
+ if (fd < 0) {
+ LOG(ERROR) << "Inexplicably couldn't get fd from SSL";
+ return false;
+ }
+
+ *addrLen = sizeof(*addrStorage);
+ if (getpeername(fd, reinterpret_cast<sockaddr*>(addrStorage), addrLen) != 0) {
+ PLOG(ERROR) << "Unable to get peer name";
+ return false;
+ }
+ CHECK(*addrLen <= sizeof(*addrStorage));
+ return true;
+}
+
+bool OpenSSLUtils::validatePeerCertNames(X509* cert,
+ const sockaddr* addr,
+ socklen_t /* addrLen */) {
+ // Try to extract the names within the SAN extension from the certificate
+ auto altNames =
+ reinterpret_cast<STACK_OF(GENERAL_NAME)*>(
+ X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
+ SCOPE_EXIT {
+ if (altNames != nullptr) {
+ sk_GENERAL_NAME_pop_free(altNames, GENERAL_NAME_free);
+ }
+ };
+ if (altNames == nullptr) {
+ LOG(WARNING) << "No subjectAltName provided and we only support ip auth";
+ return false;
+ }
+
+ const sockaddr_in* addr4 = nullptr;
+ const sockaddr_in6* addr6 = nullptr;
+ if (addr != nullptr) {
+ if (addr->sa_family == AF_INET) {
+ addr4 = reinterpret_cast<const sockaddr_in*>(addr);
+ } else if (addr->sa_family == AF_INET6) {
+ addr6 = reinterpret_cast<const sockaddr_in6*>(addr);
+ } else {
+ LOG(FATAL) << "Unsupported sockaddr family: " << addr->sa_family;
+ }
+ }
+
+
+ for (int i = 0; i < sk_GENERAL_NAME_num(altNames); i++) {
+ auto name = sk_GENERAL_NAME_value(altNames, i);
+ if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) {
+ // Extra const-ness for paranoia
+ unsigned char const * const rawIpStr = name->d.iPAddress->data;
+ int const rawIpLen = name->d.iPAddress->length;
+
+ if (rawIpLen == 4 && addr4 != nullptr) {
+ if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) {
+ return true;
+ }
+ } else if (rawIpLen == 16 && addr6 != nullptr) {
+ if (::memcmp(rawIpStr, &addr6->sin6_addr, rawIpLen) == 0) {
+ return true;
+ }
+ } else if (rawIpLen != 4 && rawIpLen != 16) {
+ LOG(WARNING) << "Unexpected IP length: " << rawIpLen;
+ }
+ }
+ }
+
+ LOG(WARNING) << "Unable to match client cert against alt name ip";
+ return false;
+}
+
+
} // folly