Enable -Wunreachable-code-return
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/Memory.h>
26 #include <folly/Random.h>
27 #include <folly/SpinLock.h>
28
29 // ---------------------------------------------------------------------
30 // SSLContext implementation
31 // ---------------------------------------------------------------------
32
33 struct CRYPTO_dynlock_value {
34   std::mutex mutex;
35 };
36
37 namespace folly {
38
39 bool SSLContext::initialized_ = false;
40
41 namespace {
42
43 std::mutex& initMutex() {
44   static std::mutex m;
45   return m;
46 }
47
48 } // anonymous namespace
49
50 #ifdef OPENSSL_NPN_NEGOTIATED
51 int SSLContext::sNextProtocolsExDataIndex_ = -1;
52 #endif
53
54 // SSLContext implementation
55 SSLContext::SSLContext(SSLVersion version) {
56   {
57     std::lock_guard<std::mutex> g(initMutex());
58     initializeOpenSSLLocked();
59   }
60
61   ctx_ = SSL_CTX_new(SSLv23_method());
62   if (ctx_ == nullptr) {
63     throw std::runtime_error("SSL_CTX_new: " + getErrors());
64   }
65
66   int opt = 0;
67   switch (version) {
68     case TLSv1:
69       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
70       break;
71     case SSLv3:
72       opt = SSL_OP_NO_SSLv2;
73       break;
74     default:
75       // do nothing
76       break;
77   }
78   int newOpt = SSL_CTX_set_options(ctx_, opt);
79   DCHECK((newOpt & opt) == opt);
80
81   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
82
83   checkPeerName_ = false;
84
85   SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
86
87 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
88   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
89   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
90 #endif
91 }
92
93 SSLContext::~SSLContext() {
94   if (ctx_ != nullptr) {
95     SSL_CTX_free(ctx_);
96     ctx_ = nullptr;
97   }
98
99 #ifdef OPENSSL_NPN_NEGOTIATED
100   deleteNextProtocolsStrings();
101 #endif
102 }
103
104 void SSLContext::ciphers(const std::string& ciphers) {
105   providedCiphersString_ = ciphers;
106   setCiphersOrThrow(ciphers);
107 }
108
109 void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
110   if (ciphers.size() == 0) {
111     return;
112   }
113   std::string opensslCipherList;
114   join(":", ciphers, opensslCipherList);
115   setCiphersOrThrow(opensslCipherList);
116 }
117
118 void SSLContext::setSignatureAlgorithms(
119     const std::vector<std::string>& sigalgs) {
120   if (sigalgs.size() == 0) {
121     return;
122   }
123 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
124   std::string opensslSigAlgsList;
125   join(":", sigalgs, opensslSigAlgsList);
126   int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
127   if (rc == 0) {
128     throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
129   }
130 #endif
131 }
132
133 void SSLContext::setClientECCurvesList(
134     const std::vector<std::string>& ecCurves) {
135   if (ecCurves.size() == 0) {
136     return;
137   }
138 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
139   std::string ecCurvesList;
140   join(":", ecCurves, ecCurvesList);
141   int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
142   if (rc == 0) {
143     throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
144   }
145 #endif
146 }
147
148 void SSLContext::setServerECCurve(const std::string& curveName) {
149   bool validCall = false;
150 #if OPENSSL_VERSION_NUMBER >= 0x0090800fL
151 #ifndef OPENSSL_NO_ECDH
152   validCall = true;
153 #endif
154 #endif
155   if (!validCall) {
156     throw std::runtime_error("Elliptic curve encryption not allowed");
157   }
158
159   EC_KEY* ecdh = nullptr;
160   int nid;
161
162   /*
163    * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
164    * from RFC 4492 section 5.1.1, or explicitly described curves over
165    * binary fields. OpenSSL only supports the "named curves", which provide
166    * maximum interoperability.
167    */
168
169   nid = OBJ_sn2nid(curveName.c_str());
170   if (nid == 0) {
171     LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
172   }
173   ecdh = EC_KEY_new_by_curve_name(nid);
174   if (ecdh == nullptr) {
175     LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
176   }
177
178   SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
179   EC_KEY_free(ecdh);
180 }
181
182 void SSLContext::setX509VerifyParam(
183     const ssl::X509VerifyParam& x509VerifyParam) {
184   if (!x509VerifyParam) {
185     return;
186   }
187   if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
188     throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
189   }
190 }
191
192 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
193   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
194   if (rc == 0) {
195     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
196   }
197 }
198
199 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
200     verifyPeer) {
201   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
202   verifyPeer_ = verifyPeer;
203 }
204
205 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
206     verifyPeer) {
207   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
208   int mode = SSL_VERIFY_NONE;
209   switch(verifyPeer) {
210     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
211     // break;
212
213     case SSLVerifyPeerEnum::VERIFY:
214       mode = SSL_VERIFY_PEER;
215       break;
216
217     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
218       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
219       break;
220
221     case SSLVerifyPeerEnum::NO_VERIFY:
222       mode = SSL_VERIFY_NONE;
223       break;
224
225     default:
226       break;
227   }
228   return mode;
229 }
230
231 int SSLContext::getVerificationMode() {
232   return getVerificationMode(verifyPeer_);
233 }
234
235 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
236                               const std::string& peerName) {
237   int mode;
238   if (checkPeerCert) {
239     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
240     checkPeerName_ = checkPeerName;
241     peerFixedName_ = peerName;
242   } else {
243     mode = SSL_VERIFY_NONE;
244     checkPeerName_ = false; // can't check name without cert!
245     peerFixedName_.clear();
246   }
247   SSL_CTX_set_verify(ctx_, mode, nullptr);
248 }
249
250 void SSLContext::loadCertificate(const char* path, const char* format) {
251   if (path == nullptr || format == nullptr) {
252     throw std::invalid_argument(
253          "loadCertificateChain: either <path> or <format> is nullptr");
254   }
255   if (strcmp(format, "PEM") == 0) {
256     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
257       int errnoCopy = errno;
258       std::string reason("SSL_CTX_use_certificate_chain_file: ");
259       reason.append(path);
260       reason.append(": ");
261       reason.append(getErrors(errnoCopy));
262       throw std::runtime_error(reason);
263     }
264   } else {
265     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
266   }
267 }
268
269 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
270   if (cert.data() == nullptr) {
271     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
272   }
273
274   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
275   if (bio == nullptr) {
276     throw std::runtime_error("BIO_new: " + getErrors());
277   }
278
279   int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
280   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
281     throw std::runtime_error("BIO_write: " + getErrors());
282   }
283
284   ssl::X509UniquePtr x509(
285       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
286   if (x509 == nullptr) {
287     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
288   }
289
290   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
291     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
292   }
293 }
294
295 void SSLContext::loadPrivateKey(const char* path, const char* format) {
296   if (path == nullptr || format == nullptr) {
297     throw std::invalid_argument(
298         "loadPrivateKey: either <path> or <format> is nullptr");
299   }
300   if (strcmp(format, "PEM") == 0) {
301     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
302       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
303     }
304   } else {
305     throw std::runtime_error("Unsupported private key format: " + std::string(format));
306   }
307 }
308
309 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
310   if (pkey.data() == nullptr) {
311     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
312   }
313
314   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
315   if (bio == nullptr) {
316     throw std::runtime_error("BIO_new: " + getErrors());
317   }
318
319   int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
320   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
321     throw std::runtime_error("BIO_write: " + getErrors());
322   }
323
324   ssl::EvpPkeyUniquePtr key(
325       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
326   if (key == nullptr) {
327     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
328   }
329
330   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
331     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
332   }
333 }
334
335 void SSLContext::loadTrustedCertificates(const char* path) {
336   if (path == nullptr) {
337     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
338   }
339   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
340     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
341   }
342 }
343
344 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
345   SSL_CTX_set_cert_store(ctx_, store);
346 }
347
348 void SSLContext::loadClientCAList(const char* path) {
349   auto clientCAs = SSL_load_client_CA_file(path);
350   if (clientCAs == nullptr) {
351     LOG(ERROR) << "Unable to load ca file: " << path;
352     return;
353   }
354   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
355 }
356
357 void SSLContext::randomize() {
358   RAND_poll();
359 }
360
361 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
362   if (collector == nullptr) {
363     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
364     return;
365   }
366   collector_ = collector;
367   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
368   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
369 }
370
371 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
372
373 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
374   serverNameCb_ = cb;
375 }
376
377 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
378   clientHelloCbs_.push_back(cb);
379 }
380
381 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
382   SSLContext* context = (SSLContext*)data;
383
384   if (context == nullptr) {
385     return SSL_TLSEXT_ERR_NOACK;
386   }
387
388   for (auto& cb : context->clientHelloCbs_) {
389     // Generic callbacks to happen after we receive the Client Hello.
390     // For example, we use one to switch which cipher we use depending
391     // on the user's TLS version.  Because the primary purpose of
392     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
393     // are side-uses, we ignore any possible failures other than just logging
394     // them.
395     cb(ssl);
396   }
397
398   if (!context->serverNameCb_) {
399     return SSL_TLSEXT_ERR_NOACK;
400   }
401
402   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
403   switch (ret) {
404     case SERVER_NAME_FOUND:
405       return SSL_TLSEXT_ERR_OK;
406     case SERVER_NAME_NOT_FOUND:
407       return SSL_TLSEXT_ERR_NOACK;
408     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
409       *al = TLS1_AD_UNRECOGNIZED_NAME;
410       return SSL_TLSEXT_ERR_ALERT_FATAL;
411     default:
412       CHECK(false);
413   }
414
415   return SSL_TLSEXT_ERR_NOACK;
416 }
417
418 void SSLContext::switchCiphersIfTLS11(
419     SSL* ssl,
420     const std::string& tls11CipherString,
421     const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
422   CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
423       << "Shouldn't call if empty ciphers / alt ciphers";
424
425   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
426     // We only do this for TLS v 1.1 and later
427     return;
428   }
429
430   const std::string* ciphers = &tls11CipherString;
431   if (!tls11AltCipherlist.empty()) {
432     if (!cipherListPicker_) {
433       std::vector<int> weights;
434       std::for_each(
435           tls11AltCipherlist.begin(),
436           tls11AltCipherlist.end(),
437           [&](const std::pair<std::string, int>& e) {
438             weights.push_back(e.second);
439           });
440       cipherListPicker_.reset(
441           new std::discrete_distribution<int>(weights.begin(), weights.end()));
442     }
443     auto rng = ThreadLocalPRNG();
444     auto index = (*cipherListPicker_)(rng);
445     if ((size_t)index >= tls11AltCipherlist.size()) {
446       LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
447                  << ", but tls11AltCipherlist is of length "
448                  << tls11AltCipherlist.size();
449     } else {
450       ciphers = &tls11AltCipherlist[index].first;
451     }
452   }
453
454   // Prefer AES for TLS versions 1.1 and later since these are not
455   // vulnerable to BEAST attacks on AES.  Note that we're setting the
456   // cipher list on the SSL object, not the SSL_CTX object, so it will
457   // only last for this request.
458   int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
459   if ((rc == 0) || ERR_peek_error() != 0) {
460     // This shouldn't happen since we checked for this when proxygen
461     // started up.
462     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
463     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
464   }
465 }
466 #endif
467
468 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
469 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
470                                    const unsigned char** out,
471                                    unsigned char* outlen,
472                                    const unsigned char* in,
473                                    unsigned int inlen,
474                                    void* data) {
475   SSLContext* context = (SSLContext*)data;
476   CHECK(context);
477   if (context->advertisedNextProtocols_.empty()) {
478     *out = nullptr;
479     *outlen = 0;
480   } else {
481     auto i = context->pickNextProtocols();
482     const auto& item = context->advertisedNextProtocols_[i];
483     if (SSL_select_next_proto((unsigned char**)out,
484                               outlen,
485                               item.protocols,
486                               item.length,
487                               in,
488                               inlen) != OPENSSL_NPN_NEGOTIATED) {
489       return SSL_TLSEXT_ERR_NOACK;
490     }
491   }
492   return SSL_TLSEXT_ERR_OK;
493 }
494 #endif
495
496 #ifdef OPENSSL_NPN_NEGOTIATED
497
498 bool SSLContext::setAdvertisedNextProtocols(
499     const std::list<std::string>& protocols, NextProtocolType protocolType) {
500   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
501 }
502
503 bool SSLContext::setRandomizedAdvertisedNextProtocols(
504     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
505   unsetNextProtocols();
506   if (items.size() == 0) {
507     return false;
508   }
509   int total_weight = 0;
510   for (const auto &item : items) {
511     if (item.protocols.size() == 0) {
512       continue;
513     }
514     AdvertisedNextProtocolsItem advertised_item;
515     advertised_item.length = 0;
516     for (const auto& proto : item.protocols) {
517       ++advertised_item.length;
518       auto protoLength = proto.length();
519       if (protoLength >= 256) {
520         deleteNextProtocolsStrings();
521         return false;
522       }
523       advertised_item.length += unsigned(protoLength);
524     }
525     advertised_item.protocols = new unsigned char[advertised_item.length];
526     if (!advertised_item.protocols) {
527       throw std::runtime_error("alloc failure");
528     }
529     unsigned char* dst = advertised_item.protocols;
530     for (auto& proto : item.protocols) {
531       uint8_t protoLength = uint8_t(proto.length());
532       *dst++ = (unsigned char)protoLength;
533       memcpy(dst, proto.data(), protoLength);
534       dst += protoLength;
535     }
536     total_weight += item.weight;
537     advertisedNextProtocols_.push_back(advertised_item);
538     advertisedNextProtocolWeights_.push_back(item.weight);
539   }
540   if (total_weight == 0) {
541     deleteNextProtocolsStrings();
542     return false;
543   }
544   nextProtocolDistribution_ =
545       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
546                                    advertisedNextProtocolWeights_.end());
547   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
548     SSL_CTX_set_next_protos_advertised_cb(
549         ctx_, advertisedNextProtocolCallback, this);
550     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
551   }
552 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
553   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
554     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
555     // Client cannot really use randomized alpn
556     SSL_CTX_set_alpn_protos(ctx_,
557                             advertisedNextProtocols_[0].protocols,
558                             advertisedNextProtocols_[0].length);
559   }
560 #endif
561   return true;
562 }
563
564 void SSLContext::deleteNextProtocolsStrings() {
565   for (auto protocols : advertisedNextProtocols_) {
566     delete[] protocols.protocols;
567   }
568   advertisedNextProtocols_.clear();
569   advertisedNextProtocolWeights_.clear();
570 }
571
572 void SSLContext::unsetNextProtocols() {
573   deleteNextProtocolsStrings();
574   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
575   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
576 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
577   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
578   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
579 #endif
580 }
581
582 size_t SSLContext::pickNextProtocols() {
583   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
584   auto rng = ThreadLocalPRNG();
585   return nextProtocolDistribution_(rng);
586 }
587
588 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
589       const unsigned char** out, unsigned int* outlen, void* data) {
590   SSLContext* context = (SSLContext*)data;
591   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
592     *out = nullptr;
593     *outlen = 0;
594   } else if (context->advertisedNextProtocols_.size() == 1) {
595     *out = context->advertisedNextProtocols_[0].protocols;
596     *outlen = context->advertisedNextProtocols_[0].length;
597   } else {
598     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
599           sNextProtocolsExDataIndex_));
600     if (selected_index) {
601       --selected_index;
602       *out = context->advertisedNextProtocols_[selected_index].protocols;
603       *outlen = context->advertisedNextProtocols_[selected_index].length;
604     } else {
605       auto i = context->pickNextProtocols();
606       uintptr_t selected = i + 1;
607       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
608       *out = context->advertisedNextProtocols_[i].protocols;
609       *outlen = context->advertisedNextProtocols_[i].length;
610     }
611   }
612   return SSL_TLSEXT_ERR_OK;
613 }
614
615 int SSLContext::selectNextProtocolCallback(SSL* ssl,
616                                            unsigned char** out,
617                                            unsigned char* outlen,
618                                            const unsigned char* server,
619                                            unsigned int server_len,
620                                            void* data) {
621   (void)ssl; // Make -Wunused-parameters happy
622   SSLContext* ctx = (SSLContext*)data;
623   if (ctx->advertisedNextProtocols_.size() > 1) {
624     VLOG(3) << "SSLContext::selectNextProcolCallback() "
625             << "client should be deterministic in selecting protocols.";
626   }
627
628   unsigned char* client = nullptr;
629   unsigned int client_len = 0;
630   bool filtered = false;
631   auto cpf = ctx->getClientProtocolFilterCallback();
632   if (cpf) {
633     filtered = (*cpf)(&client, &client_len, server, server_len);
634   }
635
636   if (!filtered) {
637     if (ctx->advertisedNextProtocols_.empty()) {
638       client = (unsigned char *) "";
639       client_len = 0;
640     } else {
641       client = ctx->advertisedNextProtocols_[0].protocols;
642       client_len = ctx->advertisedNextProtocols_[0].length;
643     }
644   }
645
646   int retval = SSL_select_next_proto(out, outlen, server, server_len,
647                                      client, client_len);
648   if (retval != OPENSSL_NPN_NEGOTIATED) {
649     VLOG(3) << "SSLContext::selectNextProcolCallback() "
650             << "unable to pick a next protocol.";
651   }
652   return SSL_TLSEXT_ERR_OK;
653 }
654 #endif // OPENSSL_NPN_NEGOTIATED
655
656 SSL* SSLContext::createSSL() const {
657   SSL* ssl = SSL_new(ctx_);
658   if (ssl == nullptr) {
659     throw std::runtime_error("SSL_new: " + getErrors());
660   }
661   return ssl;
662 }
663
664 void SSLContext::setSessionCacheContext(const std::string& context) {
665   SSL_CTX_set_session_id_context(
666       ctx_,
667       reinterpret_cast<const unsigned char*>(context.data()),
668       std::min(
669           static_cast<int>(context.length()), SSL_MAX_SSL_SESSION_ID_LENGTH));
670 }
671
672 /**
673  * Match a name with a pattern. The pattern may include wildcard. A single
674  * wildcard "*" can match up to one component in the domain name.
675  *
676  * @param  host    Host name, typically the name of the remote host
677  * @param  pattern Name retrieved from certificate
678  * @param  size    Size of "pattern"
679  * @return True, if "host" matches "pattern". False otherwise.
680  */
681 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
682   bool match = false;
683   int i = 0, j = 0;
684   while (i < size && host[j] != '\0') {
685     if (toupper(pattern[i]) == toupper(host[j])) {
686       i++;
687       j++;
688       continue;
689     }
690     if (pattern[i] == '*') {
691       while (host[j] != '.' && host[j] != '\0') {
692         j++;
693       }
694       i++;
695       continue;
696     }
697     break;
698   }
699   if (i == size && host[j] == '\0') {
700     match = true;
701   }
702   return match;
703 }
704
705 int SSLContext::passwordCallback(char* password,
706                                  int size,
707                                  int,
708                                  void* data) {
709   SSLContext* context = (SSLContext*)data;
710   if (context == nullptr || context->passwordCollector() == nullptr) {
711     return 0;
712   }
713   std::string userPassword;
714   // call user defined password collector to get password
715   context->passwordCollector()->getPassword(userPassword, size);
716   auto length = int(userPassword.size());
717   if (length > size) {
718     length = size;
719   }
720   strncpy(password, userPassword.c_str(), length);
721   return length;
722 }
723
724 struct SSLLock {
725   explicit SSLLock(
726     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
727       lockType(inLockType) {
728   }
729
730   void lock() {
731     if (lockType == SSLContext::LOCK_MUTEX) {
732       mutex.lock();
733     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
734       spinLock.lock();
735     }
736     // lockType == LOCK_NONE, no-op
737   }
738
739   void unlock() {
740     if (lockType == SSLContext::LOCK_MUTEX) {
741       mutex.unlock();
742     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
743       spinLock.unlock();
744     }
745     // lockType == LOCK_NONE, no-op
746   }
747
748   SSLContext::SSLLockType lockType;
749   folly::SpinLock spinLock{};
750   std::mutex mutex;
751 };
752
753 // Statics are unsafe in environments that call exit().
754 // If one thread calls exit() while another thread is
755 // references a member of SSLContext, bad things can happen.
756 // SSLContext runs in such environments.
757 // Instead of declaring a static member we "new" the static
758 // member so that it won't be destructed on exit().
759 static std::unique_ptr<SSLLock[]>& locks() {
760   static auto locksInst = new std::unique_ptr<SSLLock[]>();
761   return *locksInst;
762 }
763
764 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
765   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
766   return *lockTypesInst;
767 }
768
769 static void callbackLocking(int mode, int n, const char*, int) {
770   if (mode & CRYPTO_LOCK) {
771     locks()[n].lock();
772   } else {
773     locks()[n].unlock();
774   }
775 }
776
777 static unsigned long callbackThreadID() {
778   return static_cast<unsigned long>(
779 #ifdef __APPLE__
780     pthread_mach_thread_np(pthread_self())
781 #elif _MSC_VER
782     pthread_getw32threadid_np(pthread_self())
783 #else
784     pthread_self()
785 #endif
786   );
787 }
788
789 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
790   return new CRYPTO_dynlock_value;
791 }
792
793 static void dyn_lock(int mode,
794                      struct CRYPTO_dynlock_value* lock,
795                      const char*, int) {
796   if (lock != nullptr) {
797     if (mode & CRYPTO_LOCK) {
798       lock->mutex.lock();
799     } else {
800       lock->mutex.unlock();
801     }
802   }
803 }
804
805 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
806   delete lock;
807 }
808
809 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
810   lockTypes() = inLockTypes;
811 }
812
813 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
814 void SSLContext::enableFalseStart() {
815   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
816 }
817 #endif
818
819 void SSLContext::markInitialized() {
820   std::lock_guard<std::mutex> g(initMutex());
821   initialized_ = true;
822 }
823
824 void SSLContext::initializeOpenSSL() {
825   std::lock_guard<std::mutex> g(initMutex());
826   initializeOpenSSLLocked();
827 }
828
829 void SSLContext::initializeOpenSSLLocked() {
830   if (initialized_) {
831     return;
832   }
833   SSL_library_init();
834   SSL_load_error_strings();
835   ERR_load_crypto_strings();
836   // static locking
837   locks().reset(new SSLLock[::CRYPTO_num_locks()]);
838   for (auto it: lockTypes()) {
839     locks()[it.first].lockType = it.second;
840   }
841   CRYPTO_set_id_callback(callbackThreadID);
842   CRYPTO_set_locking_callback(callbackLocking);
843   // dynamic locking
844   CRYPTO_set_dynlock_create_callback(dyn_create);
845   CRYPTO_set_dynlock_lock_callback(dyn_lock);
846   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
847   randomize();
848 #ifdef OPENSSL_NPN_NEGOTIATED
849   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
850       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
851 #endif
852   initialized_ = true;
853 }
854
855 void SSLContext::cleanupOpenSSL() {
856   std::lock_guard<std::mutex> g(initMutex());
857   cleanupOpenSSLLocked();
858 }
859
860 void SSLContext::cleanupOpenSSLLocked() {
861   if (!initialized_) {
862     return;
863   }
864
865   CRYPTO_set_id_callback(nullptr);
866   CRYPTO_set_locking_callback(nullptr);
867   CRYPTO_set_dynlock_create_callback(nullptr);
868   CRYPTO_set_dynlock_lock_callback(nullptr);
869   CRYPTO_set_dynlock_destroy_callback(nullptr);
870   CRYPTO_cleanup_all_ex_data();
871   ERR_free_strings();
872   EVP_cleanup();
873   ERR_remove_state(0);
874   locks().reset();
875   initialized_ = false;
876 }
877
878 void SSLContext::setOptions(long options) {
879   long newOpt = SSL_CTX_set_options(ctx_, options);
880   if ((newOpt & options) != options) {
881     throw std::runtime_error("SSL_CTX_set_options failed");
882   }
883 }
884
885 std::string SSLContext::getErrors(int errnoCopy) {
886   std::string errors;
887   unsigned long  errorCode;
888   char   message[256];
889
890   errors.reserve(512);
891   while ((errorCode = ERR_get_error()) != 0) {
892     if (!errors.empty()) {
893       errors += "; ";
894     }
895     const char* reason = ERR_reason_error_string(errorCode);
896     if (reason == nullptr) {
897       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
898       reason = message;
899     }
900     errors += reason;
901   }
902   if (errors.empty()) {
903     errors = "error code: " + folly::to<std::string>(errnoCopy);
904   }
905   return errors;
906 }
907
908 std::ostream&
909 operator<<(std::ostream& os, const PasswordCollector& collector) {
910   os << collector.describe();
911   return os;
912 }
913
914 } // folly