Adding OpenSSLPtrTypes.h.
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/Memory.h>
26 #include <folly/SpinLock.h>
27 #include <folly/io/async/OpenSSLPtrTypes.h>
28
29 // ---------------------------------------------------------------------
30 // SSLContext implementation
31 // ---------------------------------------------------------------------
32
33 struct CRYPTO_dynlock_value {
34   std::mutex mutex;
35 };
36
37 namespace folly {
38
39 bool SSLContext::initialized_ = false;
40
41 namespace {
42
43 std::mutex& initMutex() {
44   static std::mutex m;
45   return m;
46 }
47
48 inline void BIO_free_fb(BIO* bio) { CHECK_EQ(1, BIO_free(bio)); }
49 using BIO_deleter = folly::static_function_deleter<BIO, &BIO_free_fb>;
50
51 } // anonymous namespace
52
53 #ifdef OPENSSL_NPN_NEGOTIATED
54 int SSLContext::sNextProtocolsExDataIndex_ = -1;
55 #endif
56
57 // SSLContext implementation
58 SSLContext::SSLContext(SSLVersion version) {
59   {
60     std::lock_guard<std::mutex> g(initMutex());
61     initializeOpenSSLLocked();
62   }
63
64   ctx_ = SSL_CTX_new(SSLv23_method());
65   if (ctx_ == nullptr) {
66     throw std::runtime_error("SSL_CTX_new: " + getErrors());
67   }
68
69   int opt = 0;
70   switch (version) {
71     case TLSv1:
72       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
73       break;
74     case SSLv3:
75       opt = SSL_OP_NO_SSLv2;
76       break;
77     default:
78       // do nothing
79       break;
80   }
81   int newOpt = SSL_CTX_set_options(ctx_, opt);
82   DCHECK((newOpt & opt) == opt);
83
84   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
85
86   checkPeerName_ = false;
87
88 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
89   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
90   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
91 #endif
92
93 #ifdef OPENSSL_NPN_NEGOTIATED
94   Random::seed(nextProtocolPicker_);
95 #endif
96 }
97
98 SSLContext::~SSLContext() {
99   if (ctx_ != nullptr) {
100     SSL_CTX_free(ctx_);
101     ctx_ = nullptr;
102   }
103
104 #ifdef OPENSSL_NPN_NEGOTIATED
105   deleteNextProtocolsStrings();
106 #endif
107 }
108
109 void SSLContext::ciphers(const std::string& ciphers) {
110   providedCiphersString_ = ciphers;
111   setCiphersOrThrow(ciphers);
112 }
113
114 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
115   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
116   if (ERR_peek_error() != 0) {
117     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
118   }
119   if (rc == 0) {
120     throw std::runtime_error("None of specified ciphers are supported");
121   }
122 }
123
124 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
125     verifyPeer) {
126   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
127   verifyPeer_ = verifyPeer;
128 }
129
130 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
131     verifyPeer) {
132   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
133   int mode = SSL_VERIFY_NONE;
134   switch(verifyPeer) {
135     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
136     // break;
137
138     case SSLVerifyPeerEnum::VERIFY:
139       mode = SSL_VERIFY_PEER;
140       break;
141
142     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
143       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
144       break;
145
146     case SSLVerifyPeerEnum::NO_VERIFY:
147       mode = SSL_VERIFY_NONE;
148       break;
149
150     default:
151       break;
152   }
153   return mode;
154 }
155
156 int SSLContext::getVerificationMode() {
157   return getVerificationMode(verifyPeer_);
158 }
159
160 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
161                               const std::string& peerName) {
162   int mode;
163   if (checkPeerCert) {
164     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
165     checkPeerName_ = checkPeerName;
166     peerFixedName_ = peerName;
167   } else {
168     mode = SSL_VERIFY_NONE;
169     checkPeerName_ = false; // can't check name without cert!
170     peerFixedName_.clear();
171   }
172   SSL_CTX_set_verify(ctx_, mode, nullptr);
173 }
174
175 void SSLContext::loadCertificate(const char* path, const char* format) {
176   if (path == nullptr || format == nullptr) {
177     throw std::invalid_argument(
178          "loadCertificateChain: either <path> or <format> is nullptr");
179   }
180   if (strcmp(format, "PEM") == 0) {
181     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
182       int errnoCopy = errno;
183       std::string reason("SSL_CTX_use_certificate_chain_file: ");
184       reason.append(path);
185       reason.append(": ");
186       reason.append(getErrors(errnoCopy));
187       throw std::runtime_error(reason);
188     }
189   } else {
190     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
191   }
192 }
193
194 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
195   if (cert.data() == nullptr) {
196     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
197   }
198
199   std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
200   if (bio == nullptr) {
201     throw std::runtime_error("BIO_new: " + getErrors());
202   }
203
204   int written = BIO_write(bio.get(), cert.data(), cert.size());
205   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
206     throw std::runtime_error("BIO_write: " + getErrors());
207   }
208
209   X509_UniquePtr x509(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
210   if (x509 == nullptr) {
211     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
212   }
213
214   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
215     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
216   }
217 }
218
219 void SSLContext::loadPrivateKey(const char* path, const char* format) {
220   if (path == nullptr || format == nullptr) {
221     throw std::invalid_argument(
222         "loadPrivateKey: either <path> or <format> is nullptr");
223   }
224   if (strcmp(format, "PEM") == 0) {
225     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
226       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
227     }
228   } else {
229     throw std::runtime_error("Unsupported private key format: " + std::string(format));
230   }
231 }
232
233 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
234   if (pkey.data() == nullptr) {
235     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
236   }
237
238   std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
239   if (bio == nullptr) {
240     throw std::runtime_error("BIO_new: " + getErrors());
241   }
242
243   int written = BIO_write(bio.get(), pkey.data(), pkey.size());
244   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
245     throw std::runtime_error("BIO_write: " + getErrors());
246   }
247
248   EVP_PKEY_UniquePtr key(
249       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
250   if (key == nullptr) {
251     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
252   }
253
254   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
255     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
256   }
257 }
258
259 void SSLContext::loadTrustedCertificates(const char* path) {
260   if (path == nullptr) {
261     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
262   }
263   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
264     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
265   }
266 }
267
268 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
269   SSL_CTX_set_cert_store(ctx_, store);
270 }
271
272 void SSLContext::loadClientCAList(const char* path) {
273   auto clientCAs = SSL_load_client_CA_file(path);
274   if (clientCAs == nullptr) {
275     LOG(ERROR) << "Unable to load ca file: " << path;
276     return;
277   }
278   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
279 }
280
281 void SSLContext::randomize() {
282   RAND_poll();
283 }
284
285 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
286   if (collector == nullptr) {
287     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
288     return;
289   }
290   collector_ = collector;
291   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
292   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
293 }
294
295 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
296
297 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
298   serverNameCb_ = cb;
299 }
300
301 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
302   clientHelloCbs_.push_back(cb);
303 }
304
305 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
306   SSLContext* context = (SSLContext*)data;
307
308   if (context == nullptr) {
309     return SSL_TLSEXT_ERR_NOACK;
310   }
311
312   for (auto& cb : context->clientHelloCbs_) {
313     // Generic callbacks to happen after we receive the Client Hello.
314     // For example, we use one to switch which cipher we use depending
315     // on the user's TLS version.  Because the primary purpose of
316     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
317     // are side-uses, we ignore any possible failures other than just logging
318     // them.
319     cb(ssl);
320   }
321
322   if (!context->serverNameCb_) {
323     return SSL_TLSEXT_ERR_NOACK;
324   }
325
326   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
327   switch (ret) {
328     case SERVER_NAME_FOUND:
329       return SSL_TLSEXT_ERR_OK;
330     case SERVER_NAME_NOT_FOUND:
331       return SSL_TLSEXT_ERR_NOACK;
332     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
333       *al = TLS1_AD_UNRECOGNIZED_NAME;
334       return SSL_TLSEXT_ERR_ALERT_FATAL;
335     default:
336       CHECK(false);
337   }
338
339   return SSL_TLSEXT_ERR_NOACK;
340 }
341
342 void SSLContext::switchCiphersIfTLS11(
343     SSL* ssl,
344     const std::string& tls11CipherString) {
345
346   CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
347
348   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
349     // We only do this for TLS v 1.1 and later
350     return;
351   }
352
353   // Prefer AES for TLS versions 1.1 and later since these are not
354   // vulnerable to BEAST attacks on AES.  Note that we're setting the
355   // cipher list on the SSL object, not the SSL_CTX object, so it will
356   // only last for this request.
357   int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
358   if ((rc == 0) || ERR_peek_error() != 0) {
359     // This shouldn't happen since we checked for this when proxygen
360     // started up.
361     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
362     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
363   }
364 }
365 #endif
366
367 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
368 int SSLContext::alpnSelectCallback(SSL* ssl,
369                                    const unsigned char** out,
370                                    unsigned char* outlen,
371                                    const unsigned char* in,
372                                    unsigned int inlen,
373                                    void* data) {
374   SSLContext* context = (SSLContext*)data;
375   CHECK(context);
376   if (context->advertisedNextProtocols_.empty()) {
377     *out = nullptr;
378     *outlen = 0;
379   } else {
380     auto i = context->pickNextProtocols();
381     const auto& item = context->advertisedNextProtocols_[i];
382     if (SSL_select_next_proto((unsigned char**)out,
383                               outlen,
384                               item.protocols,
385                               item.length,
386                               in,
387                               inlen) != OPENSSL_NPN_NEGOTIATED) {
388       return SSL_TLSEXT_ERR_NOACK;
389     }
390   }
391   return SSL_TLSEXT_ERR_OK;
392 }
393 #endif
394
395 #ifdef OPENSSL_NPN_NEGOTIATED
396
397 bool SSLContext::setAdvertisedNextProtocols(
398     const std::list<std::string>& protocols, NextProtocolType protocolType) {
399   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
400 }
401
402 bool SSLContext::setRandomizedAdvertisedNextProtocols(
403     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
404   unsetNextProtocols();
405   if (items.size() == 0) {
406     return false;
407   }
408   int total_weight = 0;
409   for (const auto &item : items) {
410     if (item.protocols.size() == 0) {
411       continue;
412     }
413     AdvertisedNextProtocolsItem advertised_item;
414     advertised_item.length = 0;
415     for (const auto& proto : item.protocols) {
416       ++advertised_item.length;
417       unsigned protoLength = proto.length();
418       if (protoLength >= 256) {
419         deleteNextProtocolsStrings();
420         return false;
421       }
422       advertised_item.length += protoLength;
423     }
424     advertised_item.protocols = new unsigned char[advertised_item.length];
425     if (!advertised_item.protocols) {
426       throw std::runtime_error("alloc failure");
427     }
428     unsigned char* dst = advertised_item.protocols;
429     for (auto& proto : item.protocols) {
430       unsigned protoLength = proto.length();
431       *dst++ = (unsigned char)protoLength;
432       memcpy(dst, proto.data(), protoLength);
433       dst += protoLength;
434     }
435     total_weight += item.weight;
436     advertisedNextProtocols_.push_back(advertised_item);
437     advertisedNextProtocolWeights_.push_back(item.weight);
438   }
439   if (total_weight == 0) {
440     deleteNextProtocolsStrings();
441     return false;
442   }
443   nextProtocolDistribution_ =
444       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
445                                    advertisedNextProtocolWeights_.end());
446   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
447     SSL_CTX_set_next_protos_advertised_cb(
448         ctx_, advertisedNextProtocolCallback, this);
449     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
450   }
451 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
452   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
453     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
454     // Client cannot really use randomized alpn
455     SSL_CTX_set_alpn_protos(ctx_,
456                             advertisedNextProtocols_[0].protocols,
457                             advertisedNextProtocols_[0].length);
458   }
459 #endif
460   return true;
461 }
462
463 void SSLContext::deleteNextProtocolsStrings() {
464   for (auto protocols : advertisedNextProtocols_) {
465     delete[] protocols.protocols;
466   }
467   advertisedNextProtocols_.clear();
468   advertisedNextProtocolWeights_.clear();
469 }
470
471 void SSLContext::unsetNextProtocols() {
472   deleteNextProtocolsStrings();
473   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
474   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
475 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
476   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
477   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
478 #endif
479 }
480
481 size_t SSLContext::pickNextProtocols() {
482   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
483   return nextProtocolDistribution_(nextProtocolPicker_);
484 }
485
486 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
487       const unsigned char** out, unsigned int* outlen, void* data) {
488   SSLContext* context = (SSLContext*)data;
489   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
490     *out = nullptr;
491     *outlen = 0;
492   } else if (context->advertisedNextProtocols_.size() == 1) {
493     *out = context->advertisedNextProtocols_[0].protocols;
494     *outlen = context->advertisedNextProtocols_[0].length;
495   } else {
496     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
497           sNextProtocolsExDataIndex_));
498     if (selected_index) {
499       --selected_index;
500       *out = context->advertisedNextProtocols_[selected_index].protocols;
501       *outlen = context->advertisedNextProtocols_[selected_index].length;
502     } else {
503       auto i = context->pickNextProtocols();
504       uintptr_t selected = i + 1;
505       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
506       *out = context->advertisedNextProtocols_[i].protocols;
507       *outlen = context->advertisedNextProtocols_[i].length;
508     }
509   }
510   return SSL_TLSEXT_ERR_OK;
511 }
512
513 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
514   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
515 SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
516   ciphers_{
517     TLS1_CK_DHE_DSS_WITH_AES_128_SHA,
518     TLS1_CK_DHE_RSA_WITH_AES_128_SHA,
519     TLS1_CK_DHE_DSS_WITH_AES_256_SHA,
520     TLS1_CK_DHE_RSA_WITH_AES_256_SHA,
521     TLS1_CK_DHE_DSS_WITH_AES_128_SHA256,
522     TLS1_CK_DHE_RSA_WITH_AES_128_SHA256,
523     TLS1_CK_DHE_DSS_WITH_AES_256_SHA256,
524     TLS1_CK_DHE_RSA_WITH_AES_256_SHA256,
525     TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256,
526     TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384,
527     TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256,
528     TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384,
529     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
530     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
531     TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA,
532     TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA,
533     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256,
534     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384,
535     TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256,
536     TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384,
537     TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256,
538     TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384,
539     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
540     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
541     TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
542     TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
543   } {
544   length_ = sizeof(ciphers_)/sizeof(ciphers_[0]);
545   width_ = sizeof(ciphers_[0]);
546   qsort(ciphers_, length_, width_, compare_ulong);
547 }
548
549 bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher(
550   const SSL_CIPHER *cipher) {
551   unsigned long cid = cipher->id;
552   unsigned long *r =
553     (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong);
554   return r != nullptr;
555 }
556
557 int
558 SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) {
559   if (*(unsigned long *)x < *(unsigned long *)y) {
560     return -1;
561   }
562   if (*(unsigned long *)x > *(unsigned long *)y) {
563     return 1;
564   }
565   return 0;
566 };
567
568 bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) {
569   return falseStartChecker_.canUseFalseStartWithCipher(cipher);
570 }
571 #endif
572
573 int SSLContext::selectNextProtocolCallback(
574   SSL* ssl, unsigned char **out, unsigned char *outlen,
575   const unsigned char *server, unsigned int server_len, void *data) {
576
577   SSLContext* ctx = (SSLContext*)data;
578   if (ctx->advertisedNextProtocols_.size() > 1) {
579     VLOG(3) << "SSLContext::selectNextProcolCallback() "
580             << "client should be deterministic in selecting protocols.";
581   }
582
583   unsigned char *client;
584   unsigned int client_len;
585   bool filtered = false;
586   auto cpf = ctx->getClientProtocolFilterCallback();
587   if (cpf) {
588     filtered = (*cpf)(&client, &client_len, server, server_len);
589   }
590
591   if (!filtered) {
592     if (ctx->advertisedNextProtocols_.empty()) {
593       client = (unsigned char *) "";
594       client_len = 0;
595     } else {
596       client = ctx->advertisedNextProtocols_[0].protocols;
597       client_len = ctx->advertisedNextProtocols_[0].length;
598     }
599   }
600
601   int retval = SSL_select_next_proto(out, outlen, server, server_len,
602                                      client, client_len);
603   if (retval != OPENSSL_NPN_NEGOTIATED) {
604     VLOG(3) << "SSLContext::selectNextProcolCallback() "
605             << "unable to pick a next protocol.";
606 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
607   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
608   } else {
609     const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher;
610     if (cipher && ctx->canUseFalseStartWithCipher(cipher)) {
611       SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH);
612     }
613 #endif
614   }
615   return SSL_TLSEXT_ERR_OK;
616 }
617 #endif // OPENSSL_NPN_NEGOTIATED
618
619 SSL* SSLContext::createSSL() const {
620   SSL* ssl = SSL_new(ctx_);
621   if (ssl == nullptr) {
622     throw std::runtime_error("SSL_new: " + getErrors());
623   }
624   return ssl;
625 }
626
627 /**
628  * Match a name with a pattern. The pattern may include wildcard. A single
629  * wildcard "*" can match up to one component in the domain name.
630  *
631  * @param  host    Host name, typically the name of the remote host
632  * @param  pattern Name retrieved from certificate
633  * @param  size    Size of "pattern"
634  * @return True, if "host" matches "pattern". False otherwise.
635  */
636 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
637   bool match = false;
638   int i = 0, j = 0;
639   while (i < size && host[j] != '\0') {
640     if (toupper(pattern[i]) == toupper(host[j])) {
641       i++;
642       j++;
643       continue;
644     }
645     if (pattern[i] == '*') {
646       while (host[j] != '.' && host[j] != '\0') {
647         j++;
648       }
649       i++;
650       continue;
651     }
652     break;
653   }
654   if (i == size && host[j] == '\0') {
655     match = true;
656   }
657   return match;
658 }
659
660 int SSLContext::passwordCallback(char* password,
661                                  int size,
662                                  int,
663                                  void* data) {
664   SSLContext* context = (SSLContext*)data;
665   if (context == nullptr || context->passwordCollector() == nullptr) {
666     return 0;
667   }
668   std::string userPassword;
669   // call user defined password collector to get password
670   context->passwordCollector()->getPassword(userPassword, size);
671   int length = userPassword.size();
672   if (length > size) {
673     length = size;
674   }
675   strncpy(password, userPassword.c_str(), length);
676   return length;
677 }
678
679 struct SSLLock {
680   explicit SSLLock(
681     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
682       lockType(inLockType) {
683   }
684
685   void lock() {
686     if (lockType == SSLContext::LOCK_MUTEX) {
687       mutex.lock();
688     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
689       spinLock.lock();
690     }
691     // lockType == LOCK_NONE, no-op
692   }
693
694   void unlock() {
695     if (lockType == SSLContext::LOCK_MUTEX) {
696       mutex.unlock();
697     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
698       spinLock.unlock();
699     }
700     // lockType == LOCK_NONE, no-op
701   }
702
703   SSLContext::SSLLockType lockType;
704   folly::SpinLock spinLock{};
705   std::mutex mutex;
706 };
707
708 // Statics are unsafe in environments that call exit().
709 // If one thread calls exit() while another thread is
710 // references a member of SSLContext, bad things can happen.
711 // SSLContext runs in such environments.
712 // Instead of declaring a static member we "new" the static
713 // member so that it won't be destructed on exit().
714 static std::unique_ptr<SSLLock[]>& locks() {
715   static auto locksInst = new std::unique_ptr<SSLLock[]>();
716   return *locksInst;
717 }
718
719 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
720   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
721   return *lockTypesInst;
722 }
723
724 static void callbackLocking(int mode, int n, const char*, int) {
725   if (mode & CRYPTO_LOCK) {
726     locks()[n].lock();
727   } else {
728     locks()[n].unlock();
729   }
730 }
731
732 static unsigned long callbackThreadID() {
733   return static_cast<unsigned long>(
734 #ifdef __APPLE__
735     pthread_mach_thread_np(pthread_self())
736 #else
737     pthread_self()
738 #endif
739   );
740 }
741
742 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
743   return new CRYPTO_dynlock_value;
744 }
745
746 static void dyn_lock(int mode,
747                      struct CRYPTO_dynlock_value* lock,
748                      const char*, int) {
749   if (lock != nullptr) {
750     if (mode & CRYPTO_LOCK) {
751       lock->mutex.lock();
752     } else {
753       lock->mutex.unlock();
754     }
755   }
756 }
757
758 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
759   delete lock;
760 }
761
762 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
763   lockTypes() = inLockTypes;
764 }
765
766 void SSLContext::markInitialized() {
767   std::lock_guard<std::mutex> g(initMutex());
768   initialized_ = true;
769 }
770
771 void SSLContext::initializeOpenSSL() {
772   std::lock_guard<std::mutex> g(initMutex());
773   initializeOpenSSLLocked();
774 }
775
776 void SSLContext::initializeOpenSSLLocked() {
777   if (initialized_) {
778     return;
779   }
780   SSL_library_init();
781   SSL_load_error_strings();
782   ERR_load_crypto_strings();
783   // static locking
784   locks().reset(new SSLLock[::CRYPTO_num_locks()]);
785   for (auto it: lockTypes()) {
786     locks()[it.first].lockType = it.second;
787   }
788   CRYPTO_set_id_callback(callbackThreadID);
789   CRYPTO_set_locking_callback(callbackLocking);
790   // dynamic locking
791   CRYPTO_set_dynlock_create_callback(dyn_create);
792   CRYPTO_set_dynlock_lock_callback(dyn_lock);
793   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
794   randomize();
795 #ifdef OPENSSL_NPN_NEGOTIATED
796   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
797       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
798 #endif
799   initialized_ = true;
800 }
801
802 void SSLContext::cleanupOpenSSL() {
803   std::lock_guard<std::mutex> g(initMutex());
804   cleanupOpenSSLLocked();
805 }
806
807 void SSLContext::cleanupOpenSSLLocked() {
808   if (!initialized_) {
809     return;
810   }
811
812   CRYPTO_set_id_callback(nullptr);
813   CRYPTO_set_locking_callback(nullptr);
814   CRYPTO_set_dynlock_create_callback(nullptr);
815   CRYPTO_set_dynlock_lock_callback(nullptr);
816   CRYPTO_set_dynlock_destroy_callback(nullptr);
817   CRYPTO_cleanup_all_ex_data();
818   ERR_free_strings();
819   EVP_cleanup();
820   ERR_remove_state(0);
821   locks().reset();
822   initialized_ = false;
823 }
824
825 void SSLContext::setOptions(long options) {
826   long newOpt = SSL_CTX_set_options(ctx_, options);
827   if ((newOpt & options) != options) {
828     throw std::runtime_error("SSL_CTX_set_options failed");
829   }
830 }
831
832 std::string SSLContext::getErrors(int errnoCopy) {
833   std::string errors;
834   unsigned long  errorCode;
835   char   message[256];
836
837   errors.reserve(512);
838   while ((errorCode = ERR_get_error()) != 0) {
839     if (!errors.empty()) {
840       errors += "; ";
841     }
842     const char* reason = ERR_reason_error_string(errorCode);
843     if (reason == nullptr) {
844       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
845       reason = message;
846     }
847     errors += reason;
848   }
849   if (errors.empty()) {
850     errors = "error code: " + folly::to<std::string>(errnoCopy);
851   }
852   return errors;
853 }
854
855 std::ostream&
856 operator<<(std::ostream& os, const PasswordCollector& collector) {
857   os << collector.describe();
858   return os;
859 }
860
861 bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx,
862                                                   sockaddr_storage* addrStorage,
863                                                   socklen_t* addrLen) {
864   // Grab the ssl idx and then the ssl object so that we can get the peer
865   // name to compare against the ips in the subjectAltName
866   auto sslIdx = SSL_get_ex_data_X509_STORE_CTX_idx();
867   auto ssl =
868     reinterpret_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, sslIdx));
869   int fd = SSL_get_fd(ssl);
870   if (fd < 0) {
871     LOG(ERROR) << "Inexplicably couldn't get fd from SSL";
872     return false;
873   }
874
875   *addrLen = sizeof(*addrStorage);
876   if (getpeername(fd, reinterpret_cast<sockaddr*>(addrStorage), addrLen) != 0) {
877     PLOG(ERROR) << "Unable to get peer name";
878     return false;
879   }
880   CHECK(*addrLen <= sizeof(*addrStorage));
881   return true;
882 }
883
884 bool OpenSSLUtils::validatePeerCertNames(X509* cert,
885                                          const sockaddr* addr,
886                                          socklen_t addrLen) {
887   // Try to extract the names within the SAN extension from the certificate
888   auto altNames =
889     reinterpret_cast<STACK_OF(GENERAL_NAME)*>(
890         X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
891   SCOPE_EXIT {
892     if (altNames != nullptr) {
893       sk_GENERAL_NAME_pop_free(altNames, GENERAL_NAME_free);
894     }
895   };
896   if (altNames == nullptr) {
897     LOG(WARNING) << "No subjectAltName provided and we only support ip auth";
898     return false;
899   }
900
901   const sockaddr_in* addr4 = nullptr;
902   const sockaddr_in6* addr6 = nullptr;
903   if (addr != nullptr) {
904     if (addr->sa_family == AF_INET) {
905       addr4 = reinterpret_cast<const sockaddr_in*>(addr);
906     } else if (addr->sa_family == AF_INET6) {
907       addr6 = reinterpret_cast<const sockaddr_in6*>(addr);
908     } else {
909       LOG(FATAL) << "Unsupported sockaddr family: " << addr->sa_family;
910     }
911   }
912
913
914   for (int i = 0; i < sk_GENERAL_NAME_num(altNames); i++) {
915     auto name = sk_GENERAL_NAME_value(altNames, i);
916     if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) {
917       // Extra const-ness for paranoia
918       unsigned char const * const rawIpStr = name->d.iPAddress->data;
919       int const rawIpLen = name->d.iPAddress->length;
920
921       if (rawIpLen == 4 && addr4 != nullptr) {
922         if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) {
923           return true;
924         }
925       } else if (rawIpLen == 16 && addr6 != nullptr) {
926         if (::memcmp(rawIpStr, &addr6->sin6_addr, rawIpLen) == 0) {
927           return true;
928         }
929       } else if (rawIpLen != 4 && rawIpLen != 16) {
930         LOG(WARNING) << "Unexpected IP length: " << rawIpLen;
931       }
932     }
933   }
934
935   LOG(WARNING) << "Unable to match client cert against alt name ip";
936   return false;
937 }
938
939
940 } // folly