Methods to change cipher suite list
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/Memory.h>
26 #include <folly/Random.h>
27 #include <folly/SpinLock.h>
28
29 // ---------------------------------------------------------------------
30 // SSLContext implementation
31 // ---------------------------------------------------------------------
32
33 struct CRYPTO_dynlock_value {
34   std::mutex mutex;
35 };
36
37 namespace folly {
38
39 bool SSLContext::initialized_ = false;
40
41 namespace {
42
43 std::mutex& initMutex() {
44   static std::mutex m;
45   return m;
46 }
47
48 } // anonymous namespace
49
50 #ifdef OPENSSL_NPN_NEGOTIATED
51 int SSLContext::sNextProtocolsExDataIndex_ = -1;
52 #endif
53
54 // SSLContext implementation
55 SSLContext::SSLContext(SSLVersion version) {
56   {
57     std::lock_guard<std::mutex> g(initMutex());
58     initializeOpenSSLLocked();
59   }
60
61   ctx_ = SSL_CTX_new(SSLv23_method());
62   if (ctx_ == nullptr) {
63     throw std::runtime_error("SSL_CTX_new: " + getErrors());
64   }
65
66   int opt = 0;
67   switch (version) {
68     case TLSv1:
69       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
70       break;
71     case SSLv3:
72       opt = SSL_OP_NO_SSLv2;
73       break;
74     default:
75       // do nothing
76       break;
77   }
78   int newOpt = SSL_CTX_set_options(ctx_, opt);
79   DCHECK((newOpt & opt) == opt);
80
81   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
82
83   checkPeerName_ = false;
84
85   SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
86
87 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
88   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
89   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
90 #endif
91 }
92
93 SSLContext::~SSLContext() {
94   if (ctx_ != nullptr) {
95     SSL_CTX_free(ctx_);
96     ctx_ = nullptr;
97   }
98
99 #ifdef OPENSSL_NPN_NEGOTIATED
100   deleteNextProtocolsStrings();
101 #endif
102 }
103
104 void SSLContext::ciphers(const std::string& ciphers) {
105   providedCiphersString_ = ciphers;
106   setCiphersOrThrow(ciphers);
107 }
108
109 void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
110   if (ciphers.size() == 0) {
111     return;
112   }
113   std::string opensslCipherList;
114   join(":", ciphers, opensslCipherList);
115   setCiphersOrThrow(opensslCipherList);
116 }
117
118 void SSLContext::setSignatureAlgorithms(
119     const std::vector<std::string>& sigalgs) {
120   if (sigalgs.size() == 0) {
121     return;
122   }
123 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
124   std::string opensslSigAlgsList;
125   join(":", sigalgs, opensslSigAlgsList);
126   int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
127   if (rc == 0) {
128     throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
129   }
130 #endif
131 }
132
133 void SSLContext::setClientECCurvesList(
134     const std::vector<std::string>& ecCurves) {
135   if (ecCurves.size() == 0) {
136     return;
137   }
138 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
139   std::string ecCurvesList;
140   join(":", ecCurves, ecCurvesList);
141   int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
142   if (rc == 0) {
143     throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
144   }
145 #endif
146 }
147
148 void SSLContext::setX509VerifyParam(
149     const ssl::X509VerifyParam& x509VerifyParam) {
150   if (!x509VerifyParam) {
151     return;
152   }
153   if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
154     throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
155   }
156 }
157
158 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
159   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
160   if (rc == 0) {
161     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
162   }
163 }
164
165 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
166     verifyPeer) {
167   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
168   verifyPeer_ = verifyPeer;
169 }
170
171 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
172     verifyPeer) {
173   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
174   int mode = SSL_VERIFY_NONE;
175   switch(verifyPeer) {
176     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
177     // break;
178
179     case SSLVerifyPeerEnum::VERIFY:
180       mode = SSL_VERIFY_PEER;
181       break;
182
183     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
184       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
185       break;
186
187     case SSLVerifyPeerEnum::NO_VERIFY:
188       mode = SSL_VERIFY_NONE;
189       break;
190
191     default:
192       break;
193   }
194   return mode;
195 }
196
197 int SSLContext::getVerificationMode() {
198   return getVerificationMode(verifyPeer_);
199 }
200
201 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
202                               const std::string& peerName) {
203   int mode;
204   if (checkPeerCert) {
205     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
206     checkPeerName_ = checkPeerName;
207     peerFixedName_ = peerName;
208   } else {
209     mode = SSL_VERIFY_NONE;
210     checkPeerName_ = false; // can't check name without cert!
211     peerFixedName_.clear();
212   }
213   SSL_CTX_set_verify(ctx_, mode, nullptr);
214 }
215
216 void SSLContext::loadCertificate(const char* path, const char* format) {
217   if (path == nullptr || format == nullptr) {
218     throw std::invalid_argument(
219          "loadCertificateChain: either <path> or <format> is nullptr");
220   }
221   if (strcmp(format, "PEM") == 0) {
222     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
223       int errnoCopy = errno;
224       std::string reason("SSL_CTX_use_certificate_chain_file: ");
225       reason.append(path);
226       reason.append(": ");
227       reason.append(getErrors(errnoCopy));
228       throw std::runtime_error(reason);
229     }
230   } else {
231     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
232   }
233 }
234
235 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
236   if (cert.data() == nullptr) {
237     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
238   }
239
240   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
241   if (bio == nullptr) {
242     throw std::runtime_error("BIO_new: " + getErrors());
243   }
244
245   int written = BIO_write(bio.get(), cert.data(), cert.size());
246   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
247     throw std::runtime_error("BIO_write: " + getErrors());
248   }
249
250   ssl::X509UniquePtr x509(
251       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
252   if (x509 == nullptr) {
253     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
254   }
255
256   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
257     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
258   }
259 }
260
261 void SSLContext::loadPrivateKey(const char* path, const char* format) {
262   if (path == nullptr || format == nullptr) {
263     throw std::invalid_argument(
264         "loadPrivateKey: either <path> or <format> is nullptr");
265   }
266   if (strcmp(format, "PEM") == 0) {
267     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
268       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
269     }
270   } else {
271     throw std::runtime_error("Unsupported private key format: " + std::string(format));
272   }
273 }
274
275 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
276   if (pkey.data() == nullptr) {
277     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
278   }
279
280   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
281   if (bio == nullptr) {
282     throw std::runtime_error("BIO_new: " + getErrors());
283   }
284
285   int written = BIO_write(bio.get(), pkey.data(), pkey.size());
286   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
287     throw std::runtime_error("BIO_write: " + getErrors());
288   }
289
290   ssl::EvpPkeyUniquePtr key(
291       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
292   if (key == nullptr) {
293     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
294   }
295
296   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
297     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
298   }
299 }
300
301 void SSLContext::loadTrustedCertificates(const char* path) {
302   if (path == nullptr) {
303     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
304   }
305   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
306     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
307   }
308 }
309
310 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
311   SSL_CTX_set_cert_store(ctx_, store);
312 }
313
314 void SSLContext::loadClientCAList(const char* path) {
315   auto clientCAs = SSL_load_client_CA_file(path);
316   if (clientCAs == nullptr) {
317     LOG(ERROR) << "Unable to load ca file: " << path;
318     return;
319   }
320   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
321 }
322
323 void SSLContext::randomize() {
324   RAND_poll();
325 }
326
327 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
328   if (collector == nullptr) {
329     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
330     return;
331   }
332   collector_ = collector;
333   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
334   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
335 }
336
337 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
338
339 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
340   serverNameCb_ = cb;
341 }
342
343 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
344   clientHelloCbs_.push_back(cb);
345 }
346
347 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
348   SSLContext* context = (SSLContext*)data;
349
350   if (context == nullptr) {
351     return SSL_TLSEXT_ERR_NOACK;
352   }
353
354   for (auto& cb : context->clientHelloCbs_) {
355     // Generic callbacks to happen after we receive the Client Hello.
356     // For example, we use one to switch which cipher we use depending
357     // on the user's TLS version.  Because the primary purpose of
358     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
359     // are side-uses, we ignore any possible failures other than just logging
360     // them.
361     cb(ssl);
362   }
363
364   if (!context->serverNameCb_) {
365     return SSL_TLSEXT_ERR_NOACK;
366   }
367
368   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
369   switch (ret) {
370     case SERVER_NAME_FOUND:
371       return SSL_TLSEXT_ERR_OK;
372     case SERVER_NAME_NOT_FOUND:
373       return SSL_TLSEXT_ERR_NOACK;
374     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
375       *al = TLS1_AD_UNRECOGNIZED_NAME;
376       return SSL_TLSEXT_ERR_ALERT_FATAL;
377     default:
378       CHECK(false);
379   }
380
381   return SSL_TLSEXT_ERR_NOACK;
382 }
383
384 void SSLContext::switchCiphersIfTLS11(
385     SSL* ssl,
386     const std::string& tls11CipherString,
387     const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
388   CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
389       << "Shouldn't call if empty ciphers / alt ciphers";
390
391   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
392     // We only do this for TLS v 1.1 and later
393     return;
394   }
395
396   const std::string* ciphers = &tls11CipherString;
397   if (!tls11AltCipherlist.empty()) {
398     if (!cipherListPicker_) {
399       std::vector<int> weights;
400       std::for_each(
401           tls11AltCipherlist.begin(),
402           tls11AltCipherlist.end(),
403           [&](const std::pair<std::string, int>& e) {
404             weights.push_back(e.second);
405           });
406       cipherListPicker_.reset(
407           new std::discrete_distribution<int>(weights.begin(), weights.end()));
408     }
409     auto rng = ThreadLocalPRNG();
410     auto index = (*cipherListPicker_)(rng);
411     if ((size_t)index >= tls11AltCipherlist.size()) {
412       LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
413                  << ", but tls11AltCipherlist is of length "
414                  << tls11AltCipherlist.size();
415     } else {
416       ciphers = &tls11AltCipherlist[index].first;
417     }
418   }
419
420   // Prefer AES for TLS versions 1.1 and later since these are not
421   // vulnerable to BEAST attacks on AES.  Note that we're setting the
422   // cipher list on the SSL object, not the SSL_CTX object, so it will
423   // only last for this request.
424   int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
425   if ((rc == 0) || ERR_peek_error() != 0) {
426     // This shouldn't happen since we checked for this when proxygen
427     // started up.
428     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
429     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
430   }
431 }
432 #endif
433
434 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
435 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
436                                    const unsigned char** out,
437                                    unsigned char* outlen,
438                                    const unsigned char* in,
439                                    unsigned int inlen,
440                                    void* data) {
441   SSLContext* context = (SSLContext*)data;
442   CHECK(context);
443   if (context->advertisedNextProtocols_.empty()) {
444     *out = nullptr;
445     *outlen = 0;
446   } else {
447     auto i = context->pickNextProtocols();
448     const auto& item = context->advertisedNextProtocols_[i];
449     if (SSL_select_next_proto((unsigned char**)out,
450                               outlen,
451                               item.protocols,
452                               item.length,
453                               in,
454                               inlen) != OPENSSL_NPN_NEGOTIATED) {
455       return SSL_TLSEXT_ERR_NOACK;
456     }
457   }
458   return SSL_TLSEXT_ERR_OK;
459 }
460 #endif
461
462 #ifdef OPENSSL_NPN_NEGOTIATED
463
464 bool SSLContext::setAdvertisedNextProtocols(
465     const std::list<std::string>& protocols, NextProtocolType protocolType) {
466   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
467 }
468
469 bool SSLContext::setRandomizedAdvertisedNextProtocols(
470     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
471   unsetNextProtocols();
472   if (items.size() == 0) {
473     return false;
474   }
475   int total_weight = 0;
476   for (const auto &item : items) {
477     if (item.protocols.size() == 0) {
478       continue;
479     }
480     AdvertisedNextProtocolsItem advertised_item;
481     advertised_item.length = 0;
482     for (const auto& proto : item.protocols) {
483       ++advertised_item.length;
484       unsigned protoLength = proto.length();
485       if (protoLength >= 256) {
486         deleteNextProtocolsStrings();
487         return false;
488       }
489       advertised_item.length += protoLength;
490     }
491     advertised_item.protocols = new unsigned char[advertised_item.length];
492     if (!advertised_item.protocols) {
493       throw std::runtime_error("alloc failure");
494     }
495     unsigned char* dst = advertised_item.protocols;
496     for (auto& proto : item.protocols) {
497       unsigned protoLength = proto.length();
498       *dst++ = (unsigned char)protoLength;
499       memcpy(dst, proto.data(), protoLength);
500       dst += protoLength;
501     }
502     total_weight += item.weight;
503     advertisedNextProtocols_.push_back(advertised_item);
504     advertisedNextProtocolWeights_.push_back(item.weight);
505   }
506   if (total_weight == 0) {
507     deleteNextProtocolsStrings();
508     return false;
509   }
510   nextProtocolDistribution_ =
511       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
512                                    advertisedNextProtocolWeights_.end());
513   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
514     SSL_CTX_set_next_protos_advertised_cb(
515         ctx_, advertisedNextProtocolCallback, this);
516     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
517   }
518 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
519   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
520     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
521     // Client cannot really use randomized alpn
522     SSL_CTX_set_alpn_protos(ctx_,
523                             advertisedNextProtocols_[0].protocols,
524                             advertisedNextProtocols_[0].length);
525   }
526 #endif
527   return true;
528 }
529
530 void SSLContext::deleteNextProtocolsStrings() {
531   for (auto protocols : advertisedNextProtocols_) {
532     delete[] protocols.protocols;
533   }
534   advertisedNextProtocols_.clear();
535   advertisedNextProtocolWeights_.clear();
536 }
537
538 void SSLContext::unsetNextProtocols() {
539   deleteNextProtocolsStrings();
540   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
541   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
542 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
543   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
544   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
545 #endif
546 }
547
548 size_t SSLContext::pickNextProtocols() {
549   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
550   auto rng = ThreadLocalPRNG();
551   return nextProtocolDistribution_(rng);
552 }
553
554 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
555       const unsigned char** out, unsigned int* outlen, void* data) {
556   SSLContext* context = (SSLContext*)data;
557   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
558     *out = nullptr;
559     *outlen = 0;
560   } else if (context->advertisedNextProtocols_.size() == 1) {
561     *out = context->advertisedNextProtocols_[0].protocols;
562     *outlen = context->advertisedNextProtocols_[0].length;
563   } else {
564     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
565           sNextProtocolsExDataIndex_));
566     if (selected_index) {
567       --selected_index;
568       *out = context->advertisedNextProtocols_[selected_index].protocols;
569       *outlen = context->advertisedNextProtocols_[selected_index].length;
570     } else {
571       auto i = context->pickNextProtocols();
572       uintptr_t selected = i + 1;
573       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
574       *out = context->advertisedNextProtocols_[i].protocols;
575       *outlen = context->advertisedNextProtocols_[i].length;
576     }
577   }
578   return SSL_TLSEXT_ERR_OK;
579 }
580
581 int SSLContext::selectNextProtocolCallback(SSL* ssl,
582                                            unsigned char** out,
583                                            unsigned char* outlen,
584                                            const unsigned char* server,
585                                            unsigned int server_len,
586                                            void* data) {
587   (void)ssl; // Make -Wunused-parameters happy
588   SSLContext* ctx = (SSLContext*)data;
589   if (ctx->advertisedNextProtocols_.size() > 1) {
590     VLOG(3) << "SSLContext::selectNextProcolCallback() "
591             << "client should be deterministic in selecting protocols.";
592   }
593
594   unsigned char *client;
595   unsigned int client_len;
596   bool filtered = false;
597   auto cpf = ctx->getClientProtocolFilterCallback();
598   if (cpf) {
599     filtered = (*cpf)(&client, &client_len, server, server_len);
600   }
601
602   if (!filtered) {
603     if (ctx->advertisedNextProtocols_.empty()) {
604       client = (unsigned char *) "";
605       client_len = 0;
606     } else {
607       client = ctx->advertisedNextProtocols_[0].protocols;
608       client_len = ctx->advertisedNextProtocols_[0].length;
609     }
610   }
611
612   int retval = SSL_select_next_proto(out, outlen, server, server_len,
613                                      client, client_len);
614   if (retval != OPENSSL_NPN_NEGOTIATED) {
615     VLOG(3) << "SSLContext::selectNextProcolCallback() "
616             << "unable to pick a next protocol.";
617   }
618   return SSL_TLSEXT_ERR_OK;
619 }
620 #endif // OPENSSL_NPN_NEGOTIATED
621
622 SSL* SSLContext::createSSL() const {
623   SSL* ssl = SSL_new(ctx_);
624   if (ssl == nullptr) {
625     throw std::runtime_error("SSL_new: " + getErrors());
626   }
627   return ssl;
628 }
629
630 void SSLContext::setSessionCacheContext(const std::string& context) {
631   SSL_CTX_set_session_id_context(
632       ctx_,
633       reinterpret_cast<const unsigned char*>(context.data()),
634       std::min(
635           static_cast<int>(context.length()), SSL_MAX_SSL_SESSION_ID_LENGTH));
636 }
637
638 /**
639  * Match a name with a pattern. The pattern may include wildcard. A single
640  * wildcard "*" can match up to one component in the domain name.
641  *
642  * @param  host    Host name, typically the name of the remote host
643  * @param  pattern Name retrieved from certificate
644  * @param  size    Size of "pattern"
645  * @return True, if "host" matches "pattern". False otherwise.
646  */
647 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
648   bool match = false;
649   int i = 0, j = 0;
650   while (i < size && host[j] != '\0') {
651     if (toupper(pattern[i]) == toupper(host[j])) {
652       i++;
653       j++;
654       continue;
655     }
656     if (pattern[i] == '*') {
657       while (host[j] != '.' && host[j] != '\0') {
658         j++;
659       }
660       i++;
661       continue;
662     }
663     break;
664   }
665   if (i == size && host[j] == '\0') {
666     match = true;
667   }
668   return match;
669 }
670
671 int SSLContext::passwordCallback(char* password,
672                                  int size,
673                                  int,
674                                  void* data) {
675   SSLContext* context = (SSLContext*)data;
676   if (context == nullptr || context->passwordCollector() == nullptr) {
677     return 0;
678   }
679   std::string userPassword;
680   // call user defined password collector to get password
681   context->passwordCollector()->getPassword(userPassword, size);
682   int length = userPassword.size();
683   if (length > size) {
684     length = size;
685   }
686   strncpy(password, userPassword.c_str(), length);
687   return length;
688 }
689
690 struct SSLLock {
691   explicit SSLLock(
692     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
693       lockType(inLockType) {
694   }
695
696   void lock() {
697     if (lockType == SSLContext::LOCK_MUTEX) {
698       mutex.lock();
699     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
700       spinLock.lock();
701     }
702     // lockType == LOCK_NONE, no-op
703   }
704
705   void unlock() {
706     if (lockType == SSLContext::LOCK_MUTEX) {
707       mutex.unlock();
708     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
709       spinLock.unlock();
710     }
711     // lockType == LOCK_NONE, no-op
712   }
713
714   SSLContext::SSLLockType lockType;
715   folly::SpinLock spinLock{};
716   std::mutex mutex;
717 };
718
719 // Statics are unsafe in environments that call exit().
720 // If one thread calls exit() while another thread is
721 // references a member of SSLContext, bad things can happen.
722 // SSLContext runs in such environments.
723 // Instead of declaring a static member we "new" the static
724 // member so that it won't be destructed on exit().
725 static std::unique_ptr<SSLLock[]>& locks() {
726   static auto locksInst = new std::unique_ptr<SSLLock[]>();
727   return *locksInst;
728 }
729
730 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
731   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
732   return *lockTypesInst;
733 }
734
735 static void callbackLocking(int mode, int n, const char*, int) {
736   if (mode & CRYPTO_LOCK) {
737     locks()[n].lock();
738   } else {
739     locks()[n].unlock();
740   }
741 }
742
743 static unsigned long callbackThreadID() {
744   return static_cast<unsigned long>(
745 #ifdef __APPLE__
746     pthread_mach_thread_np(pthread_self())
747 #elif _MSC_VER
748     pthread_getw32threadid_np(pthread_self())
749 #else
750     pthread_self()
751 #endif
752   );
753 }
754
755 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
756   return new CRYPTO_dynlock_value;
757 }
758
759 static void dyn_lock(int mode,
760                      struct CRYPTO_dynlock_value* lock,
761                      const char*, int) {
762   if (lock != nullptr) {
763     if (mode & CRYPTO_LOCK) {
764       lock->mutex.lock();
765     } else {
766       lock->mutex.unlock();
767     }
768   }
769 }
770
771 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
772   delete lock;
773 }
774
775 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
776   lockTypes() = inLockTypes;
777 }
778
779 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
780 void SSLContext::enableFalseStart() {
781   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
782 }
783 #endif
784
785 void SSLContext::markInitialized() {
786   std::lock_guard<std::mutex> g(initMutex());
787   initialized_ = true;
788 }
789
790 void SSLContext::initializeOpenSSL() {
791   std::lock_guard<std::mutex> g(initMutex());
792   initializeOpenSSLLocked();
793 }
794
795 void SSLContext::initializeOpenSSLLocked() {
796   if (initialized_) {
797     return;
798   }
799   SSL_library_init();
800   SSL_load_error_strings();
801   ERR_load_crypto_strings();
802   // static locking
803   locks().reset(new SSLLock[::CRYPTO_num_locks()]);
804   for (auto it: lockTypes()) {
805     locks()[it.first].lockType = it.second;
806   }
807   CRYPTO_set_id_callback(callbackThreadID);
808   CRYPTO_set_locking_callback(callbackLocking);
809   // dynamic locking
810   CRYPTO_set_dynlock_create_callback(dyn_create);
811   CRYPTO_set_dynlock_lock_callback(dyn_lock);
812   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
813   randomize();
814 #ifdef OPENSSL_NPN_NEGOTIATED
815   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
816       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
817 #endif
818   initialized_ = true;
819 }
820
821 void SSLContext::cleanupOpenSSL() {
822   std::lock_guard<std::mutex> g(initMutex());
823   cleanupOpenSSLLocked();
824 }
825
826 void SSLContext::cleanupOpenSSLLocked() {
827   if (!initialized_) {
828     return;
829   }
830
831   CRYPTO_set_id_callback(nullptr);
832   CRYPTO_set_locking_callback(nullptr);
833   CRYPTO_set_dynlock_create_callback(nullptr);
834   CRYPTO_set_dynlock_lock_callback(nullptr);
835   CRYPTO_set_dynlock_destroy_callback(nullptr);
836   CRYPTO_cleanup_all_ex_data();
837   ERR_free_strings();
838   EVP_cleanup();
839   ERR_remove_state(0);
840   locks().reset();
841   initialized_ = false;
842 }
843
844 void SSLContext::setOptions(long options) {
845   long newOpt = SSL_CTX_set_options(ctx_, options);
846   if ((newOpt & options) != options) {
847     throw std::runtime_error("SSL_CTX_set_options failed");
848   }
849 }
850
851 std::string SSLContext::getErrors(int errnoCopy) {
852   std::string errors;
853   unsigned long  errorCode;
854   char   message[256];
855
856   errors.reserve(512);
857   while ((errorCode = ERR_get_error()) != 0) {
858     if (!errors.empty()) {
859       errors += "; ";
860     }
861     const char* reason = ERR_reason_error_string(errorCode);
862     if (reason == nullptr) {
863       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
864       reason = message;
865     }
866     errors += reason;
867   }
868   if (errors.empty()) {
869     errors = "error code: " + folly::to<std::string>(errnoCopy);
870   }
871   return errors;
872 }
873
874 std::ostream&
875 operator<<(std::ostream& os, const PasswordCollector& collector) {
876   os << collector.describe();
877   return os;
878 }
879
880 } // folly