2 * Copyright 2014 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "SSLContext.h"
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
24 #include <folly/SmallLocks.h>
25 #include <folly/Format.h>
26 #include <folly/io/PortableSpinLock.h>
28 // ---------------------------------------------------------------------
29 // SSLContext implementation
30 // ---------------------------------------------------------------------
32 struct CRYPTO_dynlock_value {
38 uint64_t SSLContext::count_ = 0;
39 std::mutex SSLContext::mutex_;
40 #ifdef OPENSSL_NPN_NEGOTIATED
41 int SSLContext::sNextProtocolsExDataIndex_ = -1;
44 // SSLContext implementation
45 SSLContext::SSLContext(SSLVersion version) {
47 std::lock_guard<std::mutex> g(mutex_);
51 #ifdef OPENSSL_NPN_NEGOTIATED
52 sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
53 (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
58 ctx_ = SSL_CTX_new(SSLv23_method());
59 if (ctx_ == nullptr) {
60 throw std::runtime_error("SSL_CTX_new: " + getErrors());
66 opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
69 opt = SSL_OP_NO_SSLv2;
75 int newOpt = SSL_CTX_set_options(ctx_, opt);
76 DCHECK((newOpt & opt) == opt);
78 SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
80 checkPeerName_ = false;
82 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
83 SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
84 SSL_CTX_set_tlsext_servername_arg(ctx_, this);
88 SSLContext::~SSLContext() {
89 if (ctx_ != nullptr) {
94 #ifdef OPENSSL_NPN_NEGOTIATED
95 deleteNextProtocolsStrings();
98 std::lock_guard<std::mutex> g(mutex_);
104 void SSLContext::ciphers(const std::string& ciphers) {
105 providedCiphersString_ = ciphers;
106 setCiphersOrThrow(ciphers);
109 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
110 int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
111 if (ERR_peek_error() != 0) {
112 throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
115 throw std::runtime_error("None of specified ciphers are supported");
119 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
121 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
122 verifyPeer_ = verifyPeer;
125 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
127 CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
128 int mode = SSL_VERIFY_NONE;
130 // case SSLVerifyPeerEnum::USE_CTX: // can't happen
133 case SSLVerifyPeerEnum::VERIFY:
134 mode = SSL_VERIFY_PEER;
137 case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
138 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
141 case SSLVerifyPeerEnum::NO_VERIFY:
142 mode = SSL_VERIFY_NONE;
151 int SSLContext::getVerificationMode() {
152 return getVerificationMode(verifyPeer_);
155 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
156 const std::string& peerName) {
159 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
160 checkPeerName_ = checkPeerName;
161 peerFixedName_ = peerName;
163 mode = SSL_VERIFY_NONE;
164 checkPeerName_ = false; // can't check name without cert!
165 peerFixedName_.clear();
167 SSL_CTX_set_verify(ctx_, mode, nullptr);
170 void SSLContext::loadCertificate(const char* path, const char* format) {
171 if (path == nullptr || format == nullptr) {
172 throw std::invalid_argument(
173 "loadCertificateChain: either <path> or <format> is nullptr");
175 if (strcmp(format, "PEM") == 0) {
176 if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
177 int errnoCopy = errno;
178 std::string reason("SSL_CTX_use_certificate_chain_file: ");
181 reason.append(getErrors(errnoCopy));
182 throw std::runtime_error(reason);
185 throw std::runtime_error("Unsupported certificate format: " + std::string(format));
189 void SSLContext::loadPrivateKey(const char* path, const char* format) {
190 if (path == nullptr || format == nullptr) {
191 throw std::invalid_argument(
192 "loadPrivateKey: either <path> or <format> is nullptr");
194 if (strcmp(format, "PEM") == 0) {
195 if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
196 throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
199 throw std::runtime_error("Unsupported private key format: " + std::string(format));
203 void SSLContext::loadTrustedCertificates(const char* path) {
204 if (path == nullptr) {
205 throw std::invalid_argument(
206 "loadTrustedCertificates: <path> is nullptr");
208 if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
209 throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
213 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
214 SSL_CTX_set_cert_store(ctx_, store);
217 void SSLContext::loadClientCAList(const char* path) {
218 auto clientCAs = SSL_load_client_CA_file(path);
219 if (clientCAs == nullptr) {
220 LOG(ERROR) << "Unable to load ca file: " << path;
223 SSL_CTX_set_client_CA_list(ctx_, clientCAs);
226 void SSLContext::randomize() {
230 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
231 if (collector == nullptr) {
232 LOG(ERROR) << "passwordCollector: ignore invalid password collector";
235 collector_ = collector;
236 SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
237 SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
240 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
242 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
246 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
247 clientHelloCbs_.push_back(cb);
250 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
251 SSLContext* context = (SSLContext*)data;
253 if (context == nullptr) {
254 return SSL_TLSEXT_ERR_NOACK;
257 for (auto& cb : context->clientHelloCbs_) {
258 // Generic callbacks to happen after we receive the Client Hello.
259 // For example, we use one to switch which cipher we use depending
260 // on the user's TLS version. Because the primary purpose of
261 // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
262 // are side-uses, we ignore any possible failures other than just logging
267 if (!context->serverNameCb_) {
268 return SSL_TLSEXT_ERR_NOACK;
271 ServerNameCallbackResult ret = context->serverNameCb_(ssl);
273 case SERVER_NAME_FOUND:
274 return SSL_TLSEXT_ERR_OK;
275 case SERVER_NAME_NOT_FOUND:
276 return SSL_TLSEXT_ERR_NOACK;
277 case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
278 *al = TLS1_AD_UNRECOGNIZED_NAME;
279 return SSL_TLSEXT_ERR_ALERT_FATAL;
284 return SSL_TLSEXT_ERR_NOACK;
287 void SSLContext::switchCiphersIfTLS11(
289 const std::string& tls11CipherString) {
291 CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
293 if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
294 // We only do this for TLS v 1.1 and later
298 // Prefer AES for TLS versions 1.1 and later since these are not
299 // vulnerable to BEAST attacks on AES. Note that we're setting the
300 // cipher list on the SSL object, not the SSL_CTX object, so it will
301 // only last for this request.
302 int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
303 if ((rc == 0) || ERR_peek_error() != 0) {
304 // This shouldn't happen since we checked for this when proxygen
306 LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
307 SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
312 #ifdef OPENSSL_NPN_NEGOTIATED
313 bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
314 return setRandomizedAdvertisedNextProtocols({{1, protocols}});
317 bool SSLContext::setRandomizedAdvertisedNextProtocols(
318 const std::list<NextProtocolsItem>& items) {
319 unsetNextProtocols();
320 if (items.size() == 0) {
323 int total_weight = 0;
324 for (const auto &item : items) {
325 if (item.protocols.size() == 0) {
328 AdvertisedNextProtocolsItem advertised_item;
329 advertised_item.length = 0;
330 for (const auto& proto : item.protocols) {
331 ++advertised_item.length;
332 unsigned protoLength = proto.length();
333 if (protoLength >= 256) {
334 deleteNextProtocolsStrings();
337 advertised_item.length += protoLength;
339 advertised_item.protocols = new unsigned char[advertised_item.length];
340 if (!advertised_item.protocols) {
341 throw std::runtime_error("alloc failure");
343 unsigned char* dst = advertised_item.protocols;
344 for (auto& proto : item.protocols) {
345 unsigned protoLength = proto.length();
346 *dst++ = (unsigned char)protoLength;
347 memcpy(dst, proto.data(), protoLength);
350 total_weight += item.weight;
351 advertised_item.probability = item.weight;
352 advertisedNextProtocols_.push_back(advertised_item);
354 if (total_weight == 0) {
355 deleteNextProtocolsStrings();
358 for (auto &advertised_item : advertisedNextProtocols_) {
359 advertised_item.probability /= total_weight;
361 SSL_CTX_set_next_protos_advertised_cb(
362 ctx_, advertisedNextProtocolCallback, this);
363 SSL_CTX_set_next_proto_select_cb(
364 ctx_, selectNextProtocolCallback, this);
368 void SSLContext::deleteNextProtocolsStrings() {
369 for (auto protocols : advertisedNextProtocols_) {
370 delete[] protocols.protocols;
372 advertisedNextProtocols_.clear();
375 void SSLContext::unsetNextProtocols() {
376 deleteNextProtocolsStrings();
377 SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
378 SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
381 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
382 const unsigned char** out, unsigned int* outlen, void* data) {
383 SSLContext* context = (SSLContext*)data;
384 if (context == nullptr || context->advertisedNextProtocols_.empty()) {
387 } else if (context->advertisedNextProtocols_.size() == 1) {
388 *out = context->advertisedNextProtocols_[0].protocols;
389 *outlen = context->advertisedNextProtocols_[0].length;
391 uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
392 sNextProtocolsExDataIndex_));
393 if (selected_index) {
395 *out = context->advertisedNextProtocols_[selected_index].protocols;
396 *outlen = context->advertisedNextProtocols_[selected_index].length;
398 unsigned char random_byte;
399 RAND_bytes(&random_byte, 1);
400 double random_value = random_byte / 255.0;
402 for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
403 sum += context->advertisedNextProtocols_[i].probability;
404 if (sum < random_value &&
405 i + 1 < context->advertisedNextProtocols_.size()) {
408 uintptr_t selected = i + 1;
409 SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
410 *out = context->advertisedNextProtocols_[i].protocols;
411 *outlen = context->advertisedNextProtocols_[i].length;
416 return SSL_TLSEXT_ERR_OK;
419 int SSLContext::selectNextProtocolCallback(
420 SSL* ssl, unsigned char **out, unsigned char *outlen,
421 const unsigned char *server, unsigned int server_len, void *data) {
423 SSLContext* ctx = (SSLContext*)data;
424 if (ctx->advertisedNextProtocols_.size() > 1) {
425 VLOG(3) << "SSLContext::selectNextProcolCallback() "
426 << "client should be deterministic in selecting protocols.";
429 unsigned char *client;
431 if (ctx->advertisedNextProtocols_.empty()) {
432 client = (unsigned char *) "";
435 client = ctx->advertisedNextProtocols_[0].protocols;
436 client_len = ctx->advertisedNextProtocols_[0].length;
439 int retval = SSL_select_next_proto(out, outlen, server, server_len,
441 if (retval != OPENSSL_NPN_NEGOTIATED) {
442 VLOG(3) << "SSLContext::selectNextProcolCallback() "
443 << "unable to pick a next protocol.";
445 return SSL_TLSEXT_ERR_OK;
447 #endif // OPENSSL_NPN_NEGOTIATED
449 SSL* SSLContext::createSSL() const {
450 SSL* ssl = SSL_new(ctx_);
451 if (ssl == nullptr) {
452 throw std::runtime_error("SSL_new: " + getErrors());
458 * Match a name with a pattern. The pattern may include wildcard. A single
459 * wildcard "*" can match up to one component in the domain name.
461 * @param host Host name, typically the name of the remote host
462 * @param pattern Name retrieved from certificate
463 * @param size Size of "pattern"
464 * @return True, if "host" matches "pattern". False otherwise.
466 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
469 while (i < size && host[j] != '\0') {
470 if (toupper(pattern[i]) == toupper(host[j])) {
475 if (pattern[i] == '*') {
476 while (host[j] != '.' && host[j] != '\0') {
484 if (i == size && host[j] == '\0') {
490 int SSLContext::passwordCallback(char* password,
494 SSLContext* context = (SSLContext*)data;
495 if (context == nullptr || context->passwordCollector() == nullptr) {
498 std::string userPassword;
499 // call user defined password collector to get password
500 context->passwordCollector()->getPassword(userPassword, size);
501 int length = userPassword.size();
505 strncpy(password, userPassword.c_str(), length);
511 SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
512 lockType(inLockType) {
516 if (lockType == SSLContext::LOCK_MUTEX) {
518 } else if (lockType == SSLContext::LOCK_SPINLOCK) {
521 // lockType == LOCK_NONE, no-op
525 if (lockType == SSLContext::LOCK_MUTEX) {
527 } else if (lockType == SSLContext::LOCK_SPINLOCK) {
530 // lockType == LOCK_NONE, no-op
533 SSLContext::SSLLockType lockType;
534 folly::io::PortableSpinLock spinLock{};
538 static std::map<int, SSLContext::SSLLockType> lockTypes;
539 static std::unique_ptr<SSLLock[]> locks;
541 static void callbackLocking(int mode, int n, const char*, int) {
542 if (mode & CRYPTO_LOCK) {
549 static unsigned long callbackThreadID() {
550 return static_cast<unsigned long>(pthread_self());
553 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
554 return new CRYPTO_dynlock_value;
557 static void dyn_lock(int mode,
558 struct CRYPTO_dynlock_value* lock,
560 if (lock != nullptr) {
561 if (mode & CRYPTO_LOCK) {
564 lock->mutex.unlock();
569 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
573 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
574 lockTypes = inLockTypes;
577 void SSLContext::initializeOpenSSL() {
579 SSL_load_error_strings();
580 ERR_load_crypto_strings();
582 locks.reset(new SSLLock[::CRYPTO_num_locks()]);
583 for (auto it: lockTypes) {
584 locks[it.first].lockType = it.second;
586 CRYPTO_set_id_callback(callbackThreadID);
587 CRYPTO_set_locking_callback(callbackLocking);
589 CRYPTO_set_dynlock_create_callback(dyn_create);
590 CRYPTO_set_dynlock_lock_callback(dyn_lock);
591 CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
594 void SSLContext::cleanupOpenSSL() {
595 CRYPTO_set_id_callback(nullptr);
596 CRYPTO_set_locking_callback(nullptr);
597 CRYPTO_set_dynlock_create_callback(nullptr);
598 CRYPTO_set_dynlock_lock_callback(nullptr);
599 CRYPTO_set_dynlock_destroy_callback(nullptr);
600 CRYPTO_cleanup_all_ex_data();
607 void SSLContext::setOptions(long options) {
608 long newOpt = SSL_CTX_set_options(ctx_, options);
609 if ((newOpt & options) != options) {
610 throw std::runtime_error("SSL_CTX_set_options failed");
614 std::string SSLContext::getErrors(int errnoCopy) {
616 unsigned long errorCode;
620 while ((errorCode = ERR_get_error()) != 0) {
621 if (!errors.empty()) {
624 const char* reason = ERR_reason_error_string(errorCode);
625 if (reason == nullptr) {
626 snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
631 if (errors.empty()) {
632 errors = "error code: " + folly::to<std::string>(errnoCopy);
638 operator<<(std::ostream& os, const PasswordCollector& collector) {
639 os << collector.describe();