f3ca89f5b4b2af57adf793083d9b030009543933
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/Memory.h>
26 #include <folly/Random.h>
27 #include <folly/SpinLock.h>
28
29 // ---------------------------------------------------------------------
30 // SSLContext implementation
31 // ---------------------------------------------------------------------
32
33 struct CRYPTO_dynlock_value {
34   std::mutex mutex;
35 };
36
37 namespace folly {
38 //
39 // For OpenSSL portability API
40 using namespace folly::ssl;
41
42 bool SSLContext::initialized_ = false;
43
44 namespace {
45
46 std::mutex& initMutex() {
47   static std::mutex m;
48   return m;
49 }
50
51 } // anonymous namespace
52
53 #ifdef OPENSSL_NPN_NEGOTIATED
54 int SSLContext::sNextProtocolsExDataIndex_ = -1;
55 #endif
56
57 // SSLContext implementation
58 SSLContext::SSLContext(SSLVersion version) {
59   {
60     std::lock_guard<std::mutex> g(initMutex());
61     initializeOpenSSLLocked();
62   }
63
64   ctx_ = SSL_CTX_new(SSLv23_method());
65   if (ctx_ == nullptr) {
66     throw std::runtime_error("SSL_CTX_new: " + getErrors());
67   }
68
69   int opt = 0;
70   switch (version) {
71     case TLSv1:
72       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
73       break;
74     case SSLv3:
75       opt = SSL_OP_NO_SSLv2;
76       break;
77     default:
78       // do nothing
79       break;
80   }
81   int newOpt = SSL_CTX_set_options(ctx_, opt);
82   DCHECK((newOpt & opt) == opt);
83
84   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
85
86   checkPeerName_ = false;
87
88   SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
89
90 #if FOLLY_OPENSSL_HAS_SNI
91   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
92   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
93 #endif
94 }
95
96 SSLContext::~SSLContext() {
97   if (ctx_ != nullptr) {
98     SSL_CTX_free(ctx_);
99     ctx_ = nullptr;
100   }
101
102 #ifdef OPENSSL_NPN_NEGOTIATED
103   deleteNextProtocolsStrings();
104 #endif
105 }
106
107 void SSLContext::ciphers(const std::string& ciphers) {
108   providedCiphersString_ = ciphers;
109   setCiphersOrThrow(ciphers);
110 }
111
112 void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
113   if (ciphers.size() == 0) {
114     return;
115   }
116   std::string opensslCipherList;
117   join(":", ciphers, opensslCipherList);
118   setCiphersOrThrow(opensslCipherList);
119 }
120
121 void SSLContext::setSignatureAlgorithms(
122     const std::vector<std::string>& sigalgs) {
123   if (sigalgs.size() == 0) {
124     return;
125   }
126 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
127   std::string opensslSigAlgsList;
128   join(":", sigalgs, opensslSigAlgsList);
129   int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
130   if (rc == 0) {
131     throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
132   }
133 #endif
134 }
135
136 void SSLContext::setClientECCurvesList(
137     const std::vector<std::string>& ecCurves) {
138   if (ecCurves.size() == 0) {
139     return;
140   }
141 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
142   std::string ecCurvesList;
143   join(":", ecCurves, ecCurvesList);
144   int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
145   if (rc == 0) {
146     throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
147   }
148 #endif
149 }
150
151 void SSLContext::setServerECCurve(const std::string& curveName) {
152 #if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
153   EC_KEY* ecdh = nullptr;
154   int nid;
155
156   /*
157    * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
158    * from RFC 4492 section 5.1.1, or explicitly described curves over
159    * binary fields. OpenSSL only supports the "named curves", which provide
160    * maximum interoperability.
161    */
162
163   nid = OBJ_sn2nid(curveName.c_str());
164   if (nid == 0) {
165     LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
166   }
167   ecdh = EC_KEY_new_by_curve_name(nid);
168   if (ecdh == nullptr) {
169     LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
170   }
171
172   SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
173   EC_KEY_free(ecdh);
174 #else
175   throw std::runtime_error("Elliptic curve encryption not allowed");
176 #endif
177 }
178
179 void SSLContext::setX509VerifyParam(
180     const ssl::X509VerifyParam& x509VerifyParam) {
181   if (!x509VerifyParam) {
182     return;
183   }
184   if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
185     throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
186   }
187 }
188
189 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
190   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
191   if (rc == 0) {
192     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
193   }
194 }
195
196 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
197     verifyPeer) {
198   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
199   verifyPeer_ = verifyPeer;
200 }
201
202 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
203     verifyPeer) {
204   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
205   int mode = SSL_VERIFY_NONE;
206   switch(verifyPeer) {
207     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
208     // break;
209
210     case SSLVerifyPeerEnum::VERIFY:
211       mode = SSL_VERIFY_PEER;
212       break;
213
214     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
215       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
216       break;
217
218     case SSLVerifyPeerEnum::NO_VERIFY:
219       mode = SSL_VERIFY_NONE;
220       break;
221
222     default:
223       break;
224   }
225   return mode;
226 }
227
228 int SSLContext::getVerificationMode() {
229   return getVerificationMode(verifyPeer_);
230 }
231
232 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
233                               const std::string& peerName) {
234   int mode;
235   if (checkPeerCert) {
236     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
237     checkPeerName_ = checkPeerName;
238     peerFixedName_ = peerName;
239   } else {
240     mode = SSL_VERIFY_NONE;
241     checkPeerName_ = false; // can't check name without cert!
242     peerFixedName_.clear();
243   }
244   SSL_CTX_set_verify(ctx_, mode, nullptr);
245 }
246
247 void SSLContext::loadCertificate(const char* path, const char* format) {
248   if (path == nullptr || format == nullptr) {
249     throw std::invalid_argument(
250          "loadCertificateChain: either <path> or <format> is nullptr");
251   }
252   if (strcmp(format, "PEM") == 0) {
253     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
254       int errnoCopy = errno;
255       std::string reason("SSL_CTX_use_certificate_chain_file: ");
256       reason.append(path);
257       reason.append(": ");
258       reason.append(getErrors(errnoCopy));
259       throw std::runtime_error(reason);
260     }
261   } else {
262     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
263   }
264 }
265
266 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
267   if (cert.data() == nullptr) {
268     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
269   }
270
271   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
272   if (bio == nullptr) {
273     throw std::runtime_error("BIO_new: " + getErrors());
274   }
275
276   int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
277   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
278     throw std::runtime_error("BIO_write: " + getErrors());
279   }
280
281   ssl::X509UniquePtr x509(
282       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
283   if (x509 == nullptr) {
284     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
285   }
286
287   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
288     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
289   }
290 }
291
292 void SSLContext::loadPrivateKey(const char* path, const char* format) {
293   if (path == nullptr || format == nullptr) {
294     throw std::invalid_argument(
295         "loadPrivateKey: either <path> or <format> is nullptr");
296   }
297   if (strcmp(format, "PEM") == 0) {
298     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
299       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
300     }
301   } else {
302     throw std::runtime_error("Unsupported private key format: " + std::string(format));
303   }
304 }
305
306 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
307   if (pkey.data() == nullptr) {
308     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
309   }
310
311   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
312   if (bio == nullptr) {
313     throw std::runtime_error("BIO_new: " + getErrors());
314   }
315
316   int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
317   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
318     throw std::runtime_error("BIO_write: " + getErrors());
319   }
320
321   ssl::EvpPkeyUniquePtr key(
322       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
323   if (key == nullptr) {
324     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
325   }
326
327   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
328     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
329   }
330 }
331
332 void SSLContext::loadTrustedCertificates(const char* path) {
333   if (path == nullptr) {
334     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
335   }
336   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
337     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
338   }
339   ERR_clear_error();
340 }
341
342 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
343   SSL_CTX_set_cert_store(ctx_, store);
344 }
345
346 void SSLContext::loadClientCAList(const char* path) {
347   auto clientCAs = SSL_load_client_CA_file(path);
348   if (clientCAs == nullptr) {
349     LOG(ERROR) << "Unable to load ca file: " << path;
350     return;
351   }
352   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
353 }
354
355 void SSLContext::randomize() {
356   RAND_poll();
357 }
358
359 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
360   if (collector == nullptr) {
361     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
362     return;
363   }
364   collector_ = collector;
365   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
366   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
367 }
368
369 #if FOLLY_OPENSSL_HAS_SNI
370
371 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
372   serverNameCb_ = cb;
373 }
374
375 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
376   clientHelloCbs_.push_back(cb);
377 }
378
379 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
380   SSLContext* context = (SSLContext*)data;
381
382   if (context == nullptr) {
383     return SSL_TLSEXT_ERR_NOACK;
384   }
385
386   for (auto& cb : context->clientHelloCbs_) {
387     // Generic callbacks to happen after we receive the Client Hello.
388     // For example, we use one to switch which cipher we use depending
389     // on the user's TLS version.  Because the primary purpose of
390     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
391     // are side-uses, we ignore any possible failures other than just logging
392     // them.
393     cb(ssl);
394   }
395
396   if (!context->serverNameCb_) {
397     return SSL_TLSEXT_ERR_NOACK;
398   }
399
400   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
401   switch (ret) {
402     case SERVER_NAME_FOUND:
403       return SSL_TLSEXT_ERR_OK;
404     case SERVER_NAME_NOT_FOUND:
405       return SSL_TLSEXT_ERR_NOACK;
406     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
407       *al = TLS1_AD_UNRECOGNIZED_NAME;
408       return SSL_TLSEXT_ERR_ALERT_FATAL;
409     default:
410       CHECK(false);
411   }
412
413   return SSL_TLSEXT_ERR_NOACK;
414 }
415
416 void SSLContext::switchCiphersIfTLS11(
417     SSL* ssl,
418     const std::string& tls11CipherString,
419     const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
420   CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
421       << "Shouldn't call if empty ciphers / alt ciphers";
422
423   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
424     // We only do this for TLS v 1.1 and later
425     return;
426   }
427
428   const std::string* ciphers = &tls11CipherString;
429   if (!tls11AltCipherlist.empty()) {
430     if (!cipherListPicker_) {
431       std::vector<int> weights;
432       std::for_each(
433           tls11AltCipherlist.begin(),
434           tls11AltCipherlist.end(),
435           [&](const std::pair<std::string, int>& e) {
436             weights.push_back(e.second);
437           });
438       cipherListPicker_.reset(
439           new std::discrete_distribution<int>(weights.begin(), weights.end()));
440     }
441     auto rng = ThreadLocalPRNG();
442     auto index = (*cipherListPicker_)(rng);
443     if ((size_t)index >= tls11AltCipherlist.size()) {
444       LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
445                  << ", but tls11AltCipherlist is of length "
446                  << tls11AltCipherlist.size();
447     } else {
448       ciphers = &tls11AltCipherlist[size_t(index)].first;
449     }
450   }
451
452   // Prefer AES for TLS versions 1.1 and later since these are not
453   // vulnerable to BEAST attacks on AES.  Note that we're setting the
454   // cipher list on the SSL object, not the SSL_CTX object, so it will
455   // only last for this request.
456   int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
457   if ((rc == 0) || ERR_peek_error() != 0) {
458     // This shouldn't happen since we checked for this when proxygen
459     // started up.
460     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
461     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
462   }
463 }
464 #endif // FOLLY_OPENSSL_HAS_SNI
465
466 #if FOLLY_OPENSSL_HAS_ALPN
467 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
468                                    const unsigned char** out,
469                                    unsigned char* outlen,
470                                    const unsigned char* in,
471                                    unsigned int inlen,
472                                    void* data) {
473   SSLContext* context = (SSLContext*)data;
474   CHECK(context);
475   if (context->advertisedNextProtocols_.empty()) {
476     *out = nullptr;
477     *outlen = 0;
478   } else {
479     auto i = context->pickNextProtocols();
480     const auto& item = context->advertisedNextProtocols_[i];
481     if (SSL_select_next_proto((unsigned char**)out,
482                               outlen,
483                               item.protocols,
484                               item.length,
485                               in,
486                               inlen) != OPENSSL_NPN_NEGOTIATED) {
487       return SSL_TLSEXT_ERR_NOACK;
488     }
489   }
490   return SSL_TLSEXT_ERR_OK;
491 }
492 #endif // FOLLY_OPENSSL_HAS_ALPN
493
494 #ifdef OPENSSL_NPN_NEGOTIATED
495
496 bool SSLContext::setAdvertisedNextProtocols(
497     const std::list<std::string>& protocols, NextProtocolType protocolType) {
498   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
499 }
500
501 bool SSLContext::setRandomizedAdvertisedNextProtocols(
502     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
503   unsetNextProtocols();
504   if (items.size() == 0) {
505     return false;
506   }
507   int total_weight = 0;
508   for (const auto &item : items) {
509     if (item.protocols.size() == 0) {
510       continue;
511     }
512     AdvertisedNextProtocolsItem advertised_item;
513     advertised_item.length = 0;
514     for (const auto& proto : item.protocols) {
515       ++advertised_item.length;
516       auto protoLength = proto.length();
517       if (protoLength >= 256) {
518         deleteNextProtocolsStrings();
519         return false;
520       }
521       advertised_item.length += unsigned(protoLength);
522     }
523     advertised_item.protocols = new unsigned char[advertised_item.length];
524     if (!advertised_item.protocols) {
525       throw std::runtime_error("alloc failure");
526     }
527     unsigned char* dst = advertised_item.protocols;
528     for (auto& proto : item.protocols) {
529       uint8_t protoLength = uint8_t(proto.length());
530       *dst++ = (unsigned char)protoLength;
531       memcpy(dst, proto.data(), protoLength);
532       dst += protoLength;
533     }
534     total_weight += item.weight;
535     advertisedNextProtocols_.push_back(advertised_item);
536     advertisedNextProtocolWeights_.push_back(item.weight);
537   }
538   if (total_weight == 0) {
539     deleteNextProtocolsStrings();
540     return false;
541   }
542   nextProtocolDistribution_ =
543       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
544                                    advertisedNextProtocolWeights_.end());
545   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
546     SSL_CTX_set_next_protos_advertised_cb(
547         ctx_, advertisedNextProtocolCallback, this);
548     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
549   }
550 #if FOLLY_OPENSSL_HAS_ALPN
551   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
552     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
553     // Client cannot really use randomized alpn
554     SSL_CTX_set_alpn_protos(ctx_,
555                             advertisedNextProtocols_[0].protocols,
556                             advertisedNextProtocols_[0].length);
557   }
558 #endif
559   return true;
560 }
561
562 void SSLContext::deleteNextProtocolsStrings() {
563   for (auto protocols : advertisedNextProtocols_) {
564     delete[] protocols.protocols;
565   }
566   advertisedNextProtocols_.clear();
567   advertisedNextProtocolWeights_.clear();
568 }
569
570 void SSLContext::unsetNextProtocols() {
571   deleteNextProtocolsStrings();
572   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
573   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
574 #if FOLLY_OPENSSL_HAS_ALPN
575   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
576   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
577 #endif
578 }
579
580 size_t SSLContext::pickNextProtocols() {
581   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
582   auto rng = ThreadLocalPRNG();
583   return size_t(nextProtocolDistribution_(rng));
584 }
585
586 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
587       const unsigned char** out, unsigned int* outlen, void* data) {
588   SSLContext* context = (SSLContext*)data;
589   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
590     *out = nullptr;
591     *outlen = 0;
592   } else if (context->advertisedNextProtocols_.size() == 1) {
593     *out = context->advertisedNextProtocols_[0].protocols;
594     *outlen = context->advertisedNextProtocols_[0].length;
595   } else {
596     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
597           sNextProtocolsExDataIndex_));
598     if (selected_index) {
599       --selected_index;
600       *out = context->advertisedNextProtocols_[selected_index].protocols;
601       *outlen = context->advertisedNextProtocols_[selected_index].length;
602     } else {
603       auto i = context->pickNextProtocols();
604       uintptr_t selected = i + 1;
605       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
606       *out = context->advertisedNextProtocols_[i].protocols;
607       *outlen = context->advertisedNextProtocols_[i].length;
608     }
609   }
610   return SSL_TLSEXT_ERR_OK;
611 }
612
613 int SSLContext::selectNextProtocolCallback(SSL* ssl,
614                                            unsigned char** out,
615                                            unsigned char* outlen,
616                                            const unsigned char* server,
617                                            unsigned int server_len,
618                                            void* data) {
619   (void)ssl; // Make -Wunused-parameters happy
620   SSLContext* ctx = (SSLContext*)data;
621   if (ctx->advertisedNextProtocols_.size() > 1) {
622     VLOG(3) << "SSLContext::selectNextProcolCallback() "
623             << "client should be deterministic in selecting protocols.";
624   }
625
626   unsigned char* client = nullptr;
627   unsigned int client_len = 0;
628   bool filtered = false;
629   auto cpf = ctx->getClientProtocolFilterCallback();
630   if (cpf) {
631     filtered = (*cpf)(&client, &client_len, server, server_len);
632   }
633
634   if (!filtered) {
635     if (ctx->advertisedNextProtocols_.empty()) {
636       client = (unsigned char *) "";
637       client_len = 0;
638     } else {
639       client = ctx->advertisedNextProtocols_[0].protocols;
640       client_len = ctx->advertisedNextProtocols_[0].length;
641     }
642   }
643
644   int retval = SSL_select_next_proto(out, outlen, server, server_len,
645                                      client, client_len);
646   if (retval != OPENSSL_NPN_NEGOTIATED) {
647     VLOG(3) << "SSLContext::selectNextProcolCallback() "
648             << "unable to pick a next protocol.";
649   }
650   return SSL_TLSEXT_ERR_OK;
651 }
652 #endif // OPENSSL_NPN_NEGOTIATED
653
654 SSL* SSLContext::createSSL() const {
655   SSL* ssl = SSL_new(ctx_);
656   if (ssl == nullptr) {
657     throw std::runtime_error("SSL_new: " + getErrors());
658   }
659   return ssl;
660 }
661
662 void SSLContext::setSessionCacheContext(const std::string& context) {
663   SSL_CTX_set_session_id_context(
664       ctx_,
665       reinterpret_cast<const unsigned char*>(context.data()),
666       std::min<unsigned int>(
667           static_cast<unsigned int>(context.length()),
668           SSL_MAX_SSL_SESSION_ID_LENGTH));
669 }
670
671 /**
672  * Match a name with a pattern. The pattern may include wildcard. A single
673  * wildcard "*" can match up to one component in the domain name.
674  *
675  * @param  host    Host name, typically the name of the remote host
676  * @param  pattern Name retrieved from certificate
677  * @param  size    Size of "pattern"
678  * @return True, if "host" matches "pattern". False otherwise.
679  */
680 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
681   bool match = false;
682   int i = 0, j = 0;
683   while (i < size && host[j] != '\0') {
684     if (toupper(pattern[i]) == toupper(host[j])) {
685       i++;
686       j++;
687       continue;
688     }
689     if (pattern[i] == '*') {
690       while (host[j] != '.' && host[j] != '\0') {
691         j++;
692       }
693       i++;
694       continue;
695     }
696     break;
697   }
698   if (i == size && host[j] == '\0') {
699     match = true;
700   }
701   return match;
702 }
703
704 int SSLContext::passwordCallback(char* password,
705                                  int size,
706                                  int,
707                                  void* data) {
708   SSLContext* context = (SSLContext*)data;
709   if (context == nullptr || context->passwordCollector() == nullptr) {
710     return 0;
711   }
712   std::string userPassword;
713   // call user defined password collector to get password
714   context->passwordCollector()->getPassword(userPassword, size);
715   auto length = int(userPassword.size());
716   if (length > size) {
717     length = size;
718   }
719   strncpy(password, userPassword.c_str(), size_t(length));
720   return length;
721 }
722
723 struct SSLLock {
724   explicit SSLLock(
725     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
726       lockType(inLockType) {
727   }
728
729   void lock() {
730     if (lockType == SSLContext::LOCK_MUTEX) {
731       mutex.lock();
732     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
733       spinLock.lock();
734     }
735     // lockType == LOCK_NONE, no-op
736   }
737
738   void unlock() {
739     if (lockType == SSLContext::LOCK_MUTEX) {
740       mutex.unlock();
741     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
742       spinLock.unlock();
743     }
744     // lockType == LOCK_NONE, no-op
745   }
746
747   SSLContext::SSLLockType lockType;
748   folly::SpinLock spinLock{};
749   std::mutex mutex;
750 };
751
752 // Statics are unsafe in environments that call exit().
753 // If one thread calls exit() while another thread is
754 // references a member of SSLContext, bad things can happen.
755 // SSLContext runs in such environments.
756 // Instead of declaring a static member we "new" the static
757 // member so that it won't be destructed on exit().
758 static std::unique_ptr<SSLLock[]>& locks() {
759   static auto locksInst = new std::unique_ptr<SSLLock[]>();
760   return *locksInst;
761 }
762
763 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
764   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
765   return *lockTypesInst;
766 }
767
768 static void callbackLocking(int mode, int n, const char*, int) {
769   if (mode & CRYPTO_LOCK) {
770     locks()[size_t(n)].lock();
771   } else {
772     locks()[size_t(n)].unlock();
773   }
774 }
775
776 static unsigned long callbackThreadID() {
777   return static_cast<unsigned long>(
778 #ifdef __APPLE__
779     pthread_mach_thread_np(pthread_self())
780 #elif _MSC_VER
781     pthread_getw32threadid_np(pthread_self())
782 #else
783     pthread_self()
784 #endif
785   );
786 }
787
788 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
789   return new CRYPTO_dynlock_value;
790 }
791
792 static void dyn_lock(int mode,
793                      struct CRYPTO_dynlock_value* lock,
794                      const char*, int) {
795   if (lock != nullptr) {
796     if (mode & CRYPTO_LOCK) {
797       lock->mutex.lock();
798     } else {
799       lock->mutex.unlock();
800     }
801   }
802 }
803
804 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
805   delete lock;
806 }
807
808 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
809   lockTypes() = inLockTypes;
810 }
811
812 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
813 void SSLContext::enableFalseStart() {
814   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
815 }
816 #endif
817
818 void SSLContext::markInitialized() {
819   std::lock_guard<std::mutex> g(initMutex());
820   initialized_ = true;
821 }
822
823 void SSLContext::initializeOpenSSL() {
824   std::lock_guard<std::mutex> g(initMutex());
825   initializeOpenSSLLocked();
826 }
827
828 void SSLContext::initializeOpenSSLLocked() {
829   if (initialized_) {
830     return;
831   }
832   SSL_library_init();
833   SSL_load_error_strings();
834   ERR_load_crypto_strings();
835   // static locking
836   locks().reset(new SSLLock[size_t(CRYPTO_num_locks())]);
837   for (auto it: lockTypes()) {
838     locks()[size_t(it.first)].lockType = it.second;
839   }
840   CRYPTO_set_id_callback(callbackThreadID);
841   CRYPTO_set_locking_callback(callbackLocking);
842   // dynamic locking
843   CRYPTO_set_dynlock_create_callback(dyn_create);
844   CRYPTO_set_dynlock_lock_callback(dyn_lock);
845   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
846   randomize();
847 #ifdef OPENSSL_NPN_NEGOTIATED
848   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
849       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
850 #endif
851   initialized_ = true;
852 }
853
854 void SSLContext::cleanupOpenSSL() {
855   std::lock_guard<std::mutex> g(initMutex());
856   cleanupOpenSSLLocked();
857 }
858
859 void SSLContext::cleanupOpenSSLLocked() {
860   if (!initialized_) {
861     return;
862   }
863
864   CRYPTO_set_id_callback(nullptr);
865   CRYPTO_set_locking_callback(nullptr);
866   CRYPTO_set_dynlock_create_callback(nullptr);
867   CRYPTO_set_dynlock_lock_callback(nullptr);
868   CRYPTO_set_dynlock_destroy_callback(nullptr);
869   CRYPTO_cleanup_all_ex_data();
870   ERR_free_strings();
871   EVP_cleanup();
872   ERR_clear_error();
873   locks().reset();
874   initialized_ = false;
875 }
876
877 void SSLContext::setOptions(long options) {
878   long newOpt = SSL_CTX_set_options(ctx_, options);
879   if ((newOpt & options) != options) {
880     throw std::runtime_error("SSL_CTX_set_options failed");
881   }
882 }
883
884 std::string SSLContext::getErrors(int errnoCopy) {
885   std::string errors;
886   unsigned long  errorCode;
887   char   message[256];
888
889   errors.reserve(512);
890   while ((errorCode = ERR_get_error()) != 0) {
891     if (!errors.empty()) {
892       errors += "; ";
893     }
894     const char* reason = ERR_reason_error_string(errorCode);
895     if (reason == nullptr) {
896       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
897       reason = message;
898     }
899     errors += reason;
900   }
901   if (errors.empty()) {
902     errors = "error code: " + folly::to<std::string>(errnoCopy);
903   }
904   return errors;
905 }
906
907 std::ostream&
908 operator<<(std::ostream& os, const PasswordCollector& collector) {
909   os << collector.describe();
910   return os;
911 }
912
913 } // folly