Move OpenSSL locking code out of SSLContext
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <folly/Format.h>
20 #include <folly/Memory.h>
21 #include <folly/Random.h>
22 #include <folly/SharedMutex.h>
23 #include <folly/SpinLock.h>
24 #include <folly/ThreadId.h>
25 #include <folly/ssl/Init.h>
26
27 // ---------------------------------------------------------------------
28 // SSLContext implementation
29 // ---------------------------------------------------------------------
30 namespace folly {
31 //
32 // For OpenSSL portability API
33 using namespace folly::ssl;
34
35 // SSLContext implementation
36 SSLContext::SSLContext(SSLVersion version) {
37   folly::ssl::init();
38
39   ctx_ = SSL_CTX_new(SSLv23_method());
40   if (ctx_ == nullptr) {
41     throw std::runtime_error("SSL_CTX_new: " + getErrors());
42   }
43
44   int opt = 0;
45   switch (version) {
46     case TLSv1:
47       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
48       break;
49     case SSLv3:
50       opt = SSL_OP_NO_SSLv2;
51       break;
52     default:
53       // do nothing
54       break;
55   }
56   int newOpt = SSL_CTX_set_options(ctx_, opt);
57   DCHECK((newOpt & opt) == opt);
58
59   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
60
61   checkPeerName_ = false;
62
63   SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
64
65 #if FOLLY_OPENSSL_HAS_SNI
66   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
67   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
68 #endif
69 }
70
71 SSLContext::~SSLContext() {
72   if (ctx_ != nullptr) {
73     SSL_CTX_free(ctx_);
74     ctx_ = nullptr;
75   }
76
77 #ifdef OPENSSL_NPN_NEGOTIATED
78   deleteNextProtocolsStrings();
79 #endif
80 }
81
82 void SSLContext::ciphers(const std::string& ciphers) {
83   setCiphersOrThrow(ciphers);
84 }
85
86 void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
87   if (ciphers.size() == 0) {
88     return;
89   }
90   std::string opensslCipherList;
91   join(":", ciphers, opensslCipherList);
92   setCiphersOrThrow(opensslCipherList);
93 }
94
95 void SSLContext::setSignatureAlgorithms(
96     const std::vector<std::string>& sigalgs) {
97   if (sigalgs.size() == 0) {
98     return;
99   }
100 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
101   std::string opensslSigAlgsList;
102   join(":", sigalgs, opensslSigAlgsList);
103   int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
104   if (rc == 0) {
105     throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
106   }
107 #endif
108 }
109
110 void SSLContext::setClientECCurvesList(
111     const std::vector<std::string>& ecCurves) {
112   if (ecCurves.size() == 0) {
113     return;
114   }
115 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
116   std::string ecCurvesList;
117   join(":", ecCurves, ecCurvesList);
118   int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
119   if (rc == 0) {
120     throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
121   }
122 #endif
123 }
124
125 void SSLContext::setServerECCurve(const std::string& curveName) {
126 #if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
127   EC_KEY* ecdh = nullptr;
128   int nid;
129
130   /*
131    * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
132    * from RFC 4492 section 5.1.1, or explicitly described curves over
133    * binary fields. OpenSSL only supports the "named curves", which provide
134    * maximum interoperability.
135    */
136
137   nid = OBJ_sn2nid(curveName.c_str());
138   if (nid == 0) {
139     LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
140   }
141   ecdh = EC_KEY_new_by_curve_name(nid);
142   if (ecdh == nullptr) {
143     LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
144   }
145
146   SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
147   EC_KEY_free(ecdh);
148 #else
149   throw std::runtime_error("Elliptic curve encryption not allowed");
150 #endif
151 }
152
153 void SSLContext::setX509VerifyParam(
154     const ssl::X509VerifyParam& x509VerifyParam) {
155   if (!x509VerifyParam) {
156     return;
157   }
158   if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
159     throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
160   }
161 }
162
163 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
164   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
165   if (rc == 0) {
166     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
167   }
168   providedCiphersString_ = ciphers;
169 }
170
171 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
172     verifyPeer) {
173   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
174   verifyPeer_ = verifyPeer;
175 }
176
177 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
178     verifyPeer) {
179   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
180   int mode = SSL_VERIFY_NONE;
181   switch(verifyPeer) {
182     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
183     // break;
184
185     case SSLVerifyPeerEnum::VERIFY:
186       mode = SSL_VERIFY_PEER;
187       break;
188
189     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
190       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
191       break;
192
193     case SSLVerifyPeerEnum::NO_VERIFY:
194       mode = SSL_VERIFY_NONE;
195       break;
196
197     default:
198       break;
199   }
200   return mode;
201 }
202
203 int SSLContext::getVerificationMode() {
204   return getVerificationMode(verifyPeer_);
205 }
206
207 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
208                               const std::string& peerName) {
209   int mode;
210   if (checkPeerCert) {
211     mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
212         SSL_VERIFY_CLIENT_ONCE;
213     checkPeerName_ = checkPeerName;
214     peerFixedName_ = peerName;
215   } else {
216     mode = SSL_VERIFY_NONE;
217     checkPeerName_ = false; // can't check name without cert!
218     peerFixedName_.clear();
219   }
220   SSL_CTX_set_verify(ctx_, mode, nullptr);
221 }
222
223 void SSLContext::loadCertificate(const char* path, const char* format) {
224   if (path == nullptr || format == nullptr) {
225     throw std::invalid_argument(
226          "loadCertificateChain: either <path> or <format> is nullptr");
227   }
228   if (strcmp(format, "PEM") == 0) {
229     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
230       int errnoCopy = errno;
231       std::string reason("SSL_CTX_use_certificate_chain_file: ");
232       reason.append(path);
233       reason.append(": ");
234       reason.append(getErrors(errnoCopy));
235       throw std::runtime_error(reason);
236     }
237   } else {
238     throw std::runtime_error(
239         "Unsupported certificate format: " + std::string(format));
240   }
241 }
242
243 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
244   if (cert.data() == nullptr) {
245     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
246   }
247
248   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
249   if (bio == nullptr) {
250     throw std::runtime_error("BIO_new: " + getErrors());
251   }
252
253   int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
254   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
255     throw std::runtime_error("BIO_write: " + getErrors());
256   }
257
258   ssl::X509UniquePtr x509(
259       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
260   if (x509 == nullptr) {
261     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
262   }
263
264   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
265     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
266   }
267 }
268
269 void SSLContext::loadPrivateKey(const char* path, const char* format) {
270   if (path == nullptr || format == nullptr) {
271     throw std::invalid_argument(
272         "loadPrivateKey: either <path> or <format> is nullptr");
273   }
274   if (strcmp(format, "PEM") == 0) {
275     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
276       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
277     }
278   } else {
279     throw std::runtime_error(
280         "Unsupported private key format: " + std::string(format));
281   }
282 }
283
284 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
285   if (pkey.data() == nullptr) {
286     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
287   }
288
289   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
290   if (bio == nullptr) {
291     throw std::runtime_error("BIO_new: " + getErrors());
292   }
293
294   int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
295   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
296     throw std::runtime_error("BIO_write: " + getErrors());
297   }
298
299   ssl::EvpPkeyUniquePtr key(
300       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
301   if (key == nullptr) {
302     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
303   }
304
305   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
306     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
307   }
308 }
309
310 void SSLContext::loadTrustedCertificates(const char* path) {
311   if (path == nullptr) {
312     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
313   }
314   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
315     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
316   }
317   ERR_clear_error();
318 }
319
320 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
321   SSL_CTX_set_cert_store(ctx_, store);
322 }
323
324 void SSLContext::loadClientCAList(const char* path) {
325   auto clientCAs = SSL_load_client_CA_file(path);
326   if (clientCAs == nullptr) {
327     LOG(ERROR) << "Unable to load ca file: " << path;
328     return;
329   }
330   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
331 }
332
333 void SSLContext::passwordCollector(
334     std::shared_ptr<PasswordCollector> collector) {
335   if (collector == nullptr) {
336     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
337     return;
338   }
339   collector_ = collector;
340   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
341   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
342 }
343
344 #if FOLLY_OPENSSL_HAS_SNI
345
346 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
347   serverNameCb_ = cb;
348 }
349
350 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
351   clientHelloCbs_.push_back(cb);
352 }
353
354 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
355   SSLContext* context = (SSLContext*)data;
356
357   if (context == nullptr) {
358     return SSL_TLSEXT_ERR_NOACK;
359   }
360
361   for (auto& cb : context->clientHelloCbs_) {
362     // Generic callbacks to happen after we receive the Client Hello.
363     // For example, we use one to switch which cipher we use depending
364     // on the user's TLS version.  Because the primary purpose of
365     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
366     // are side-uses, we ignore any possible failures other than just logging
367     // them.
368     cb(ssl);
369   }
370
371   if (!context->serverNameCb_) {
372     return SSL_TLSEXT_ERR_NOACK;
373   }
374
375   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
376   switch (ret) {
377     case SERVER_NAME_FOUND:
378       return SSL_TLSEXT_ERR_OK;
379     case SERVER_NAME_NOT_FOUND:
380       return SSL_TLSEXT_ERR_NOACK;
381     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
382       *al = TLS1_AD_UNRECOGNIZED_NAME;
383       return SSL_TLSEXT_ERR_ALERT_FATAL;
384     default:
385       CHECK(false);
386   }
387
388   return SSL_TLSEXT_ERR_NOACK;
389 }
390
391 void SSLContext::switchCiphersIfTLS11(
392     SSL* ssl,
393     const std::string& tls11CipherString,
394     const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
395   CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
396       << "Shouldn't call if empty ciphers / alt ciphers";
397
398   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
399     // We only do this for TLS v 1.1 and later
400     return;
401   }
402
403   const std::string* ciphers = &tls11CipherString;
404   if (!tls11AltCipherlist.empty()) {
405     if (!cipherListPicker_) {
406       std::vector<int> weights;
407       std::for_each(
408           tls11AltCipherlist.begin(),
409           tls11AltCipherlist.end(),
410           [&](const std::pair<std::string, int>& e) {
411             weights.push_back(e.second);
412           });
413       cipherListPicker_.reset(
414           new std::discrete_distribution<int>(weights.begin(), weights.end()));
415     }
416     auto rng = ThreadLocalPRNG();
417     auto index = (*cipherListPicker_)(rng);
418     if ((size_t)index >= tls11AltCipherlist.size()) {
419       LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
420                  << ", but tls11AltCipherlist is of length "
421                  << tls11AltCipherlist.size();
422     } else {
423       ciphers = &tls11AltCipherlist[size_t(index)].first;
424     }
425   }
426
427   // Prefer AES for TLS versions 1.1 and later since these are not
428   // vulnerable to BEAST attacks on AES.  Note that we're setting the
429   // cipher list on the SSL object, not the SSL_CTX object, so it will
430   // only last for this request.
431   int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
432   if ((rc == 0) || ERR_peek_error() != 0) {
433     // This shouldn't happen since we checked for this when proxygen
434     // started up.
435     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
436     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
437   }
438 }
439 #endif // FOLLY_OPENSSL_HAS_SNI
440
441 #if FOLLY_OPENSSL_HAS_ALPN
442 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
443                                    const unsigned char** out,
444                                    unsigned char* outlen,
445                                    const unsigned char* in,
446                                    unsigned int inlen,
447                                    void* data) {
448   SSLContext* context = (SSLContext*)data;
449   CHECK(context);
450   if (context->advertisedNextProtocols_.empty()) {
451     *out = nullptr;
452     *outlen = 0;
453   } else {
454     auto i = context->pickNextProtocols();
455     const auto& item = context->advertisedNextProtocols_[i];
456     if (SSL_select_next_proto((unsigned char**)out,
457                               outlen,
458                               item.protocols,
459                               item.length,
460                               in,
461                               inlen) != OPENSSL_NPN_NEGOTIATED) {
462       return SSL_TLSEXT_ERR_NOACK;
463     }
464   }
465   return SSL_TLSEXT_ERR_OK;
466 }
467 #endif // FOLLY_OPENSSL_HAS_ALPN
468
469 #ifdef OPENSSL_NPN_NEGOTIATED
470
471 bool SSLContext::setAdvertisedNextProtocols(
472     const std::list<std::string>& protocols, NextProtocolType protocolType) {
473   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
474 }
475
476 bool SSLContext::setRandomizedAdvertisedNextProtocols(
477     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
478   unsetNextProtocols();
479   if (items.size() == 0) {
480     return false;
481   }
482   int total_weight = 0;
483   for (const auto &item : items) {
484     if (item.protocols.size() == 0) {
485       continue;
486     }
487     AdvertisedNextProtocolsItem advertised_item;
488     advertised_item.length = 0;
489     for (const auto& proto : item.protocols) {
490       ++advertised_item.length;
491       auto protoLength = proto.length();
492       if (protoLength >= 256) {
493         deleteNextProtocolsStrings();
494         return false;
495       }
496       advertised_item.length += unsigned(protoLength);
497     }
498     advertised_item.protocols = new unsigned char[advertised_item.length];
499     if (!advertised_item.protocols) {
500       throw std::runtime_error("alloc failure");
501     }
502     unsigned char* dst = advertised_item.protocols;
503     for (auto& proto : item.protocols) {
504       uint8_t protoLength = uint8_t(proto.length());
505       *dst++ = (unsigned char)protoLength;
506       memcpy(dst, proto.data(), protoLength);
507       dst += protoLength;
508     }
509     total_weight += item.weight;
510     advertisedNextProtocols_.push_back(advertised_item);
511     advertisedNextProtocolWeights_.push_back(item.weight);
512   }
513   if (total_weight == 0) {
514     deleteNextProtocolsStrings();
515     return false;
516   }
517   nextProtocolDistribution_ =
518       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
519                                    advertisedNextProtocolWeights_.end());
520   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
521     SSL_CTX_set_next_protos_advertised_cb(
522         ctx_, advertisedNextProtocolCallback, this);
523     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
524   }
525 #if FOLLY_OPENSSL_HAS_ALPN
526   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
527     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
528     // Client cannot really use randomized alpn
529     SSL_CTX_set_alpn_protos(ctx_,
530                             advertisedNextProtocols_[0].protocols,
531                             advertisedNextProtocols_[0].length);
532   }
533 #endif
534   return true;
535 }
536
537 void SSLContext::deleteNextProtocolsStrings() {
538   for (auto protocols : advertisedNextProtocols_) {
539     delete[] protocols.protocols;
540   }
541   advertisedNextProtocols_.clear();
542   advertisedNextProtocolWeights_.clear();
543 }
544
545 void SSLContext::unsetNextProtocols() {
546   deleteNextProtocolsStrings();
547   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
548   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
549 #if FOLLY_OPENSSL_HAS_ALPN
550   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
551   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
552 #endif
553 }
554
555 size_t SSLContext::pickNextProtocols() {
556   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
557   auto rng = ThreadLocalPRNG();
558   return size_t(nextProtocolDistribution_(rng));
559 }
560
561 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
562       const unsigned char** out, unsigned int* outlen, void* data) {
563   static int nextProtocolsExDataIndex = SSL_get_ex_new_index(
564       0, (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
565
566   SSLContext* context = (SSLContext*)data;
567   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
568     *out = nullptr;
569     *outlen = 0;
570   } else if (context->advertisedNextProtocols_.size() == 1) {
571     *out = context->advertisedNextProtocols_[0].protocols;
572     *outlen = context->advertisedNextProtocols_[0].length;
573   } else {
574     uintptr_t selected_index = reinterpret_cast<uintptr_t>(
575         SSL_get_ex_data(ssl, nextProtocolsExDataIndex));
576     if (selected_index) {
577       --selected_index;
578       *out = context->advertisedNextProtocols_[selected_index].protocols;
579       *outlen = context->advertisedNextProtocols_[selected_index].length;
580     } else {
581       auto i = context->pickNextProtocols();
582       uintptr_t selected = i + 1;
583       SSL_set_ex_data(ssl, nextProtocolsExDataIndex, (void*)selected);
584       *out = context->advertisedNextProtocols_[i].protocols;
585       *outlen = context->advertisedNextProtocols_[i].length;
586     }
587   }
588   return SSL_TLSEXT_ERR_OK;
589 }
590
591 int SSLContext::selectNextProtocolCallback(SSL* ssl,
592                                            unsigned char** out,
593                                            unsigned char* outlen,
594                                            const unsigned char* server,
595                                            unsigned int server_len,
596                                            void* data) {
597   (void)ssl; // Make -Wunused-parameters happy
598   SSLContext* ctx = (SSLContext*)data;
599   if (ctx->advertisedNextProtocols_.size() > 1) {
600     VLOG(3) << "SSLContext::selectNextProcolCallback() "
601             << "client should be deterministic in selecting protocols.";
602   }
603
604   unsigned char* client = nullptr;
605   unsigned int client_len = 0;
606   bool filtered = false;
607   auto cpf = ctx->getClientProtocolFilterCallback();
608   if (cpf) {
609     filtered = (*cpf)(&client, &client_len, server, server_len);
610   }
611
612   if (!filtered) {
613     if (ctx->advertisedNextProtocols_.empty()) {
614       client = (unsigned char *) "";
615       client_len = 0;
616     } else {
617       client = ctx->advertisedNextProtocols_[0].protocols;
618       client_len = ctx->advertisedNextProtocols_[0].length;
619     }
620   }
621
622   int retval = SSL_select_next_proto(out, outlen, server, server_len,
623                                      client, client_len);
624   if (retval != OPENSSL_NPN_NEGOTIATED) {
625     VLOG(3) << "SSLContext::selectNextProcolCallback() "
626             << "unable to pick a next protocol.";
627   }
628   return SSL_TLSEXT_ERR_OK;
629 }
630 #endif // OPENSSL_NPN_NEGOTIATED
631
632 SSL* SSLContext::createSSL() const {
633   SSL* ssl = SSL_new(ctx_);
634   if (ssl == nullptr) {
635     throw std::runtime_error("SSL_new: " + getErrors());
636   }
637   return ssl;
638 }
639
640 void SSLContext::setSessionCacheContext(const std::string& context) {
641   SSL_CTX_set_session_id_context(
642       ctx_,
643       reinterpret_cast<const unsigned char*>(context.data()),
644       std::min<unsigned int>(
645           static_cast<unsigned int>(context.length()),
646           SSL_MAX_SSL_SESSION_ID_LENGTH));
647 }
648
649 /**
650  * Match a name with a pattern. The pattern may include wildcard. A single
651  * wildcard "*" can match up to one component in the domain name.
652  *
653  * @param  host    Host name, typically the name of the remote host
654  * @param  pattern Name retrieved from certificate
655  * @param  size    Size of "pattern"
656  * @return True, if "host" matches "pattern". False otherwise.
657  */
658 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
659   bool match = false;
660   int i = 0, j = 0;
661   while (i < size && host[j] != '\0') {
662     if (toupper(pattern[i]) == toupper(host[j])) {
663       i++;
664       j++;
665       continue;
666     }
667     if (pattern[i] == '*') {
668       while (host[j] != '.' && host[j] != '\0') {
669         j++;
670       }
671       i++;
672       continue;
673     }
674     break;
675   }
676   if (i == size && host[j] == '\0') {
677     match = true;
678   }
679   return match;
680 }
681
682 int SSLContext::passwordCallback(char* password,
683                                  int size,
684                                  int,
685                                  void* data) {
686   SSLContext* context = (SSLContext*)data;
687   if (context == nullptr || context->passwordCollector() == nullptr) {
688     return 0;
689   }
690   std::string userPassword;
691   // call user defined password collector to get password
692   context->passwordCollector()->getPassword(userPassword, size);
693   auto length = int(userPassword.size());
694   if (length > size) {
695     length = size;
696   }
697   strncpy(password, userPassword.c_str(), size_t(length));
698   return length;
699 }
700
701 void SSLContext::setSSLLockTypes(std::map<int, LockType> inLockTypes) {
702   folly::ssl::setLockTypes(inLockTypes);
703 }
704
705 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
706 void SSLContext::enableFalseStart() {
707   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
708 }
709 #endif
710
711 void SSLContext::initializeOpenSSL() {
712   folly::ssl::init();
713 }
714
715 void SSLContext::setOptions(long options) {
716   long newOpt = SSL_CTX_set_options(ctx_, options);
717   if ((newOpt & options) != options) {
718     throw std::runtime_error("SSL_CTX_set_options failed");
719   }
720 }
721
722 std::string SSLContext::getErrors(int errnoCopy) {
723   std::string errors;
724   unsigned long  errorCode;
725   char   message[256];
726
727   errors.reserve(512);
728   while ((errorCode = ERR_get_error()) != 0) {
729     if (!errors.empty()) {
730       errors += "; ";
731     }
732     const char* reason = ERR_reason_error_string(errorCode);
733     if (reason == nullptr) {
734       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
735       reason = message;
736     }
737     errors += reason;
738   }
739   if (errors.empty()) {
740     errors = "error code: " + folly::to<std::string>(errnoCopy);
741   }
742   return errors;
743 }
744
745 std::ostream&
746 operator<<(std::ostream& os, const PasswordCollector& collector) {
747   os << collector.describe();
748   return os;
749 }
750
751 } // folly