Clear OpenSSL error stack after loading certificate file.
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/Memory.h>
26 #include <folly/Random.h>
27 #include <folly/SpinLock.h>
28
29 // ---------------------------------------------------------------------
30 // SSLContext implementation
31 // ---------------------------------------------------------------------
32
33 struct CRYPTO_dynlock_value {
34   std::mutex mutex;
35 };
36
37 namespace folly {
38 //
39 // For OpenSSL portability API
40 using namespace folly::ssl;
41
42 bool SSLContext::initialized_ = false;
43
44 namespace {
45
46 std::mutex& initMutex() {
47   static std::mutex m;
48   return m;
49 }
50
51 } // anonymous namespace
52
53 #ifdef OPENSSL_NPN_NEGOTIATED
54 int SSLContext::sNextProtocolsExDataIndex_ = -1;
55 #endif
56
57 // SSLContext implementation
58 SSLContext::SSLContext(SSLVersion version) {
59   {
60     std::lock_guard<std::mutex> g(initMutex());
61     initializeOpenSSLLocked();
62   }
63
64   ctx_ = SSL_CTX_new(SSLv23_method());
65   if (ctx_ == nullptr) {
66     throw std::runtime_error("SSL_CTX_new: " + getErrors());
67   }
68
69   int opt = 0;
70   switch (version) {
71     case TLSv1:
72       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
73       break;
74     case SSLv3:
75       opt = SSL_OP_NO_SSLv2;
76       break;
77     default:
78       // do nothing
79       break;
80   }
81   int newOpt = SSL_CTX_set_options(ctx_, opt);
82   DCHECK((newOpt & opt) == opt);
83
84   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
85
86   checkPeerName_ = false;
87
88   SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
89
90 #if FOLLY_OPENSSL_HAS_SNI
91   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
92   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
93 #endif
94 }
95
96 SSLContext::~SSLContext() {
97   if (ctx_ != nullptr) {
98     SSL_CTX_free(ctx_);
99     ctx_ = nullptr;
100   }
101
102 #ifdef OPENSSL_NPN_NEGOTIATED
103   deleteNextProtocolsStrings();
104 #endif
105 }
106
107 void SSLContext::ciphers(const std::string& ciphers) {
108   providedCiphersString_ = ciphers;
109   setCiphersOrThrow(ciphers);
110 }
111
112 void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
113   if (ciphers.size() == 0) {
114     return;
115   }
116   std::string opensslCipherList;
117   join(":", ciphers, opensslCipherList);
118   setCiphersOrThrow(opensslCipherList);
119 }
120
121 void SSLContext::setSignatureAlgorithms(
122     const std::vector<std::string>& sigalgs) {
123   if (sigalgs.size() == 0) {
124     return;
125   }
126 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
127   std::string opensslSigAlgsList;
128   join(":", sigalgs, opensslSigAlgsList);
129   int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
130   if (rc == 0) {
131     throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
132   }
133 #endif
134 }
135
136 void SSLContext::setClientECCurvesList(
137     const std::vector<std::string>& ecCurves) {
138   if (ecCurves.size() == 0) {
139     return;
140   }
141 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
142   std::string ecCurvesList;
143   join(":", ecCurves, ecCurvesList);
144   int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
145   if (rc == 0) {
146     throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
147   }
148 #endif
149 }
150
151 void SSLContext::setServerECCurve(const std::string& curveName) {
152   bool validCall = false;
153 #if OPENSSL_VERSION_NUMBER >= 0x0090800fL
154 #ifndef OPENSSL_NO_ECDH
155   validCall = true;
156 #endif
157 #endif
158   if (!validCall) {
159     throw std::runtime_error("Elliptic curve encryption not allowed");
160   }
161
162   EC_KEY* ecdh = nullptr;
163   int nid;
164
165   /*
166    * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
167    * from RFC 4492 section 5.1.1, or explicitly described curves over
168    * binary fields. OpenSSL only supports the "named curves", which provide
169    * maximum interoperability.
170    */
171
172   nid = OBJ_sn2nid(curveName.c_str());
173   if (nid == 0) {
174     LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
175   }
176   ecdh = EC_KEY_new_by_curve_name(nid);
177   if (ecdh == nullptr) {
178     LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
179   }
180
181   SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
182   EC_KEY_free(ecdh);
183 }
184
185 void SSLContext::setX509VerifyParam(
186     const ssl::X509VerifyParam& x509VerifyParam) {
187   if (!x509VerifyParam) {
188     return;
189   }
190   if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
191     throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
192   }
193 }
194
195 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
196   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
197   if (rc == 0) {
198     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
199   }
200 }
201
202 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
203     verifyPeer) {
204   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
205   verifyPeer_ = verifyPeer;
206 }
207
208 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
209     verifyPeer) {
210   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
211   int mode = SSL_VERIFY_NONE;
212   switch(verifyPeer) {
213     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
214     // break;
215
216     case SSLVerifyPeerEnum::VERIFY:
217       mode = SSL_VERIFY_PEER;
218       break;
219
220     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
221       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
222       break;
223
224     case SSLVerifyPeerEnum::NO_VERIFY:
225       mode = SSL_VERIFY_NONE;
226       break;
227
228     default:
229       break;
230   }
231   return mode;
232 }
233
234 int SSLContext::getVerificationMode() {
235   return getVerificationMode(verifyPeer_);
236 }
237
238 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
239                               const std::string& peerName) {
240   int mode;
241   if (checkPeerCert) {
242     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
243     checkPeerName_ = checkPeerName;
244     peerFixedName_ = peerName;
245   } else {
246     mode = SSL_VERIFY_NONE;
247     checkPeerName_ = false; // can't check name without cert!
248     peerFixedName_.clear();
249   }
250   SSL_CTX_set_verify(ctx_, mode, nullptr);
251 }
252
253 void SSLContext::loadCertificate(const char* path, const char* format) {
254   if (path == nullptr || format == nullptr) {
255     throw std::invalid_argument(
256          "loadCertificateChain: either <path> or <format> is nullptr");
257   }
258   if (strcmp(format, "PEM") == 0) {
259     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
260       int errnoCopy = errno;
261       std::string reason("SSL_CTX_use_certificate_chain_file: ");
262       reason.append(path);
263       reason.append(": ");
264       reason.append(getErrors(errnoCopy));
265       throw std::runtime_error(reason);
266     }
267   } else {
268     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
269   }
270 }
271
272 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
273   if (cert.data() == nullptr) {
274     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
275   }
276
277   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
278   if (bio == nullptr) {
279     throw std::runtime_error("BIO_new: " + getErrors());
280   }
281
282   int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
283   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
284     throw std::runtime_error("BIO_write: " + getErrors());
285   }
286
287   ssl::X509UniquePtr x509(
288       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
289   if (x509 == nullptr) {
290     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
291   }
292
293   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
294     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
295   }
296 }
297
298 void SSLContext::loadPrivateKey(const char* path, const char* format) {
299   if (path == nullptr || format == nullptr) {
300     throw std::invalid_argument(
301         "loadPrivateKey: either <path> or <format> is nullptr");
302   }
303   if (strcmp(format, "PEM") == 0) {
304     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
305       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
306     }
307   } else {
308     throw std::runtime_error("Unsupported private key format: " + std::string(format));
309   }
310 }
311
312 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
313   if (pkey.data() == nullptr) {
314     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
315   }
316
317   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
318   if (bio == nullptr) {
319     throw std::runtime_error("BIO_new: " + getErrors());
320   }
321
322   int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
323   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
324     throw std::runtime_error("BIO_write: " + getErrors());
325   }
326
327   ssl::EvpPkeyUniquePtr key(
328       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
329   if (key == nullptr) {
330     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
331   }
332
333   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
334     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
335   }
336 }
337
338 void SSLContext::loadTrustedCertificates(const char* path) {
339   if (path == nullptr) {
340     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
341   }
342   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
343     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
344   }
345   ERR_clear_error();
346 }
347
348 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
349   SSL_CTX_set_cert_store(ctx_, store);
350 }
351
352 void SSLContext::loadClientCAList(const char* path) {
353   auto clientCAs = SSL_load_client_CA_file(path);
354   if (clientCAs == nullptr) {
355     LOG(ERROR) << "Unable to load ca file: " << path;
356     return;
357   }
358   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
359 }
360
361 void SSLContext::randomize() {
362   RAND_poll();
363 }
364
365 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
366   if (collector == nullptr) {
367     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
368     return;
369   }
370   collector_ = collector;
371   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
372   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
373 }
374
375 #if FOLLY_OPENSSL_HAS_SNI
376
377 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
378   serverNameCb_ = cb;
379 }
380
381 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
382   clientHelloCbs_.push_back(cb);
383 }
384
385 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
386   SSLContext* context = (SSLContext*)data;
387
388   if (context == nullptr) {
389     return SSL_TLSEXT_ERR_NOACK;
390   }
391
392   for (auto& cb : context->clientHelloCbs_) {
393     // Generic callbacks to happen after we receive the Client Hello.
394     // For example, we use one to switch which cipher we use depending
395     // on the user's TLS version.  Because the primary purpose of
396     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
397     // are side-uses, we ignore any possible failures other than just logging
398     // them.
399     cb(ssl);
400   }
401
402   if (!context->serverNameCb_) {
403     return SSL_TLSEXT_ERR_NOACK;
404   }
405
406   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
407   switch (ret) {
408     case SERVER_NAME_FOUND:
409       return SSL_TLSEXT_ERR_OK;
410     case SERVER_NAME_NOT_FOUND:
411       return SSL_TLSEXT_ERR_NOACK;
412     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
413       *al = TLS1_AD_UNRECOGNIZED_NAME;
414       return SSL_TLSEXT_ERR_ALERT_FATAL;
415     default:
416       CHECK(false);
417   }
418
419   return SSL_TLSEXT_ERR_NOACK;
420 }
421
422 void SSLContext::switchCiphersIfTLS11(
423     SSL* ssl,
424     const std::string& tls11CipherString,
425     const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
426   CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
427       << "Shouldn't call if empty ciphers / alt ciphers";
428
429   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
430     // We only do this for TLS v 1.1 and later
431     return;
432   }
433
434   const std::string* ciphers = &tls11CipherString;
435   if (!tls11AltCipherlist.empty()) {
436     if (!cipherListPicker_) {
437       std::vector<int> weights;
438       std::for_each(
439           tls11AltCipherlist.begin(),
440           tls11AltCipherlist.end(),
441           [&](const std::pair<std::string, int>& e) {
442             weights.push_back(e.second);
443           });
444       cipherListPicker_.reset(
445           new std::discrete_distribution<int>(weights.begin(), weights.end()));
446     }
447     auto rng = ThreadLocalPRNG();
448     auto index = (*cipherListPicker_)(rng);
449     if ((size_t)index >= tls11AltCipherlist.size()) {
450       LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
451                  << ", but tls11AltCipherlist is of length "
452                  << tls11AltCipherlist.size();
453     } else {
454       ciphers = &tls11AltCipherlist[index].first;
455     }
456   }
457
458   // Prefer AES for TLS versions 1.1 and later since these are not
459   // vulnerable to BEAST attacks on AES.  Note that we're setting the
460   // cipher list on the SSL object, not the SSL_CTX object, so it will
461   // only last for this request.
462   int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
463   if ((rc == 0) || ERR_peek_error() != 0) {
464     // This shouldn't happen since we checked for this when proxygen
465     // started up.
466     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
467     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
468   }
469 }
470 #endif // FOLLY_OPENSSL_HAS_SNI
471
472 #if FOLLY_OPENSSL_HAS_ALPN
473 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
474                                    const unsigned char** out,
475                                    unsigned char* outlen,
476                                    const unsigned char* in,
477                                    unsigned int inlen,
478                                    void* data) {
479   SSLContext* context = (SSLContext*)data;
480   CHECK(context);
481   if (context->advertisedNextProtocols_.empty()) {
482     *out = nullptr;
483     *outlen = 0;
484   } else {
485     auto i = context->pickNextProtocols();
486     const auto& item = context->advertisedNextProtocols_[i];
487     if (SSL_select_next_proto((unsigned char**)out,
488                               outlen,
489                               item.protocols,
490                               item.length,
491                               in,
492                               inlen) != OPENSSL_NPN_NEGOTIATED) {
493       return SSL_TLSEXT_ERR_NOACK;
494     }
495   }
496   return SSL_TLSEXT_ERR_OK;
497 }
498 #endif // FOLLY_OPENSSL_HAS_ALPN
499
500 #ifdef OPENSSL_NPN_NEGOTIATED
501
502 bool SSLContext::setAdvertisedNextProtocols(
503     const std::list<std::string>& protocols, NextProtocolType protocolType) {
504   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
505 }
506
507 bool SSLContext::setRandomizedAdvertisedNextProtocols(
508     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
509   unsetNextProtocols();
510   if (items.size() == 0) {
511     return false;
512   }
513   int total_weight = 0;
514   for (const auto &item : items) {
515     if (item.protocols.size() == 0) {
516       continue;
517     }
518     AdvertisedNextProtocolsItem advertised_item;
519     advertised_item.length = 0;
520     for (const auto& proto : item.protocols) {
521       ++advertised_item.length;
522       auto protoLength = proto.length();
523       if (protoLength >= 256) {
524         deleteNextProtocolsStrings();
525         return false;
526       }
527       advertised_item.length += unsigned(protoLength);
528     }
529     advertised_item.protocols = new unsigned char[advertised_item.length];
530     if (!advertised_item.protocols) {
531       throw std::runtime_error("alloc failure");
532     }
533     unsigned char* dst = advertised_item.protocols;
534     for (auto& proto : item.protocols) {
535       uint8_t protoLength = uint8_t(proto.length());
536       *dst++ = (unsigned char)protoLength;
537       memcpy(dst, proto.data(), protoLength);
538       dst += protoLength;
539     }
540     total_weight += item.weight;
541     advertisedNextProtocols_.push_back(advertised_item);
542     advertisedNextProtocolWeights_.push_back(item.weight);
543   }
544   if (total_weight == 0) {
545     deleteNextProtocolsStrings();
546     return false;
547   }
548   nextProtocolDistribution_ =
549       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
550                                    advertisedNextProtocolWeights_.end());
551   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
552     SSL_CTX_set_next_protos_advertised_cb(
553         ctx_, advertisedNextProtocolCallback, this);
554     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
555   }
556 #if FOLLY_OPENSSL_HAS_ALPN
557   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
558     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
559     // Client cannot really use randomized alpn
560     SSL_CTX_set_alpn_protos(ctx_,
561                             advertisedNextProtocols_[0].protocols,
562                             advertisedNextProtocols_[0].length);
563   }
564 #endif
565   return true;
566 }
567
568 void SSLContext::deleteNextProtocolsStrings() {
569   for (auto protocols : advertisedNextProtocols_) {
570     delete[] protocols.protocols;
571   }
572   advertisedNextProtocols_.clear();
573   advertisedNextProtocolWeights_.clear();
574 }
575
576 void SSLContext::unsetNextProtocols() {
577   deleteNextProtocolsStrings();
578   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
579   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
580 #if FOLLY_OPENSSL_HAS_ALPN
581   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
582   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
583 #endif
584 }
585
586 size_t SSLContext::pickNextProtocols() {
587   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
588   auto rng = ThreadLocalPRNG();
589   return nextProtocolDistribution_(rng);
590 }
591
592 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
593       const unsigned char** out, unsigned int* outlen, void* data) {
594   SSLContext* context = (SSLContext*)data;
595   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
596     *out = nullptr;
597     *outlen = 0;
598   } else if (context->advertisedNextProtocols_.size() == 1) {
599     *out = context->advertisedNextProtocols_[0].protocols;
600     *outlen = context->advertisedNextProtocols_[0].length;
601   } else {
602     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
603           sNextProtocolsExDataIndex_));
604     if (selected_index) {
605       --selected_index;
606       *out = context->advertisedNextProtocols_[selected_index].protocols;
607       *outlen = context->advertisedNextProtocols_[selected_index].length;
608     } else {
609       auto i = context->pickNextProtocols();
610       uintptr_t selected = i + 1;
611       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
612       *out = context->advertisedNextProtocols_[i].protocols;
613       *outlen = context->advertisedNextProtocols_[i].length;
614     }
615   }
616   return SSL_TLSEXT_ERR_OK;
617 }
618
619 int SSLContext::selectNextProtocolCallback(SSL* ssl,
620                                            unsigned char** out,
621                                            unsigned char* outlen,
622                                            const unsigned char* server,
623                                            unsigned int server_len,
624                                            void* data) {
625   (void)ssl; // Make -Wunused-parameters happy
626   SSLContext* ctx = (SSLContext*)data;
627   if (ctx->advertisedNextProtocols_.size() > 1) {
628     VLOG(3) << "SSLContext::selectNextProcolCallback() "
629             << "client should be deterministic in selecting protocols.";
630   }
631
632   unsigned char* client = nullptr;
633   unsigned int client_len = 0;
634   bool filtered = false;
635   auto cpf = ctx->getClientProtocolFilterCallback();
636   if (cpf) {
637     filtered = (*cpf)(&client, &client_len, server, server_len);
638   }
639
640   if (!filtered) {
641     if (ctx->advertisedNextProtocols_.empty()) {
642       client = (unsigned char *) "";
643       client_len = 0;
644     } else {
645       client = ctx->advertisedNextProtocols_[0].protocols;
646       client_len = ctx->advertisedNextProtocols_[0].length;
647     }
648   }
649
650   int retval = SSL_select_next_proto(out, outlen, server, server_len,
651                                      client, client_len);
652   if (retval != OPENSSL_NPN_NEGOTIATED) {
653     VLOG(3) << "SSLContext::selectNextProcolCallback() "
654             << "unable to pick a next protocol.";
655   }
656   return SSL_TLSEXT_ERR_OK;
657 }
658 #endif // OPENSSL_NPN_NEGOTIATED
659
660 SSL* SSLContext::createSSL() const {
661   SSL* ssl = SSL_new(ctx_);
662   if (ssl == nullptr) {
663     throw std::runtime_error("SSL_new: " + getErrors());
664   }
665   return ssl;
666 }
667
668 void SSLContext::setSessionCacheContext(const std::string& context) {
669   SSL_CTX_set_session_id_context(
670       ctx_,
671       reinterpret_cast<const unsigned char*>(context.data()),
672       std::min(
673           static_cast<int>(context.length()), SSL_MAX_SSL_SESSION_ID_LENGTH));
674 }
675
676 /**
677  * Match a name with a pattern. The pattern may include wildcard. A single
678  * wildcard "*" can match up to one component in the domain name.
679  *
680  * @param  host    Host name, typically the name of the remote host
681  * @param  pattern Name retrieved from certificate
682  * @param  size    Size of "pattern"
683  * @return True, if "host" matches "pattern". False otherwise.
684  */
685 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
686   bool match = false;
687   int i = 0, j = 0;
688   while (i < size && host[j] != '\0') {
689     if (toupper(pattern[i]) == toupper(host[j])) {
690       i++;
691       j++;
692       continue;
693     }
694     if (pattern[i] == '*') {
695       while (host[j] != '.' && host[j] != '\0') {
696         j++;
697       }
698       i++;
699       continue;
700     }
701     break;
702   }
703   if (i == size && host[j] == '\0') {
704     match = true;
705   }
706   return match;
707 }
708
709 int SSLContext::passwordCallback(char* password,
710                                  int size,
711                                  int,
712                                  void* data) {
713   SSLContext* context = (SSLContext*)data;
714   if (context == nullptr || context->passwordCollector() == nullptr) {
715     return 0;
716   }
717   std::string userPassword;
718   // call user defined password collector to get password
719   context->passwordCollector()->getPassword(userPassword, size);
720   auto length = int(userPassword.size());
721   if (length > size) {
722     length = size;
723   }
724   strncpy(password, userPassword.c_str(), length);
725   return length;
726 }
727
728 struct SSLLock {
729   explicit SSLLock(
730     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
731       lockType(inLockType) {
732   }
733
734   void lock() {
735     if (lockType == SSLContext::LOCK_MUTEX) {
736       mutex.lock();
737     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
738       spinLock.lock();
739     }
740     // lockType == LOCK_NONE, no-op
741   }
742
743   void unlock() {
744     if (lockType == SSLContext::LOCK_MUTEX) {
745       mutex.unlock();
746     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
747       spinLock.unlock();
748     }
749     // lockType == LOCK_NONE, no-op
750   }
751
752   SSLContext::SSLLockType lockType;
753   folly::SpinLock spinLock{};
754   std::mutex mutex;
755 };
756
757 // Statics are unsafe in environments that call exit().
758 // If one thread calls exit() while another thread is
759 // references a member of SSLContext, bad things can happen.
760 // SSLContext runs in such environments.
761 // Instead of declaring a static member we "new" the static
762 // member so that it won't be destructed on exit().
763 static std::unique_ptr<SSLLock[]>& locks() {
764   static auto locksInst = new std::unique_ptr<SSLLock[]>();
765   return *locksInst;
766 }
767
768 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
769   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
770   return *lockTypesInst;
771 }
772
773 static void callbackLocking(int mode, int n, const char*, int) {
774   if (mode & CRYPTO_LOCK) {
775     locks()[n].lock();
776   } else {
777     locks()[n].unlock();
778   }
779 }
780
781 static unsigned long callbackThreadID() {
782   return static_cast<unsigned long>(
783 #ifdef __APPLE__
784     pthread_mach_thread_np(pthread_self())
785 #elif _MSC_VER
786     pthread_getw32threadid_np(pthread_self())
787 #else
788     pthread_self()
789 #endif
790   );
791 }
792
793 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
794   return new CRYPTO_dynlock_value;
795 }
796
797 static void dyn_lock(int mode,
798                      struct CRYPTO_dynlock_value* lock,
799                      const char*, int) {
800   if (lock != nullptr) {
801     if (mode & CRYPTO_LOCK) {
802       lock->mutex.lock();
803     } else {
804       lock->mutex.unlock();
805     }
806   }
807 }
808
809 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
810   delete lock;
811 }
812
813 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
814   lockTypes() = inLockTypes;
815 }
816
817 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
818 void SSLContext::enableFalseStart() {
819   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
820 }
821 #endif
822
823 void SSLContext::markInitialized() {
824   std::lock_guard<std::mutex> g(initMutex());
825   initialized_ = true;
826 }
827
828 void SSLContext::initializeOpenSSL() {
829   std::lock_guard<std::mutex> g(initMutex());
830   initializeOpenSSLLocked();
831 }
832
833 void SSLContext::initializeOpenSSLLocked() {
834   if (initialized_) {
835     return;
836   }
837   SSL_library_init();
838   SSL_load_error_strings();
839   ERR_load_crypto_strings();
840   // static locking
841   locks().reset(new SSLLock[::CRYPTO_num_locks()]);
842   for (auto it: lockTypes()) {
843     locks()[it.first].lockType = it.second;
844   }
845   CRYPTO_set_id_callback(callbackThreadID);
846   CRYPTO_set_locking_callback(callbackLocking);
847   // dynamic locking
848   CRYPTO_set_dynlock_create_callback(dyn_create);
849   CRYPTO_set_dynlock_lock_callback(dyn_lock);
850   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
851   randomize();
852 #ifdef OPENSSL_NPN_NEGOTIATED
853   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
854       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
855 #endif
856   initialized_ = true;
857 }
858
859 void SSLContext::cleanupOpenSSL() {
860   std::lock_guard<std::mutex> g(initMutex());
861   cleanupOpenSSLLocked();
862 }
863
864 void SSLContext::cleanupOpenSSLLocked() {
865   if (!initialized_) {
866     return;
867   }
868
869   CRYPTO_set_id_callback(nullptr);
870   CRYPTO_set_locking_callback(nullptr);
871   CRYPTO_set_dynlock_create_callback(nullptr);
872   CRYPTO_set_dynlock_lock_callback(nullptr);
873   CRYPTO_set_dynlock_destroy_callback(nullptr);
874   CRYPTO_cleanup_all_ex_data();
875   ERR_free_strings();
876   EVP_cleanup();
877   ERR_remove_state(0);
878   locks().reset();
879   initialized_ = false;
880 }
881
882 void SSLContext::setOptions(long options) {
883   long newOpt = SSL_CTX_set_options(ctx_, options);
884   if ((newOpt & options) != options) {
885     throw std::runtime_error("SSL_CTX_set_options failed");
886   }
887 }
888
889 std::string SSLContext::getErrors(int errnoCopy) {
890   std::string errors;
891   unsigned long  errorCode;
892   char   message[256];
893
894   errors.reserve(512);
895   while ((errorCode = ERR_get_error()) != 0) {
896     if (!errors.empty()) {
897       errors += "; ";
898     }
899     const char* reason = ERR_reason_error_string(errorCode);
900     if (reason == nullptr) {
901       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
902       reason = message;
903     }
904     errors += reason;
905   }
906   if (errors.empty()) {
907     errors = "error code: " + folly::to<std::string>(errnoCopy);
908   }
909   return errors;
910 }
911
912 std::ostream&
913 operator<<(std::ostream& os, const PasswordCollector& collector) {
914   os << collector.describe();
915   return os;
916 }
917
918 } // folly