2 * Copyright 2017 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "SSLContext.h"
19 #include <folly/Format.h>
20 #include <folly/Memory.h>
21 #include <folly/Random.h>
22 #include <folly/SharedMutex.h>
23 #include <folly/SpinLock.h>
24 #include <folly/ssl/Init.h>
25 #include <folly/system/ThreadId.h>
27 // ---------------------------------------------------------------------
28 // SSLContext implementation
29 // ---------------------------------------------------------------------
32 // For OpenSSL portability API
33 using namespace folly::ssl;
35 // SSLContext implementation
36 SSLContext::SSLContext(SSLVersion version) {
39 ctx_ = SSL_CTX_new(SSLv23_method());
40 if (ctx_ == nullptr) {
41 throw std::runtime_error("SSL_CTX_new: " + getErrors());
47 opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
50 opt = SSL_OP_NO_SSLv2;
53 opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 |
60 int newOpt = SSL_CTX_set_options(ctx_, opt);
61 DCHECK((newOpt & opt) == opt);
63 SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
65 checkPeerName_ = false;
67 SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
69 #if FOLLY_OPENSSL_HAS_SNI
70 SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
71 SSL_CTX_set_tlsext_servername_arg(ctx_, this);
75 SSLContext::~SSLContext() {
76 if (ctx_ != nullptr) {
81 #ifdef OPENSSL_NPN_NEGOTIATED
82 deleteNextProtocolsStrings();
86 void SSLContext::ciphers(const std::string& ciphers) {
87 setCiphersOrThrow(ciphers);
90 void SSLContext::setClientECCurvesList(
91 const std::vector<std::string>& ecCurves) {
92 if (ecCurves.size() == 0) {
95 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
96 std::string ecCurvesList;
97 join(":", ecCurves, ecCurvesList);
98 int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
100 throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
105 void SSLContext::setServerECCurve(const std::string& curveName) {
106 #if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
107 EC_KEY* ecdh = nullptr;
111 * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
112 * from RFC 4492 section 5.1.1, or explicitly described curves over
113 * binary fields. OpenSSL only supports the "named curves", which provide
114 * maximum interoperability.
117 nid = OBJ_sn2nid(curveName.c_str());
119 LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
121 ecdh = EC_KEY_new_by_curve_name(nid);
122 if (ecdh == nullptr) {
123 LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
126 SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
129 throw std::runtime_error("Elliptic curve encryption not allowed");
133 void SSLContext::setX509VerifyParam(
134 const ssl::X509VerifyParam& x509VerifyParam) {
135 if (!x509VerifyParam) {
138 if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
139 throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
143 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
144 int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
146 throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
148 providedCiphersString_ = ciphers;
151 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
153 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
154 verifyPeer_ = verifyPeer;
157 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
159 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
160 int mode = SSL_VERIFY_NONE;
162 // case SSLVerifyPeerEnum::USE_CTX: // can't happen
165 case SSLVerifyPeerEnum::VERIFY:
166 mode = SSL_VERIFY_PEER;
169 case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
170 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
173 case SSLVerifyPeerEnum::NO_VERIFY:
174 mode = SSL_VERIFY_NONE;
183 int SSLContext::getVerificationMode() {
184 return getVerificationMode(verifyPeer_);
187 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
188 const std::string& peerName) {
191 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
192 SSL_VERIFY_CLIENT_ONCE;
193 checkPeerName_ = checkPeerName;
194 peerFixedName_ = peerName;
196 mode = SSL_VERIFY_NONE;
197 checkPeerName_ = false; // can't check name without cert!
198 peerFixedName_.clear();
200 SSL_CTX_set_verify(ctx_, mode, nullptr);
203 void SSLContext::loadCertificate(const char* path, const char* format) {
204 if (path == nullptr || format == nullptr) {
205 throw std::invalid_argument(
206 "loadCertificateChain: either <path> or <format> is nullptr");
208 if (strcmp(format, "PEM") == 0) {
209 if (SSL_CTX_use_certificate_chain_file(ctx_, path) != 1) {
210 int errnoCopy = errno;
211 std::string reason("SSL_CTX_use_certificate_chain_file: ");
214 reason.append(getErrors(errnoCopy));
215 throw std::runtime_error(reason);
218 throw std::runtime_error(
219 "Unsupported certificate format: " + std::string(format));
223 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
224 if (cert.data() == nullptr) {
225 throw std::invalid_argument("loadCertificate: <cert> is nullptr");
228 ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
229 if (bio == nullptr) {
230 throw std::runtime_error("BIO_new: " + getErrors());
233 int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
234 if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
235 throw std::runtime_error("BIO_write: " + getErrors());
238 ssl::X509UniquePtr x509(
239 PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
240 if (x509 == nullptr) {
241 throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
244 if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
245 throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
249 void SSLContext::loadPrivateKey(const char* path, const char* format) {
250 if (path == nullptr || format == nullptr) {
251 throw std::invalid_argument(
252 "loadPrivateKey: either <path> or <format> is nullptr");
254 if (strcmp(format, "PEM") == 0) {
255 if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
256 throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
259 throw std::runtime_error(
260 "Unsupported private key format: " + std::string(format));
264 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
265 if (pkey.data() == nullptr) {
266 throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
269 ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
270 if (bio == nullptr) {
271 throw std::runtime_error("BIO_new: " + getErrors());
274 int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
275 if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
276 throw std::runtime_error("BIO_write: " + getErrors());
279 ssl::EvpPkeyUniquePtr key(
280 PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
281 if (key == nullptr) {
282 throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
285 if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
286 throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
290 void SSLContext::loadCertKeyPairFromBufferPEM(
291 folly::StringPiece cert,
292 folly::StringPiece pkey) {
293 loadCertificateFromBufferPEM(cert);
294 loadPrivateKeyFromBufferPEM(pkey);
295 if (!isCertKeyPairValid()) {
296 throw std::runtime_error("SSL certificate and private key do not match");
300 void SSLContext::loadCertKeyPairFromFiles(
301 const char* certPath,
303 const char* certFormat,
304 const char* keyFormat) {
305 loadCertificate(certPath, certFormat);
306 loadPrivateKey(keyPath, keyFormat);
307 if (!isCertKeyPairValid()) {
308 throw std::runtime_error("SSL certificate and private key do not match");
312 bool SSLContext::isCertKeyPairValid() const {
313 return SSL_CTX_check_private_key(ctx_) == 1;
316 void SSLContext::loadTrustedCertificates(const char* path) {
317 if (path == nullptr) {
318 throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
320 if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
321 throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
326 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
327 SSL_CTX_set_cert_store(ctx_, store);
330 void SSLContext::loadClientCAList(const char* path) {
331 auto clientCAs = SSL_load_client_CA_file(path);
332 if (clientCAs == nullptr) {
333 LOG(ERROR) << "Unable to load ca file: " << path;
336 SSL_CTX_set_client_CA_list(ctx_, clientCAs);
339 void SSLContext::passwordCollector(
340 std::shared_ptr<PasswordCollector> collector) {
341 if (collector == nullptr) {
342 LOG(ERROR) << "passwordCollector: ignore invalid password collector";
345 collector_ = collector;
346 SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
347 SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
350 #if FOLLY_OPENSSL_HAS_SNI
352 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
356 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
357 clientHelloCbs_.push_back(cb);
360 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
361 SSLContext* context = (SSLContext*)data;
363 if (context == nullptr) {
364 return SSL_TLSEXT_ERR_NOACK;
367 for (auto& cb : context->clientHelloCbs_) {
368 // Generic callbacks to happen after we receive the Client Hello.
369 // For example, we use one to switch which cipher we use depending
370 // on the user's TLS version. Because the primary purpose of
371 // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
372 // are side-uses, we ignore any possible failures other than just logging
377 if (!context->serverNameCb_) {
378 return SSL_TLSEXT_ERR_NOACK;
381 ServerNameCallbackResult ret = context->serverNameCb_(ssl);
383 case SERVER_NAME_FOUND:
384 return SSL_TLSEXT_ERR_OK;
385 case SERVER_NAME_NOT_FOUND:
386 return SSL_TLSEXT_ERR_NOACK;
387 case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
388 *al = TLS1_AD_UNRECOGNIZED_NAME;
389 return SSL_TLSEXT_ERR_ALERT_FATAL;
394 return SSL_TLSEXT_ERR_NOACK;
396 #endif // FOLLY_OPENSSL_HAS_SNI
398 #if FOLLY_OPENSSL_HAS_ALPN
399 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
400 const unsigned char** out,
401 unsigned char* outlen,
402 const unsigned char* in,
405 SSLContext* context = (SSLContext*)data;
407 if (context->advertisedNextProtocols_.empty()) {
411 auto i = context->pickNextProtocols();
412 const auto& item = context->advertisedNextProtocols_[i];
413 if (SSL_select_next_proto((unsigned char**)out,
418 inlen) != OPENSSL_NPN_NEGOTIATED) {
419 return SSL_TLSEXT_ERR_NOACK;
422 return SSL_TLSEXT_ERR_OK;
424 #endif // FOLLY_OPENSSL_HAS_ALPN
426 #ifdef OPENSSL_NPN_NEGOTIATED
428 bool SSLContext::setAdvertisedNextProtocols(
429 const std::list<std::string>& protocols, NextProtocolType protocolType) {
430 return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
433 bool SSLContext::setRandomizedAdvertisedNextProtocols(
434 const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
435 unsetNextProtocols();
436 if (items.size() == 0) {
439 int total_weight = 0;
440 for (const auto &item : items) {
441 if (item.protocols.size() == 0) {
444 AdvertisedNextProtocolsItem advertised_item;
445 advertised_item.length = 0;
446 for (const auto& proto : item.protocols) {
447 ++advertised_item.length;
448 auto protoLength = proto.length();
449 if (protoLength >= 256) {
450 deleteNextProtocolsStrings();
453 advertised_item.length += unsigned(protoLength);
455 advertised_item.protocols = new unsigned char[advertised_item.length];
456 if (!advertised_item.protocols) {
457 throw std::runtime_error("alloc failure");
459 unsigned char* dst = advertised_item.protocols;
460 for (auto& proto : item.protocols) {
461 uint8_t protoLength = uint8_t(proto.length());
462 *dst++ = (unsigned char)protoLength;
463 memcpy(dst, proto.data(), protoLength);
466 total_weight += item.weight;
467 advertisedNextProtocols_.push_back(advertised_item);
468 advertisedNextProtocolWeights_.push_back(item.weight);
470 if (total_weight == 0) {
471 deleteNextProtocolsStrings();
474 nextProtocolDistribution_ =
475 std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
476 advertisedNextProtocolWeights_.end());
477 if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
478 SSL_CTX_set_next_protos_advertised_cb(
479 ctx_, advertisedNextProtocolCallback, this);
480 SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
482 #if FOLLY_OPENSSL_HAS_ALPN
483 if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
484 SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
485 // Client cannot really use randomized alpn
486 SSL_CTX_set_alpn_protos(ctx_,
487 advertisedNextProtocols_[0].protocols,
488 advertisedNextProtocols_[0].length);
494 void SSLContext::deleteNextProtocolsStrings() {
495 for (auto protocols : advertisedNextProtocols_) {
496 delete[] protocols.protocols;
498 advertisedNextProtocols_.clear();
499 advertisedNextProtocolWeights_.clear();
502 void SSLContext::unsetNextProtocols() {
503 deleteNextProtocolsStrings();
504 SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
505 SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
506 #if FOLLY_OPENSSL_HAS_ALPN
507 SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
508 SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
512 size_t SSLContext::pickNextProtocols() {
513 CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
514 auto rng = ThreadLocalPRNG();
515 return size_t(nextProtocolDistribution_(rng));
518 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
519 const unsigned char** out, unsigned int* outlen, void* data) {
520 static int nextProtocolsExDataIndex = SSL_get_ex_new_index(
521 0, (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
523 SSLContext* context = (SSLContext*)data;
524 if (context == nullptr || context->advertisedNextProtocols_.empty()) {
527 } else if (context->advertisedNextProtocols_.size() == 1) {
528 *out = context->advertisedNextProtocols_[0].protocols;
529 *outlen = context->advertisedNextProtocols_[0].length;
531 uintptr_t selected_index = reinterpret_cast<uintptr_t>(
532 SSL_get_ex_data(ssl, nextProtocolsExDataIndex));
533 if (selected_index) {
535 *out = context->advertisedNextProtocols_[selected_index].protocols;
536 *outlen = context->advertisedNextProtocols_[selected_index].length;
538 auto i = context->pickNextProtocols();
539 uintptr_t selected = i + 1;
540 SSL_set_ex_data(ssl, nextProtocolsExDataIndex, (void*)selected);
541 *out = context->advertisedNextProtocols_[i].protocols;
542 *outlen = context->advertisedNextProtocols_[i].length;
545 return SSL_TLSEXT_ERR_OK;
548 int SSLContext::selectNextProtocolCallback(SSL* ssl,
550 unsigned char* outlen,
551 const unsigned char* server,
552 unsigned int server_len,
554 (void)ssl; // Make -Wunused-parameters happy
555 SSLContext* ctx = (SSLContext*)data;
556 if (ctx->advertisedNextProtocols_.size() > 1) {
557 VLOG(3) << "SSLContext::selectNextProcolCallback() "
558 << "client should be deterministic in selecting protocols.";
561 unsigned char* client = nullptr;
562 unsigned int client_len = 0;
563 bool filtered = false;
564 auto cpf = ctx->getClientProtocolFilterCallback();
566 filtered = (*cpf)(&client, &client_len, server, server_len);
570 if (ctx->advertisedNextProtocols_.empty()) {
571 client = (unsigned char *) "";
574 client = ctx->advertisedNextProtocols_[0].protocols;
575 client_len = ctx->advertisedNextProtocols_[0].length;
579 int retval = SSL_select_next_proto(out, outlen, server, server_len,
581 if (retval != OPENSSL_NPN_NEGOTIATED) {
582 VLOG(3) << "SSLContext::selectNextProcolCallback() "
583 << "unable to pick a next protocol.";
585 return SSL_TLSEXT_ERR_OK;
587 #endif // OPENSSL_NPN_NEGOTIATED
589 SSL* SSLContext::createSSL() const {
590 SSL* ssl = SSL_new(ctx_);
591 if (ssl == nullptr) {
592 throw std::runtime_error("SSL_new: " + getErrors());
597 void SSLContext::setSessionCacheContext(const std::string& context) {
598 SSL_CTX_set_session_id_context(
600 reinterpret_cast<const unsigned char*>(context.data()),
601 std::min<unsigned int>(
602 static_cast<unsigned int>(context.length()), SSL_MAX_SID_CTX_LENGTH));
606 * Match a name with a pattern. The pattern may include wildcard. A single
607 * wildcard "*" can match up to one component in the domain name.
609 * @param host Host name, typically the name of the remote host
610 * @param pattern Name retrieved from certificate
611 * @param size Size of "pattern"
612 * @return True, if "host" matches "pattern". False otherwise.
614 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
617 while (i < size && host[j] != '\0') {
618 if (toupper(pattern[i]) == toupper(host[j])) {
623 if (pattern[i] == '*') {
624 while (host[j] != '.' && host[j] != '\0') {
632 if (i == size && host[j] == '\0') {
638 int SSLContext::passwordCallback(char* password,
642 SSLContext* context = (SSLContext*)data;
643 if (context == nullptr || context->passwordCollector() == nullptr) {
646 std::string userPassword;
647 // call user defined password collector to get password
648 context->passwordCollector()->getPassword(userPassword, size);
649 auto const length = std::min(userPassword.size(), size_t(size));
650 std::memcpy(password, userPassword.data(), length);
654 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
655 void SSLContext::enableFalseStart() {
656 SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
660 void SSLContext::initializeOpenSSL() {
664 void SSLContext::setOptions(long options) {
665 long newOpt = SSL_CTX_set_options(ctx_, options);
666 if ((newOpt & options) != options) {
667 throw std::runtime_error("SSL_CTX_set_options failed");
671 std::string SSLContext::getErrors(int errnoCopy) {
673 unsigned long errorCode;
677 while ((errorCode = ERR_get_error()) != 0) {
678 if (!errors.empty()) {
681 const char* reason = ERR_reason_error_string(errorCode);
682 if (reason == nullptr) {
683 snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
688 if (errors.empty()) {
689 errors = "error code: " + folly::to<std::string>(errnoCopy);
695 operator<<(std::ostream& os, const PasswordCollector& collector) {
696 os << collector.describe();