X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=folly%2Fio%2Fasync%2FSSLContext.cpp;h=a8cf72de2895e514c6a14888c494e2c15c9b8ff9;hb=75e5507cbfe7ef5a448375e9908e10864b506b05;hp=cfd25c7f7c55341b944ec7344afdcf4b7f2a8801;hpb=60331af0cfde1c9642367d4e40dc8cec17159ab4;p=folly.git diff --git a/folly/io/async/SSLContext.cpp b/folly/io/async/SSLContext.cpp index cfd25c7f..a8cf72de 100644 --- a/folly/io/async/SSLContext.cpp +++ b/folly/io/async/SSLContext.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2015 Facebook, Inc. + * Copyright 2016 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,8 @@ #include #include +#include +#include #include // --------------------------------------------------------------------- @@ -43,7 +45,7 @@ std::mutex& initMutex() { return m; } -} // anonymous namespace +} // anonymous namespace #ifdef OPENSSL_NPN_NEGOTIATED int SSLContext::sNextProtocolsExDataIndex_ = -1; @@ -80,6 +82,8 @@ SSLContext::SSLContext(SSLVersion version) { checkPeerName_ = false; + SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION); + #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback); SSL_CTX_set_tlsext_servername_arg(ctx_, this); @@ -102,13 +106,95 @@ void SSLContext::ciphers(const std::string& ciphers) { setCiphersOrThrow(ciphers); } +void SSLContext::setCipherList(const std::vector& ciphers) { + if (ciphers.size() == 0) { + return; + } + std::string opensslCipherList; + join(":", ciphers, opensslCipherList); + setCiphersOrThrow(opensslCipherList); +} + +void SSLContext::setSignatureAlgorithms( + const std::vector& sigalgs) { + if (sigalgs.size() == 0) { + return; + } +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL + std::string opensslSigAlgsList; + join(":", sigalgs, opensslSigAlgsList); + int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str()); + if (rc == 0) { + throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors()); + } +#endif +} + +void SSLContext::setClientECCurvesList( + const std::vector& ecCurves) { + if (ecCurves.size() == 0) { + return; + } +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL + std::string ecCurvesList; + join(":", ecCurves, ecCurvesList); + int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str()); + if (rc == 0) { + throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors()); + } +#endif +} + +void SSLContext::setServerECCurve(const std::string& curveName) { + bool validCall = false; +#if OPENSSL_VERSION_NUMBER >= 0x0090800fL +#ifndef OPENSSL_NO_ECDH + validCall = true; +#endif +#endif + if (!validCall) { + throw std::runtime_error("Elliptic curve encryption not allowed"); + } + + EC_KEY* ecdh = nullptr; + int nid; + + /* + * Elliptic-Curve Diffie-Hellman parameters are either "named curves" + * from RFC 4492 section 5.1.1, or explicitly described curves over + * binary fields. OpenSSL only supports the "named curves", which provide + * maximum interoperability. + */ + + nid = OBJ_sn2nid(curveName.c_str()); + if (nid == 0) { + LOG(FATAL) << "Unknown curve name:" << curveName.c_str(); + return; + } + ecdh = EC_KEY_new_by_curve_name(nid); + if (ecdh == nullptr) { + LOG(FATAL) << "Unable to create curve:" << curveName.c_str(); + return; + } + + SSL_CTX_set_tmp_ecdh(ctx_, ecdh); + EC_KEY_free(ecdh); +} + +void SSLContext::setX509VerifyParam( + const ssl::X509VerifyParam& x509VerifyParam) { + if (!x509VerifyParam) { + return; + } + if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) { + throw std::runtime_error("SSL_CTX_set1_param " + getErrors()); + } +} + void SSLContext::setCiphersOrThrow(const std::string& ciphers) { int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str()); - if (ERR_peek_error() != 0) { - throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors()); - } if (rc == 0) { - throw std::runtime_error("None of specified ciphers are supported"); + throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors()); } } @@ -182,10 +268,36 @@ void SSLContext::loadCertificate(const char* path, const char* format) { } } +void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) { + if (cert.data() == nullptr) { + throw std::invalid_argument("loadCertificate: is nullptr"); + } + + ssl::BioUniquePtr bio(BIO_new(BIO_s_mem())); + if (bio == nullptr) { + throw std::runtime_error("BIO_new: " + getErrors()); + } + + int written = BIO_write(bio.get(), cert.data(), cert.size()); + if (written <= 0 || static_cast(written) != cert.size()) { + throw std::runtime_error("BIO_write: " + getErrors()); + } + + ssl::X509UniquePtr x509( + PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); + if (x509 == nullptr) { + throw std::runtime_error("PEM_read_bio_X509: " + getErrors()); + } + + if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) { + throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors()); + } +} + void SSLContext::loadPrivateKey(const char* path, const char* format) { if (path == nullptr || format == nullptr) { throw std::invalid_argument( - "loadPrivateKey: either or is nullptr"); + "loadPrivateKey: either or is nullptr"); } if (strcmp(format, "PEM") == 0) { if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) { @@ -196,10 +308,35 @@ void SSLContext::loadPrivateKey(const char* path, const char* format) { } } +void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) { + if (pkey.data() == nullptr) { + throw std::invalid_argument("loadPrivateKey: is nullptr"); + } + + ssl::BioUniquePtr bio(BIO_new(BIO_s_mem())); + if (bio == nullptr) { + throw std::runtime_error("BIO_new: " + getErrors()); + } + + int written = BIO_write(bio.get(), pkey.data(), pkey.size()); + if (written <= 0 || static_cast(written) != pkey.size()) { + throw std::runtime_error("BIO_write: " + getErrors()); + } + + ssl::EvpPkeyUniquePtr key( + PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr)); + if (key == nullptr) { + throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors()); + } + + if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) { + throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors()); + } +} + void SSLContext::loadTrustedCertificates(const char* path) { if (path == nullptr) { - throw std::invalid_argument( - "loadTrustedCertificates: is nullptr"); + throw std::invalid_argument("loadTrustedCertificates: is nullptr"); } if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) { throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors()); @@ -282,20 +419,45 @@ int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) { void SSLContext::switchCiphersIfTLS11( SSL* ssl, - const std::string& tls11CipherString) { - - CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers"; + const std::string& tls11CipherString, + const std::vector>& tls11AltCipherlist) { + CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty())) + << "Shouldn't call if empty ciphers / alt ciphers"; if (TLS1_get_client_version(ssl) <= TLS1_VERSION) { // We only do this for TLS v 1.1 and later return; } + const std::string* ciphers = &tls11CipherString; + if (!tls11AltCipherlist.empty()) { + if (!cipherListPicker_) { + std::vector weights; + std::for_each( + tls11AltCipherlist.begin(), + tls11AltCipherlist.end(), + [&](const std::pair& e) { + weights.push_back(e.second); + }); + cipherListPicker_.reset( + new std::discrete_distribution(weights.begin(), weights.end())); + } + auto rng = ThreadLocalPRNG(); + auto index = (*cipherListPicker_)(rng); + if ((size_t)index >= tls11AltCipherlist.size()) { + LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index + << ", but tls11AltCipherlist is of length " + << tls11AltCipherlist.size(); + } else { + ciphers = &tls11AltCipherlist[index].first; + } + } + // Prefer AES for TLS versions 1.1 and later since these are not // vulnerable to BEAST attacks on AES. Note that we're setting the // cipher list on the SSL object, not the SSL_CTX object, so it will // only last for this request. - int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str()); + int rc = SSL_set_cipher_list(ssl, ciphers->c_str()); if ((rc == 0) || ERR_peek_error() != 0) { // This shouldn't happen since we checked for this when proxygen // started up. @@ -305,13 +467,43 @@ void SSLContext::switchCiphersIfTLS11( } #endif +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) +int SSLContext::alpnSelectCallback(SSL* /* ssl */, + const unsigned char** out, + unsigned char* outlen, + const unsigned char* in, + unsigned int inlen, + void* data) { + SSLContext* context = (SSLContext*)data; + CHECK(context); + if (context->advertisedNextProtocols_.empty()) { + *out = nullptr; + *outlen = 0; + } else { + auto i = context->pickNextProtocols(); + const auto& item = context->advertisedNextProtocols_[i]; + if (SSL_select_next_proto((unsigned char**)out, + outlen, + item.protocols, + item.length, + in, + inlen) != OPENSSL_NPN_NEGOTIATED) { + return SSL_TLSEXT_ERR_NOACK; + } + } + return SSL_TLSEXT_ERR_OK; +} +#endif + #ifdef OPENSSL_NPN_NEGOTIATED -bool SSLContext::setAdvertisedNextProtocols(const std::list& protocols) { - return setRandomizedAdvertisedNextProtocols({{1, protocols}}); + +bool SSLContext::setAdvertisedNextProtocols( + const std::list& protocols, NextProtocolType protocolType) { + return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType); } bool SSLContext::setRandomizedAdvertisedNextProtocols( - const std::list& items) { + const std::list& items, NextProtocolType protocolType) { unsetNextProtocols(); if (items.size() == 0) { return false; @@ -344,20 +536,30 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols( dst += protoLength; } total_weight += item.weight; - advertised_item.probability = item.weight; advertisedNextProtocols_.push_back(advertised_item); + advertisedNextProtocolWeights_.push_back(item.weight); } if (total_weight == 0) { deleteNextProtocolsStrings(); return false; } - for (auto &advertised_item : advertisedNextProtocols_) { - advertised_item.probability /= total_weight; + nextProtocolDistribution_ = + std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(), + advertisedNextProtocolWeights_.end()); + if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) { + SSL_CTX_set_next_protos_advertised_cb( + ctx_, advertisedNextProtocolCallback, this); + SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this); + } +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) + if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) { + SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this); + // Client cannot really use randomized alpn + SSL_CTX_set_alpn_protos(ctx_, + advertisedNextProtocols_[0].protocols, + advertisedNextProtocols_[0].length); } - SSL_CTX_set_next_protos_advertised_cb( - ctx_, advertisedNextProtocolCallback, this); - SSL_CTX_set_next_proto_select_cb( - ctx_, selectNextProtocolCallback, this); +#endif return true; } @@ -366,12 +568,23 @@ void SSLContext::deleteNextProtocolsStrings() { delete[] protocols.protocols; } advertisedNextProtocols_.clear(); + advertisedNextProtocolWeights_.clear(); } void SSLContext::unsetNextProtocols() { deleteNextProtocolsStrings(); SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr); SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr); +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) + SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr); + SSL_CTX_set_alpn_protos(ctx_, nullptr, 0); +#endif +} + +size_t SSLContext::pickNextProtocols() { + CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols"; + auto rng = ThreadLocalPRNG(); + return nextProtocolDistribution_(rng); } int SSLContext::advertisedNextProtocolCallback(SSL* ssl, @@ -391,135 +604,45 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl, *out = context->advertisedNextProtocols_[selected_index].protocols; *outlen = context->advertisedNextProtocols_[selected_index].length; } else { - unsigned char random_byte; - RAND_bytes(&random_byte, 1); - double random_value = random_byte / 255.0; - double sum = 0; - for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) { - sum += context->advertisedNextProtocols_[i].probability; - if (sum < random_value && - i + 1 < context->advertisedNextProtocols_.size()) { - continue; - } - uintptr_t selected = i + 1; - SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected); - *out = context->advertisedNextProtocols_[i].protocols; - *outlen = context->advertisedNextProtocols_[i].length; - break; - } + auto i = context->pickNextProtocols(); + uintptr_t selected = i + 1; + SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected); + *out = context->advertisedNextProtocols_[i].protocols; + *outlen = context->advertisedNextProtocols_[i].length; } } return SSL_TLSEXT_ERR_OK; } -#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \ - FOLLY_SSLCONTEXT_USE_TLS_FALSE_START -SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() : - /** - * The list was generated as follows: - * grep "_CK_" openssl-1.0.1e/ssl/s3_lib.c -A 4 | - * while read A && read B && read C && read D && read E && read F; do - * echo $A $B $C $D $E; done | - * grep "\(SSL_kDHr\|SSL_kDHd\|SSL_kEDH\|SSL_kECDHr\| - * SSL_kECDHe\|SSL_kEECDH\)" | grep -v SSL_aNULL | grep SSL_AES | - * awk -F, '{ print $1"," }' - */ - ciphers_{ - TLS1_CK_DH_DSS_WITH_AES_128_SHA, - TLS1_CK_DH_RSA_WITH_AES_128_SHA, - TLS1_CK_DHE_DSS_WITH_AES_128_SHA, - TLS1_CK_DHE_RSA_WITH_AES_128_SHA, - TLS1_CK_DH_DSS_WITH_AES_256_SHA, - TLS1_CK_DH_RSA_WITH_AES_256_SHA, - TLS1_CK_DHE_DSS_WITH_AES_256_SHA, - TLS1_CK_DHE_RSA_WITH_AES_256_SHA, - TLS1_CK_DH_DSS_WITH_AES_128_SHA256, - TLS1_CK_DH_RSA_WITH_AES_128_SHA256, - TLS1_CK_DHE_DSS_WITH_AES_128_SHA256, - TLS1_CK_DHE_RSA_WITH_AES_128_SHA256, - TLS1_CK_DH_DSS_WITH_AES_256_SHA256, - TLS1_CK_DH_RSA_WITH_AES_256_SHA256, - TLS1_CK_DHE_DSS_WITH_AES_256_SHA256, - TLS1_CK_DHE_RSA_WITH_AES_256_SHA256, - TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256, - TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384, - TLS1_CK_DH_RSA_WITH_AES_128_GCM_SHA256, - TLS1_CK_DH_RSA_WITH_AES_256_GCM_SHA384, - TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256, - TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384, - TLS1_CK_DH_DSS_WITH_AES_128_GCM_SHA256, - TLS1_CK_DH_DSS_WITH_AES_256_GCM_SHA384, - TLS1_CK_ECDH_ECDSA_WITH_AES_128_CBC_SHA, - TLS1_CK_ECDH_ECDSA_WITH_AES_256_CBC_SHA, - TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - TLS1_CK_ECDH_RSA_WITH_AES_128_CBC_SHA, - TLS1_CK_ECDH_RSA_WITH_AES_256_CBC_SHA, - TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA, - TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA, - TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256, - TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384, - TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256, - TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384, - TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256, - TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384, - TLS1_CK_ECDH_RSA_WITH_AES_128_SHA256, - TLS1_CK_ECDH_RSA_WITH_AES_256_SHA384, - TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - TLS1_CK_ECDH_ECDSA_WITH_AES_128_GCM_SHA256, - TLS1_CK_ECDH_ECDSA_WITH_AES_256_GCM_SHA384, - TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - TLS1_CK_ECDH_RSA_WITH_AES_128_GCM_SHA256, - } { - length_ = sizeof(ciphers_)/sizeof(ciphers_[0]); - width_ = sizeof(ciphers_[0]); - qsort(ciphers_, length_, width_, compare_ulong); -} - -bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher( - const SSL_CIPHER *cipher) { - unsigned long cid = cipher->id; - unsigned long *r = - (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong); - return r != nullptr; -} - -int -SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) { - if (*(unsigned long *)x < *(unsigned long *)y) { - return -1; - } - if (*(unsigned long *)x > *(unsigned long *)y) { - return 1; - } - return 0; -}; - -bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) { - return falseStartChecker_.canUseFalseStartWithCipher(cipher); -} -#endif - -int SSLContext::selectNextProtocolCallback( - SSL* ssl, unsigned char **out, unsigned char *outlen, - const unsigned char *server, unsigned int server_len, void *data) { - +int SSLContext::selectNextProtocolCallback(SSL* ssl, + unsigned char** out, + unsigned char* outlen, + const unsigned char* server, + unsigned int server_len, + void* data) { + (void)ssl; // Make -Wunused-parameters happy SSLContext* ctx = (SSLContext*)data; if (ctx->advertisedNextProtocols_.size() > 1) { VLOG(3) << "SSLContext::selectNextProcolCallback() " << "client should be deterministic in selecting protocols."; } - unsigned char *client; - int client_len; - if (ctx->advertisedNextProtocols_.empty()) { - client = (unsigned char *) ""; - client_len = 0; - } else { - client = ctx->advertisedNextProtocols_[0].protocols; - client_len = ctx->advertisedNextProtocols_[0].length; + unsigned char* client = nullptr; + unsigned int client_len = 0; + bool filtered = false; + auto cpf = ctx->getClientProtocolFilterCallback(); + if (cpf) { + filtered = (*cpf)(&client, &client_len, server, server_len); + } + + if (!filtered) { + if (ctx->advertisedNextProtocols_.empty()) { + client = (unsigned char *) ""; + client_len = 0; + } else { + client = ctx->advertisedNextProtocols_[0].protocols; + client_len = ctx->advertisedNextProtocols_[0].length; + } } int retval = SSL_select_next_proto(out, outlen, server, server_len, @@ -527,14 +650,6 @@ int SSLContext::selectNextProtocolCallback( if (retval != OPENSSL_NPN_NEGOTIATED) { VLOG(3) << "SSLContext::selectNextProcolCallback() " << "unable to pick a next protocol."; -#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \ - FOLLY_SSLCONTEXT_USE_TLS_FALSE_START - } else { - const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher; - if (cipher && ctx->canUseFalseStartWithCipher(cipher)) { - SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH); - } -#endif } return SSL_TLSEXT_ERR_OK; } @@ -548,6 +663,14 @@ SSL* SSLContext::createSSL() const { return ssl; } +void SSLContext::setSessionCacheContext(const std::string& context) { + SSL_CTX_set_session_id_context( + ctx_, + reinterpret_cast(context.data()), + std::min( + static_cast(context.length()), SSL_MAX_SSL_SESSION_ID_LENGTH)); +} + /** * Match a name with a pattern. The pattern may include wildcard. A single * wildcard "*" can match up to one component in the domain name. @@ -657,6 +780,8 @@ static unsigned long callbackThreadID() { return static_cast( #ifdef __APPLE__ pthread_mach_thread_np(pthread_self()) +#elif _MSC_VER + pthread_getw32threadid_np(pthread_self()) #else pthread_self() #endif @@ -687,6 +812,12 @@ void SSLContext::setSSLLockTypes(std::map inLockTypes) { lockTypes() = inLockTypes; } +#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) +void SSLContext::enableFalseStart() { + SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH); +} +#endif + void SSLContext::markInitialized() { std::lock_guard g(initMutex()); initialized_ = true; @@ -782,83 +913,4 @@ operator<<(std::ostream& os, const PasswordCollector& collector) { return os; } -bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx, - sockaddr_storage* addrStorage, - socklen_t* addrLen) { - // Grab the ssl idx and then the ssl object so that we can get the peer - // name to compare against the ips in the subjectAltName - auto sslIdx = SSL_get_ex_data_X509_STORE_CTX_idx(); - auto ssl = - reinterpret_cast(X509_STORE_CTX_get_ex_data(ctx, sslIdx)); - int fd = SSL_get_fd(ssl); - if (fd < 0) { - LOG(ERROR) << "Inexplicably couldn't get fd from SSL"; - return false; - } - - *addrLen = sizeof(*addrStorage); - if (getpeername(fd, reinterpret_cast(addrStorage), addrLen) != 0) { - PLOG(ERROR) << "Unable to get peer name"; - return false; - } - CHECK(*addrLen <= sizeof(*addrStorage)); - return true; -} - -bool OpenSSLUtils::validatePeerCertNames(X509* cert, - const sockaddr* addr, - socklen_t addrLen) { - // Try to extract the names within the SAN extension from the certificate - auto altNames = - reinterpret_cast( - X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr)); - SCOPE_EXIT { - if (altNames != nullptr) { - sk_GENERAL_NAME_pop_free(altNames, GENERAL_NAME_free); - } - }; - if (altNames == nullptr) { - LOG(WARNING) << "No subjectAltName provided and we only support ip auth"; - return false; - } - - const sockaddr_in* addr4 = nullptr; - const sockaddr_in6* addr6 = nullptr; - if (addr != nullptr) { - if (addr->sa_family == AF_INET) { - addr4 = reinterpret_cast(addr); - } else if (addr->sa_family == AF_INET6) { - addr6 = reinterpret_cast(addr); - } else { - LOG(FATAL) << "Unsupported sockaddr family: " << addr->sa_family; - } - } - - - for (int i = 0; i < sk_GENERAL_NAME_num(altNames); i++) { - auto name = sk_GENERAL_NAME_value(altNames, i); - if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) { - // Extra const-ness for paranoia - unsigned char const * const rawIpStr = name->d.iPAddress->data; - int const rawIpLen = name->d.iPAddress->length; - - if (rawIpLen == 4 && addr4 != nullptr) { - if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) { - return true; - } - } else if (rawIpLen == 16 && addr6 != nullptr) { - if (::memcmp(rawIpStr, &addr6->sin6_addr, rawIpLen) == 0) { - return true; - } - } else if (rawIpLen != 4 && rawIpLen != 16) { - LOG(WARNING) << "Unexpected IP length: " << rawIpLen; - } - } - } - - LOG(WARNING) << "Unable to match client cert against alt name ip"; - return false; -} - - } // folly