fix flaky ConnectTFOTimeout and ConnectTFOFallbackTimeout tests
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/Memory.h>
26 #include <folly/Random.h>
27 #include <folly/SpinLock.h>
28
29 // ---------------------------------------------------------------------
30 // SSLContext implementation
31 // ---------------------------------------------------------------------
32
33 struct CRYPTO_dynlock_value {
34   std::mutex mutex;
35 };
36
37 namespace folly {
38
39 bool SSLContext::initialized_ = false;
40
41 namespace {
42
43 std::mutex& initMutex() {
44   static std::mutex m;
45   return m;
46 }
47
48 } // anonymous namespace
49
50 #ifdef OPENSSL_NPN_NEGOTIATED
51 int SSLContext::sNextProtocolsExDataIndex_ = -1;
52 #endif
53
54 // SSLContext implementation
55 SSLContext::SSLContext(SSLVersion version) {
56   {
57     std::lock_guard<std::mutex> g(initMutex());
58     initializeOpenSSLLocked();
59   }
60
61   ctx_ = SSL_CTX_new(SSLv23_method());
62   if (ctx_ == nullptr) {
63     throw std::runtime_error("SSL_CTX_new: " + getErrors());
64   }
65
66   int opt = 0;
67   switch (version) {
68     case TLSv1:
69       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
70       break;
71     case SSLv3:
72       opt = SSL_OP_NO_SSLv2;
73       break;
74     default:
75       // do nothing
76       break;
77   }
78   int newOpt = SSL_CTX_set_options(ctx_, opt);
79   DCHECK((newOpt & opt) == opt);
80
81   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
82
83   checkPeerName_ = false;
84
85   SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
86
87 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
88   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
89   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
90 #endif
91 }
92
93 SSLContext::~SSLContext() {
94   if (ctx_ != nullptr) {
95     SSL_CTX_free(ctx_);
96     ctx_ = nullptr;
97   }
98
99 #ifdef OPENSSL_NPN_NEGOTIATED
100   deleteNextProtocolsStrings();
101 #endif
102 }
103
104 void SSLContext::ciphers(const std::string& ciphers) {
105   providedCiphersString_ = ciphers;
106   setCiphersOrThrow(ciphers);
107 }
108
109 void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
110   if (ciphers.size() == 0) {
111     return;
112   }
113   std::string opensslCipherList;
114   join(":", ciphers, opensslCipherList);
115   setCiphersOrThrow(opensslCipherList);
116 }
117
118 void SSLContext::setSignatureAlgorithms(
119     const std::vector<std::string>& sigalgs) {
120   if (sigalgs.size() == 0) {
121     return;
122   }
123 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
124   std::string opensslSigAlgsList;
125   join(":", sigalgs, opensslSigAlgsList);
126   int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
127   if (rc == 0) {
128     throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
129   }
130 #endif
131 }
132
133 void SSLContext::setClientECCurvesList(
134     const std::vector<std::string>& ecCurves) {
135   if (ecCurves.size() == 0) {
136     return;
137   }
138 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
139   std::string ecCurvesList;
140   join(":", ecCurves, ecCurvesList);
141   int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
142   if (rc == 0) {
143     throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
144   }
145 #endif
146 }
147
148 void SSLContext::setServerECCurve(const std::string& curveName) {
149   bool validCall = false;
150 #if OPENSSL_VERSION_NUMBER >= 0x0090800fL
151 #ifndef OPENSSL_NO_ECDH
152   validCall = true;
153 #endif
154 #endif
155   if (!validCall) {
156     throw std::runtime_error("Elliptic curve encryption not allowed");
157   }
158
159   EC_KEY* ecdh = nullptr;
160   int nid;
161
162   /*
163    * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
164    * from RFC 4492 section 5.1.1, or explicitly described curves over
165    * binary fields. OpenSSL only supports the "named curves", which provide
166    * maximum interoperability.
167    */
168
169   nid = OBJ_sn2nid(curveName.c_str());
170   if (nid == 0) {
171     LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
172     return;
173   }
174   ecdh = EC_KEY_new_by_curve_name(nid);
175   if (ecdh == nullptr) {
176     LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
177     return;
178   }
179
180   SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
181   EC_KEY_free(ecdh);
182 }
183
184 void SSLContext::setX509VerifyParam(
185     const ssl::X509VerifyParam& x509VerifyParam) {
186   if (!x509VerifyParam) {
187     return;
188   }
189   if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
190     throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
191   }
192 }
193
194 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
195   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
196   if (rc == 0) {
197     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
198   }
199 }
200
201 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
202     verifyPeer) {
203   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
204   verifyPeer_ = verifyPeer;
205 }
206
207 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
208     verifyPeer) {
209   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
210   int mode = SSL_VERIFY_NONE;
211   switch(verifyPeer) {
212     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
213     // break;
214
215     case SSLVerifyPeerEnum::VERIFY:
216       mode = SSL_VERIFY_PEER;
217       break;
218
219     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
220       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
221       break;
222
223     case SSLVerifyPeerEnum::NO_VERIFY:
224       mode = SSL_VERIFY_NONE;
225       break;
226
227     default:
228       break;
229   }
230   return mode;
231 }
232
233 int SSLContext::getVerificationMode() {
234   return getVerificationMode(verifyPeer_);
235 }
236
237 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
238                               const std::string& peerName) {
239   int mode;
240   if (checkPeerCert) {
241     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
242     checkPeerName_ = checkPeerName;
243     peerFixedName_ = peerName;
244   } else {
245     mode = SSL_VERIFY_NONE;
246     checkPeerName_ = false; // can't check name without cert!
247     peerFixedName_.clear();
248   }
249   SSL_CTX_set_verify(ctx_, mode, nullptr);
250 }
251
252 void SSLContext::loadCertificate(const char* path, const char* format) {
253   if (path == nullptr || format == nullptr) {
254     throw std::invalid_argument(
255          "loadCertificateChain: either <path> or <format> is nullptr");
256   }
257   if (strcmp(format, "PEM") == 0) {
258     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
259       int errnoCopy = errno;
260       std::string reason("SSL_CTX_use_certificate_chain_file: ");
261       reason.append(path);
262       reason.append(": ");
263       reason.append(getErrors(errnoCopy));
264       throw std::runtime_error(reason);
265     }
266   } else {
267     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
268   }
269 }
270
271 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
272   if (cert.data() == nullptr) {
273     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
274   }
275
276   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
277   if (bio == nullptr) {
278     throw std::runtime_error("BIO_new: " + getErrors());
279   }
280
281   int written = BIO_write(bio.get(), cert.data(), cert.size());
282   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
283     throw std::runtime_error("BIO_write: " + getErrors());
284   }
285
286   ssl::X509UniquePtr x509(
287       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
288   if (x509 == nullptr) {
289     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
290   }
291
292   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
293     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
294   }
295 }
296
297 void SSLContext::loadPrivateKey(const char* path, const char* format) {
298   if (path == nullptr || format == nullptr) {
299     throw std::invalid_argument(
300         "loadPrivateKey: either <path> or <format> is nullptr");
301   }
302   if (strcmp(format, "PEM") == 0) {
303     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
304       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
305     }
306   } else {
307     throw std::runtime_error("Unsupported private key format: " + std::string(format));
308   }
309 }
310
311 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
312   if (pkey.data() == nullptr) {
313     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
314   }
315
316   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
317   if (bio == nullptr) {
318     throw std::runtime_error("BIO_new: " + getErrors());
319   }
320
321   int written = BIO_write(bio.get(), pkey.data(), pkey.size());
322   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
323     throw std::runtime_error("BIO_write: " + getErrors());
324   }
325
326   ssl::EvpPkeyUniquePtr key(
327       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
328   if (key == nullptr) {
329     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
330   }
331
332   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
333     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
334   }
335 }
336
337 void SSLContext::loadTrustedCertificates(const char* path) {
338   if (path == nullptr) {
339     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
340   }
341   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
342     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
343   }
344 }
345
346 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
347   SSL_CTX_set_cert_store(ctx_, store);
348 }
349
350 void SSLContext::loadClientCAList(const char* path) {
351   auto clientCAs = SSL_load_client_CA_file(path);
352   if (clientCAs == nullptr) {
353     LOG(ERROR) << "Unable to load ca file: " << path;
354     return;
355   }
356   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
357 }
358
359 void SSLContext::randomize() {
360   RAND_poll();
361 }
362
363 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
364   if (collector == nullptr) {
365     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
366     return;
367   }
368   collector_ = collector;
369   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
370   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
371 }
372
373 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
374
375 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
376   serverNameCb_ = cb;
377 }
378
379 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
380   clientHelloCbs_.push_back(cb);
381 }
382
383 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
384   SSLContext* context = (SSLContext*)data;
385
386   if (context == nullptr) {
387     return SSL_TLSEXT_ERR_NOACK;
388   }
389
390   for (auto& cb : context->clientHelloCbs_) {
391     // Generic callbacks to happen after we receive the Client Hello.
392     // For example, we use one to switch which cipher we use depending
393     // on the user's TLS version.  Because the primary purpose of
394     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
395     // are side-uses, we ignore any possible failures other than just logging
396     // them.
397     cb(ssl);
398   }
399
400   if (!context->serverNameCb_) {
401     return SSL_TLSEXT_ERR_NOACK;
402   }
403
404   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
405   switch (ret) {
406     case SERVER_NAME_FOUND:
407       return SSL_TLSEXT_ERR_OK;
408     case SERVER_NAME_NOT_FOUND:
409       return SSL_TLSEXT_ERR_NOACK;
410     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
411       *al = TLS1_AD_UNRECOGNIZED_NAME;
412       return SSL_TLSEXT_ERR_ALERT_FATAL;
413     default:
414       CHECK(false);
415   }
416
417   return SSL_TLSEXT_ERR_NOACK;
418 }
419
420 void SSLContext::switchCiphersIfTLS11(
421     SSL* ssl,
422     const std::string& tls11CipherString,
423     const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
424   CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
425       << "Shouldn't call if empty ciphers / alt ciphers";
426
427   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
428     // We only do this for TLS v 1.1 and later
429     return;
430   }
431
432   const std::string* ciphers = &tls11CipherString;
433   if (!tls11AltCipherlist.empty()) {
434     if (!cipherListPicker_) {
435       std::vector<int> weights;
436       std::for_each(
437           tls11AltCipherlist.begin(),
438           tls11AltCipherlist.end(),
439           [&](const std::pair<std::string, int>& e) {
440             weights.push_back(e.second);
441           });
442       cipherListPicker_.reset(
443           new std::discrete_distribution<int>(weights.begin(), weights.end()));
444     }
445     auto rng = ThreadLocalPRNG();
446     auto index = (*cipherListPicker_)(rng);
447     if ((size_t)index >= tls11AltCipherlist.size()) {
448       LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
449                  << ", but tls11AltCipherlist is of length "
450                  << tls11AltCipherlist.size();
451     } else {
452       ciphers = &tls11AltCipherlist[index].first;
453     }
454   }
455
456   // Prefer AES for TLS versions 1.1 and later since these are not
457   // vulnerable to BEAST attacks on AES.  Note that we're setting the
458   // cipher list on the SSL object, not the SSL_CTX object, so it will
459   // only last for this request.
460   int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
461   if ((rc == 0) || ERR_peek_error() != 0) {
462     // This shouldn't happen since we checked for this when proxygen
463     // started up.
464     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
465     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
466   }
467 }
468 #endif
469
470 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
471 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
472                                    const unsigned char** out,
473                                    unsigned char* outlen,
474                                    const unsigned char* in,
475                                    unsigned int inlen,
476                                    void* data) {
477   SSLContext* context = (SSLContext*)data;
478   CHECK(context);
479   if (context->advertisedNextProtocols_.empty()) {
480     *out = nullptr;
481     *outlen = 0;
482   } else {
483     auto i = context->pickNextProtocols();
484     const auto& item = context->advertisedNextProtocols_[i];
485     if (SSL_select_next_proto((unsigned char**)out,
486                               outlen,
487                               item.protocols,
488                               item.length,
489                               in,
490                               inlen) != OPENSSL_NPN_NEGOTIATED) {
491       return SSL_TLSEXT_ERR_NOACK;
492     }
493   }
494   return SSL_TLSEXT_ERR_OK;
495 }
496 #endif
497
498 #ifdef OPENSSL_NPN_NEGOTIATED
499
500 bool SSLContext::setAdvertisedNextProtocols(
501     const std::list<std::string>& protocols, NextProtocolType protocolType) {
502   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
503 }
504
505 bool SSLContext::setRandomizedAdvertisedNextProtocols(
506     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
507   unsetNextProtocols();
508   if (items.size() == 0) {
509     return false;
510   }
511   int total_weight = 0;
512   for (const auto &item : items) {
513     if (item.protocols.size() == 0) {
514       continue;
515     }
516     AdvertisedNextProtocolsItem advertised_item;
517     advertised_item.length = 0;
518     for (const auto& proto : item.protocols) {
519       ++advertised_item.length;
520       unsigned protoLength = proto.length();
521       if (protoLength >= 256) {
522         deleteNextProtocolsStrings();
523         return false;
524       }
525       advertised_item.length += protoLength;
526     }
527     advertised_item.protocols = new unsigned char[advertised_item.length];
528     if (!advertised_item.protocols) {
529       throw std::runtime_error("alloc failure");
530     }
531     unsigned char* dst = advertised_item.protocols;
532     for (auto& proto : item.protocols) {
533       unsigned protoLength = proto.length();
534       *dst++ = (unsigned char)protoLength;
535       memcpy(dst, proto.data(), protoLength);
536       dst += protoLength;
537     }
538     total_weight += item.weight;
539     advertisedNextProtocols_.push_back(advertised_item);
540     advertisedNextProtocolWeights_.push_back(item.weight);
541   }
542   if (total_weight == 0) {
543     deleteNextProtocolsStrings();
544     return false;
545   }
546   nextProtocolDistribution_ =
547       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
548                                    advertisedNextProtocolWeights_.end());
549   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
550     SSL_CTX_set_next_protos_advertised_cb(
551         ctx_, advertisedNextProtocolCallback, this);
552     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
553   }
554 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
555   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
556     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
557     // Client cannot really use randomized alpn
558     SSL_CTX_set_alpn_protos(ctx_,
559                             advertisedNextProtocols_[0].protocols,
560                             advertisedNextProtocols_[0].length);
561   }
562 #endif
563   return true;
564 }
565
566 void SSLContext::deleteNextProtocolsStrings() {
567   for (auto protocols : advertisedNextProtocols_) {
568     delete[] protocols.protocols;
569   }
570   advertisedNextProtocols_.clear();
571   advertisedNextProtocolWeights_.clear();
572 }
573
574 void SSLContext::unsetNextProtocols() {
575   deleteNextProtocolsStrings();
576   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
577   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
578 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
579   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
580   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
581 #endif
582 }
583
584 size_t SSLContext::pickNextProtocols() {
585   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
586   auto rng = ThreadLocalPRNG();
587   return nextProtocolDistribution_(rng);
588 }
589
590 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
591       const unsigned char** out, unsigned int* outlen, void* data) {
592   SSLContext* context = (SSLContext*)data;
593   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
594     *out = nullptr;
595     *outlen = 0;
596   } else if (context->advertisedNextProtocols_.size() == 1) {
597     *out = context->advertisedNextProtocols_[0].protocols;
598     *outlen = context->advertisedNextProtocols_[0].length;
599   } else {
600     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
601           sNextProtocolsExDataIndex_));
602     if (selected_index) {
603       --selected_index;
604       *out = context->advertisedNextProtocols_[selected_index].protocols;
605       *outlen = context->advertisedNextProtocols_[selected_index].length;
606     } else {
607       auto i = context->pickNextProtocols();
608       uintptr_t selected = i + 1;
609       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
610       *out = context->advertisedNextProtocols_[i].protocols;
611       *outlen = context->advertisedNextProtocols_[i].length;
612     }
613   }
614   return SSL_TLSEXT_ERR_OK;
615 }
616
617 int SSLContext::selectNextProtocolCallback(SSL* ssl,
618                                            unsigned char** out,
619                                            unsigned char* outlen,
620                                            const unsigned char* server,
621                                            unsigned int server_len,
622                                            void* data) {
623   (void)ssl; // Make -Wunused-parameters happy
624   SSLContext* ctx = (SSLContext*)data;
625   if (ctx->advertisedNextProtocols_.size() > 1) {
626     VLOG(3) << "SSLContext::selectNextProcolCallback() "
627             << "client should be deterministic in selecting protocols.";
628   }
629
630   unsigned char* client = nullptr;
631   unsigned int client_len = 0;
632   bool filtered = false;
633   auto cpf = ctx->getClientProtocolFilterCallback();
634   if (cpf) {
635     filtered = (*cpf)(&client, &client_len, server, server_len);
636   }
637
638   if (!filtered) {
639     if (ctx->advertisedNextProtocols_.empty()) {
640       client = (unsigned char *) "";
641       client_len = 0;
642     } else {
643       client = ctx->advertisedNextProtocols_[0].protocols;
644       client_len = ctx->advertisedNextProtocols_[0].length;
645     }
646   }
647
648   int retval = SSL_select_next_proto(out, outlen, server, server_len,
649                                      client, client_len);
650   if (retval != OPENSSL_NPN_NEGOTIATED) {
651     VLOG(3) << "SSLContext::selectNextProcolCallback() "
652             << "unable to pick a next protocol.";
653   }
654   return SSL_TLSEXT_ERR_OK;
655 }
656 #endif // OPENSSL_NPN_NEGOTIATED
657
658 SSL* SSLContext::createSSL() const {
659   SSL* ssl = SSL_new(ctx_);
660   if (ssl == nullptr) {
661     throw std::runtime_error("SSL_new: " + getErrors());
662   }
663   return ssl;
664 }
665
666 void SSLContext::setSessionCacheContext(const std::string& context) {
667   SSL_CTX_set_session_id_context(
668       ctx_,
669       reinterpret_cast<const unsigned char*>(context.data()),
670       std::min(
671           static_cast<int>(context.length()), SSL_MAX_SSL_SESSION_ID_LENGTH));
672 }
673
674 /**
675  * Match a name with a pattern. The pattern may include wildcard. A single
676  * wildcard "*" can match up to one component in the domain name.
677  *
678  * @param  host    Host name, typically the name of the remote host
679  * @param  pattern Name retrieved from certificate
680  * @param  size    Size of "pattern"
681  * @return True, if "host" matches "pattern". False otherwise.
682  */
683 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
684   bool match = false;
685   int i = 0, j = 0;
686   while (i < size && host[j] != '\0') {
687     if (toupper(pattern[i]) == toupper(host[j])) {
688       i++;
689       j++;
690       continue;
691     }
692     if (pattern[i] == '*') {
693       while (host[j] != '.' && host[j] != '\0') {
694         j++;
695       }
696       i++;
697       continue;
698     }
699     break;
700   }
701   if (i == size && host[j] == '\0') {
702     match = true;
703   }
704   return match;
705 }
706
707 int SSLContext::passwordCallback(char* password,
708                                  int size,
709                                  int,
710                                  void* data) {
711   SSLContext* context = (SSLContext*)data;
712   if (context == nullptr || context->passwordCollector() == nullptr) {
713     return 0;
714   }
715   std::string userPassword;
716   // call user defined password collector to get password
717   context->passwordCollector()->getPassword(userPassword, size);
718   int length = userPassword.size();
719   if (length > size) {
720     length = size;
721   }
722   strncpy(password, userPassword.c_str(), length);
723   return length;
724 }
725
726 struct SSLLock {
727   explicit SSLLock(
728     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
729       lockType(inLockType) {
730   }
731
732   void lock() {
733     if (lockType == SSLContext::LOCK_MUTEX) {
734       mutex.lock();
735     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
736       spinLock.lock();
737     }
738     // lockType == LOCK_NONE, no-op
739   }
740
741   void unlock() {
742     if (lockType == SSLContext::LOCK_MUTEX) {
743       mutex.unlock();
744     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
745       spinLock.unlock();
746     }
747     // lockType == LOCK_NONE, no-op
748   }
749
750   SSLContext::SSLLockType lockType;
751   folly::SpinLock spinLock{};
752   std::mutex mutex;
753 };
754
755 // Statics are unsafe in environments that call exit().
756 // If one thread calls exit() while another thread is
757 // references a member of SSLContext, bad things can happen.
758 // SSLContext runs in such environments.
759 // Instead of declaring a static member we "new" the static
760 // member so that it won't be destructed on exit().
761 static std::unique_ptr<SSLLock[]>& locks() {
762   static auto locksInst = new std::unique_ptr<SSLLock[]>();
763   return *locksInst;
764 }
765
766 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
767   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
768   return *lockTypesInst;
769 }
770
771 static void callbackLocking(int mode, int n, const char*, int) {
772   if (mode & CRYPTO_LOCK) {
773     locks()[n].lock();
774   } else {
775     locks()[n].unlock();
776   }
777 }
778
779 static unsigned long callbackThreadID() {
780   return static_cast<unsigned long>(
781 #ifdef __APPLE__
782     pthread_mach_thread_np(pthread_self())
783 #elif _MSC_VER
784     pthread_getw32threadid_np(pthread_self())
785 #else
786     pthread_self()
787 #endif
788   );
789 }
790
791 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
792   return new CRYPTO_dynlock_value;
793 }
794
795 static void dyn_lock(int mode,
796                      struct CRYPTO_dynlock_value* lock,
797                      const char*, int) {
798   if (lock != nullptr) {
799     if (mode & CRYPTO_LOCK) {
800       lock->mutex.lock();
801     } else {
802       lock->mutex.unlock();
803     }
804   }
805 }
806
807 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
808   delete lock;
809 }
810
811 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
812   lockTypes() = inLockTypes;
813 }
814
815 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
816 void SSLContext::enableFalseStart() {
817   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
818 }
819 #endif
820
821 void SSLContext::markInitialized() {
822   std::lock_guard<std::mutex> g(initMutex());
823   initialized_ = true;
824 }
825
826 void SSLContext::initializeOpenSSL() {
827   std::lock_guard<std::mutex> g(initMutex());
828   initializeOpenSSLLocked();
829 }
830
831 void SSLContext::initializeOpenSSLLocked() {
832   if (initialized_) {
833     return;
834   }
835   SSL_library_init();
836   SSL_load_error_strings();
837   ERR_load_crypto_strings();
838   // static locking
839   locks().reset(new SSLLock[::CRYPTO_num_locks()]);
840   for (auto it: lockTypes()) {
841     locks()[it.first].lockType = it.second;
842   }
843   CRYPTO_set_id_callback(callbackThreadID);
844   CRYPTO_set_locking_callback(callbackLocking);
845   // dynamic locking
846   CRYPTO_set_dynlock_create_callback(dyn_create);
847   CRYPTO_set_dynlock_lock_callback(dyn_lock);
848   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
849   randomize();
850 #ifdef OPENSSL_NPN_NEGOTIATED
851   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
852       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
853 #endif
854   initialized_ = true;
855 }
856
857 void SSLContext::cleanupOpenSSL() {
858   std::lock_guard<std::mutex> g(initMutex());
859   cleanupOpenSSLLocked();
860 }
861
862 void SSLContext::cleanupOpenSSLLocked() {
863   if (!initialized_) {
864     return;
865   }
866
867   CRYPTO_set_id_callback(nullptr);
868   CRYPTO_set_locking_callback(nullptr);
869   CRYPTO_set_dynlock_create_callback(nullptr);
870   CRYPTO_set_dynlock_lock_callback(nullptr);
871   CRYPTO_set_dynlock_destroy_callback(nullptr);
872   CRYPTO_cleanup_all_ex_data();
873   ERR_free_strings();
874   EVP_cleanup();
875   ERR_remove_state(0);
876   locks().reset();
877   initialized_ = false;
878 }
879
880 void SSLContext::setOptions(long options) {
881   long newOpt = SSL_CTX_set_options(ctx_, options);
882   if ((newOpt & options) != options) {
883     throw std::runtime_error("SSL_CTX_set_options failed");
884   }
885 }
886
887 std::string SSLContext::getErrors(int errnoCopy) {
888   std::string errors;
889   unsigned long  errorCode;
890   char   message[256];
891
892   errors.reserve(512);
893   while ((errorCode = ERR_get_error()) != 0) {
894     if (!errors.empty()) {
895       errors += "; ";
896     }
897     const char* reason = ERR_reason_error_string(errorCode);
898     if (reason == nullptr) {
899       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
900       reason = message;
901     }
902     errors += reason;
903   }
904   if (errors.empty()) {
905     errors = "error code: " + folly::to<std::string>(errnoCopy);
906   }
907   return errors;
908 }
909
910 std::ostream&
911 operator<<(std::ostream& os, const PasswordCollector& collector) {
912   os << collector.describe();
913   return os;
914 }
915
916 } // folly