Fix copyright lines
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2014-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <folly/Format.h>
20 #include <folly/Memory.h>
21 #include <folly/Random.h>
22 #include <folly/SharedMutex.h>
23 #include <folly/SpinLock.h>
24 #include <folly/ssl/Init.h>
25 #include <folly/system/ThreadId.h>
26
27 // ---------------------------------------------------------------------
28 // SSLContext implementation
29 // ---------------------------------------------------------------------
30 namespace folly {
31 //
32 // For OpenSSL portability API
33 using namespace folly::ssl;
34
35 // SSLContext implementation
36 SSLContext::SSLContext(SSLVersion version) {
37   folly::ssl::init();
38
39   ctx_ = SSL_CTX_new(SSLv23_method());
40   if (ctx_ == nullptr) {
41     throw std::runtime_error("SSL_CTX_new: " + getErrors());
42   }
43
44   int opt = 0;
45   switch (version) {
46     case TLSv1:
47       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
48       break;
49     case SSLv3:
50       opt = SSL_OP_NO_SSLv2;
51       break;
52     case TLSv1_2:
53       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 |
54           SSL_OP_NO_TLSv1_1;
55       break;
56     default:
57       // do nothing
58       break;
59   }
60   int newOpt = SSL_CTX_set_options(ctx_, opt);
61   DCHECK((newOpt & opt) == opt);
62
63   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
64
65   checkPeerName_ = false;
66
67   SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
68
69 #if FOLLY_OPENSSL_HAS_SNI
70   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
71   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
72 #endif
73 }
74
75 SSLContext::~SSLContext() {
76   if (ctx_ != nullptr) {
77     SSL_CTX_free(ctx_);
78     ctx_ = nullptr;
79   }
80
81 #ifdef OPENSSL_NPN_NEGOTIATED
82   deleteNextProtocolsStrings();
83 #endif
84 }
85
86 void SSLContext::ciphers(const std::string& ciphers) {
87   setCiphersOrThrow(ciphers);
88 }
89
90 void SSLContext::setClientECCurvesList(
91     const std::vector<std::string>& ecCurves) {
92   if (ecCurves.size() == 0) {
93     return;
94   }
95 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
96   std::string ecCurvesList;
97   join(":", ecCurves, ecCurvesList);
98   int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
99   if (rc == 0) {
100     throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
101   }
102 #endif
103 }
104
105 void SSLContext::setServerECCurve(const std::string& curveName) {
106 #if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
107   EC_KEY* ecdh = nullptr;
108   int nid;
109
110   /*
111    * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
112    * from RFC 4492 section 5.1.1, or explicitly described curves over
113    * binary fields. OpenSSL only supports the "named curves", which provide
114    * maximum interoperability.
115    */
116
117   nid = OBJ_sn2nid(curveName.c_str());
118   if (nid == 0) {
119     LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
120   }
121   ecdh = EC_KEY_new_by_curve_name(nid);
122   if (ecdh == nullptr) {
123     LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
124   }
125
126   SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
127   EC_KEY_free(ecdh);
128 #else
129   throw std::runtime_error("Elliptic curve encryption not allowed");
130 #endif
131 }
132
133 void SSLContext::setX509VerifyParam(
134     const ssl::X509VerifyParam& x509VerifyParam) {
135   if (!x509VerifyParam) {
136     return;
137   }
138   if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
139     throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
140   }
141 }
142
143 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
144   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
145   if (rc == 0) {
146     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
147   }
148   providedCiphersString_ = ciphers;
149 }
150
151 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
152     verifyPeer) {
153   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
154   verifyPeer_ = verifyPeer;
155 }
156
157 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
158     verifyPeer) {
159   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
160   int mode = SSL_VERIFY_NONE;
161   switch(verifyPeer) {
162     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
163     // break;
164
165     case SSLVerifyPeerEnum::VERIFY:
166       mode = SSL_VERIFY_PEER;
167       break;
168
169     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
170       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
171       break;
172
173     case SSLVerifyPeerEnum::NO_VERIFY:
174       mode = SSL_VERIFY_NONE;
175       break;
176
177     default:
178       break;
179   }
180   return mode;
181 }
182
183 int SSLContext::getVerificationMode() {
184   return getVerificationMode(verifyPeer_);
185 }
186
187 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
188                               const std::string& peerName) {
189   int mode;
190   if (checkPeerCert) {
191     mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
192         SSL_VERIFY_CLIENT_ONCE;
193     checkPeerName_ = checkPeerName;
194     peerFixedName_ = peerName;
195   } else {
196     mode = SSL_VERIFY_NONE;
197     checkPeerName_ = false; // can't check name without cert!
198     peerFixedName_.clear();
199   }
200   SSL_CTX_set_verify(ctx_, mode, nullptr);
201 }
202
203 void SSLContext::loadCertificate(const char* path, const char* format) {
204   if (path == nullptr || format == nullptr) {
205     throw std::invalid_argument(
206          "loadCertificateChain: either <path> or <format> is nullptr");
207   }
208   if (strcmp(format, "PEM") == 0) {
209     if (SSL_CTX_use_certificate_chain_file(ctx_, path) != 1) {
210       int errnoCopy = errno;
211       std::string reason("SSL_CTX_use_certificate_chain_file: ");
212       reason.append(path);
213       reason.append(": ");
214       reason.append(getErrors(errnoCopy));
215       throw std::runtime_error(reason);
216     }
217   } else {
218     throw std::runtime_error(
219         "Unsupported certificate format: " + std::string(format));
220   }
221 }
222
223 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
224   if (cert.data() == nullptr) {
225     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
226   }
227
228   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
229   if (bio == nullptr) {
230     throw std::runtime_error("BIO_new: " + getErrors());
231   }
232
233   int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
234   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
235     throw std::runtime_error("BIO_write: " + getErrors());
236   }
237
238   ssl::X509UniquePtr x509(
239       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
240   if (x509 == nullptr) {
241     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
242   }
243
244   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
245     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
246   }
247 }
248
249 void SSLContext::loadPrivateKey(const char* path, const char* format) {
250   if (path == nullptr || format == nullptr) {
251     throw std::invalid_argument(
252         "loadPrivateKey: either <path> or <format> is nullptr");
253   }
254   if (strcmp(format, "PEM") == 0) {
255     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
256       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
257     }
258   } else {
259     throw std::runtime_error(
260         "Unsupported private key format: " + std::string(format));
261   }
262 }
263
264 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
265   if (pkey.data() == nullptr) {
266     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
267   }
268
269   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
270   if (bio == nullptr) {
271     throw std::runtime_error("BIO_new: " + getErrors());
272   }
273
274   int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
275   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
276     throw std::runtime_error("BIO_write: " + getErrors());
277   }
278
279   ssl::EvpPkeyUniquePtr key(
280       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
281   if (key == nullptr) {
282     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
283   }
284
285   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
286     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
287   }
288 }
289
290 void SSLContext::loadCertKeyPairFromBufferPEM(
291     folly::StringPiece cert,
292     folly::StringPiece pkey) {
293   loadCertificateFromBufferPEM(cert);
294   loadPrivateKeyFromBufferPEM(pkey);
295   if (!isCertKeyPairValid()) {
296     throw std::runtime_error("SSL certificate and private key do not match");
297   }
298 }
299
300 void SSLContext::loadCertKeyPairFromFiles(
301     const char* certPath,
302     const char* keyPath,
303     const char* certFormat,
304     const char* keyFormat) {
305   loadCertificate(certPath, certFormat);
306   loadPrivateKey(keyPath, keyFormat);
307   if (!isCertKeyPairValid()) {
308     throw std::runtime_error("SSL certificate and private key do not match");
309   }
310 }
311
312 bool SSLContext::isCertKeyPairValid() const {
313   return SSL_CTX_check_private_key(ctx_) == 1;
314 }
315
316 void SSLContext::loadTrustedCertificates(const char* path) {
317   if (path == nullptr) {
318     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
319   }
320   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
321     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
322   }
323   ERR_clear_error();
324 }
325
326 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
327   SSL_CTX_set_cert_store(ctx_, store);
328 }
329
330 void SSLContext::loadClientCAList(const char* path) {
331   auto clientCAs = SSL_load_client_CA_file(path);
332   if (clientCAs == nullptr) {
333     LOG(ERROR) << "Unable to load ca file: " << path;
334     return;
335   }
336   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
337 }
338
339 void SSLContext::passwordCollector(
340     std::shared_ptr<PasswordCollector> collector) {
341   if (collector == nullptr) {
342     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
343     return;
344   }
345   collector_ = collector;
346   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
347   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
348 }
349
350 #if FOLLY_OPENSSL_HAS_SNI
351
352 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
353   serverNameCb_ = cb;
354 }
355
356 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
357   clientHelloCbs_.push_back(cb);
358 }
359
360 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
361   SSLContext* context = (SSLContext*)data;
362
363   if (context == nullptr) {
364     return SSL_TLSEXT_ERR_NOACK;
365   }
366
367   for (auto& cb : context->clientHelloCbs_) {
368     // Generic callbacks to happen after we receive the Client Hello.
369     // For example, we use one to switch which cipher we use depending
370     // on the user's TLS version.  Because the primary purpose of
371     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
372     // are side-uses, we ignore any possible failures other than just logging
373     // them.
374     cb(ssl);
375   }
376
377   if (!context->serverNameCb_) {
378     return SSL_TLSEXT_ERR_NOACK;
379   }
380
381   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
382   switch (ret) {
383     case SERVER_NAME_FOUND:
384       return SSL_TLSEXT_ERR_OK;
385     case SERVER_NAME_NOT_FOUND:
386       return SSL_TLSEXT_ERR_NOACK;
387     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
388       *al = TLS1_AD_UNRECOGNIZED_NAME;
389       return SSL_TLSEXT_ERR_ALERT_FATAL;
390     default:
391       CHECK(false);
392   }
393
394   return SSL_TLSEXT_ERR_NOACK;
395 }
396 #endif // FOLLY_OPENSSL_HAS_SNI
397
398 #if FOLLY_OPENSSL_HAS_ALPN
399 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
400                                    const unsigned char** out,
401                                    unsigned char* outlen,
402                                    const unsigned char* in,
403                                    unsigned int inlen,
404                                    void* data) {
405   SSLContext* context = (SSLContext*)data;
406   CHECK(context);
407   if (context->advertisedNextProtocols_.empty()) {
408     *out = nullptr;
409     *outlen = 0;
410   } else {
411     auto i = context->pickNextProtocols();
412     const auto& item = context->advertisedNextProtocols_[i];
413     if (SSL_select_next_proto((unsigned char**)out,
414                               outlen,
415                               item.protocols,
416                               item.length,
417                               in,
418                               inlen) != OPENSSL_NPN_NEGOTIATED) {
419       return SSL_TLSEXT_ERR_NOACK;
420     }
421   }
422   return SSL_TLSEXT_ERR_OK;
423 }
424 #endif // FOLLY_OPENSSL_HAS_ALPN
425
426 #ifdef OPENSSL_NPN_NEGOTIATED
427
428 bool SSLContext::setAdvertisedNextProtocols(
429     const std::list<std::string>& protocols, NextProtocolType protocolType) {
430   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
431 }
432
433 bool SSLContext::setRandomizedAdvertisedNextProtocols(
434     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
435   unsetNextProtocols();
436   if (items.size() == 0) {
437     return false;
438   }
439   int total_weight = 0;
440   for (const auto &item : items) {
441     if (item.protocols.size() == 0) {
442       continue;
443     }
444     AdvertisedNextProtocolsItem advertised_item;
445     advertised_item.length = 0;
446     for (const auto& proto : item.protocols) {
447       ++advertised_item.length;
448       auto protoLength = proto.length();
449       if (protoLength >= 256) {
450         deleteNextProtocolsStrings();
451         return false;
452       }
453       advertised_item.length += unsigned(protoLength);
454     }
455     advertised_item.protocols = new unsigned char[advertised_item.length];
456     if (!advertised_item.protocols) {
457       throw std::runtime_error("alloc failure");
458     }
459     unsigned char* dst = advertised_item.protocols;
460     for (auto& proto : item.protocols) {
461       uint8_t protoLength = uint8_t(proto.length());
462       *dst++ = (unsigned char)protoLength;
463       memcpy(dst, proto.data(), protoLength);
464       dst += protoLength;
465     }
466     total_weight += item.weight;
467     advertisedNextProtocols_.push_back(advertised_item);
468     advertisedNextProtocolWeights_.push_back(item.weight);
469   }
470   if (total_weight == 0) {
471     deleteNextProtocolsStrings();
472     return false;
473   }
474   nextProtocolDistribution_ =
475       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
476                                    advertisedNextProtocolWeights_.end());
477   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
478     SSL_CTX_set_next_protos_advertised_cb(
479         ctx_, advertisedNextProtocolCallback, this);
480     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
481   }
482 #if FOLLY_OPENSSL_HAS_ALPN
483   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
484     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
485     // Client cannot really use randomized alpn
486     SSL_CTX_set_alpn_protos(ctx_,
487                             advertisedNextProtocols_[0].protocols,
488                             advertisedNextProtocols_[0].length);
489   }
490 #endif
491   return true;
492 }
493
494 void SSLContext::deleteNextProtocolsStrings() {
495   for (auto protocols : advertisedNextProtocols_) {
496     delete[] protocols.protocols;
497   }
498   advertisedNextProtocols_.clear();
499   advertisedNextProtocolWeights_.clear();
500 }
501
502 void SSLContext::unsetNextProtocols() {
503   deleteNextProtocolsStrings();
504   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
505   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
506 #if FOLLY_OPENSSL_HAS_ALPN
507   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
508   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
509 #endif
510 }
511
512 size_t SSLContext::pickNextProtocols() {
513   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
514   auto rng = ThreadLocalPRNG();
515   return size_t(nextProtocolDistribution_(rng));
516 }
517
518 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
519       const unsigned char** out, unsigned int* outlen, void* data) {
520   static int nextProtocolsExDataIndex = SSL_get_ex_new_index(
521       0, (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
522
523   SSLContext* context = (SSLContext*)data;
524   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
525     *out = nullptr;
526     *outlen = 0;
527   } else if (context->advertisedNextProtocols_.size() == 1) {
528     *out = context->advertisedNextProtocols_[0].protocols;
529     *outlen = context->advertisedNextProtocols_[0].length;
530   } else {
531     uintptr_t selected_index = reinterpret_cast<uintptr_t>(
532         SSL_get_ex_data(ssl, nextProtocolsExDataIndex));
533     if (selected_index) {
534       --selected_index;
535       *out = context->advertisedNextProtocols_[selected_index].protocols;
536       *outlen = context->advertisedNextProtocols_[selected_index].length;
537     } else {
538       auto i = context->pickNextProtocols();
539       uintptr_t selected = i + 1;
540       SSL_set_ex_data(ssl, nextProtocolsExDataIndex, (void*)selected);
541       *out = context->advertisedNextProtocols_[i].protocols;
542       *outlen = context->advertisedNextProtocols_[i].length;
543     }
544   }
545   return SSL_TLSEXT_ERR_OK;
546 }
547
548 int SSLContext::selectNextProtocolCallback(SSL* ssl,
549                                            unsigned char** out,
550                                            unsigned char* outlen,
551                                            const unsigned char* server,
552                                            unsigned int server_len,
553                                            void* data) {
554   (void)ssl; // Make -Wunused-parameters happy
555   SSLContext* ctx = (SSLContext*)data;
556   if (ctx->advertisedNextProtocols_.size() > 1) {
557     VLOG(3) << "SSLContext::selectNextProcolCallback() "
558             << "client should be deterministic in selecting protocols.";
559   }
560
561   unsigned char* client = nullptr;
562   unsigned int client_len = 0;
563   bool filtered = false;
564   auto cpf = ctx->getClientProtocolFilterCallback();
565   if (cpf) {
566     filtered = (*cpf)(&client, &client_len, server, server_len);
567   }
568
569   if (!filtered) {
570     if (ctx->advertisedNextProtocols_.empty()) {
571       client = (unsigned char *) "";
572       client_len = 0;
573     } else {
574       client = ctx->advertisedNextProtocols_[0].protocols;
575       client_len = ctx->advertisedNextProtocols_[0].length;
576     }
577   }
578
579   int retval = SSL_select_next_proto(out, outlen, server, server_len,
580                                      client, client_len);
581   if (retval != OPENSSL_NPN_NEGOTIATED) {
582     VLOG(3) << "SSLContext::selectNextProcolCallback() "
583             << "unable to pick a next protocol.";
584   }
585   return SSL_TLSEXT_ERR_OK;
586 }
587 #endif // OPENSSL_NPN_NEGOTIATED
588
589 SSL* SSLContext::createSSL() const {
590   SSL* ssl = SSL_new(ctx_);
591   if (ssl == nullptr) {
592     throw std::runtime_error("SSL_new: " + getErrors());
593   }
594   return ssl;
595 }
596
597 void SSLContext::setSessionCacheContext(const std::string& context) {
598   SSL_CTX_set_session_id_context(
599       ctx_,
600       reinterpret_cast<const unsigned char*>(context.data()),
601       std::min<unsigned int>(
602           static_cast<unsigned int>(context.length()), SSL_MAX_SID_CTX_LENGTH));
603 }
604
605 /**
606  * Match a name with a pattern. The pattern may include wildcard. A single
607  * wildcard "*" can match up to one component in the domain name.
608  *
609  * @param  host    Host name, typically the name of the remote host
610  * @param  pattern Name retrieved from certificate
611  * @param  size    Size of "pattern"
612  * @return True, if "host" matches "pattern". False otherwise.
613  */
614 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
615   bool match = false;
616   int i = 0, j = 0;
617   while (i < size && host[j] != '\0') {
618     if (toupper(pattern[i]) == toupper(host[j])) {
619       i++;
620       j++;
621       continue;
622     }
623     if (pattern[i] == '*') {
624       while (host[j] != '.' && host[j] != '\0') {
625         j++;
626       }
627       i++;
628       continue;
629     }
630     break;
631   }
632   if (i == size && host[j] == '\0') {
633     match = true;
634   }
635   return match;
636 }
637
638 int SSLContext::passwordCallback(char* password,
639                                  int size,
640                                  int,
641                                  void* data) {
642   SSLContext* context = (SSLContext*)data;
643   if (context == nullptr || context->passwordCollector() == nullptr) {
644     return 0;
645   }
646   std::string userPassword;
647   // call user defined password collector to get password
648   context->passwordCollector()->getPassword(userPassword, size);
649   auto const length = std::min(userPassword.size(), size_t(size));
650   std::memcpy(password, userPassword.data(), length);
651   return int(length);
652 }
653
654 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
655 void SSLContext::enableFalseStart() {
656   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
657 }
658 #endif
659
660 void SSLContext::initializeOpenSSL() {
661   folly::ssl::init();
662 }
663
664 void SSLContext::setOptions(long options) {
665   long newOpt = SSL_CTX_set_options(ctx_, options);
666   if ((newOpt & options) != options) {
667     throw std::runtime_error("SSL_CTX_set_options failed");
668   }
669 }
670
671 std::string SSLContext::getErrors(int errnoCopy) {
672   std::string errors;
673   unsigned long  errorCode;
674   char   message[256];
675
676   errors.reserve(512);
677   while ((errorCode = ERR_get_error()) != 0) {
678     if (!errors.empty()) {
679       errors += "; ";
680     }
681     const char* reason = ERR_reason_error_string(errorCode);
682     if (reason == nullptr) {
683       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
684       reason = message;
685     }
686     errors += reason;
687   }
688   if (errors.empty()) {
689     errors = "error code: " + folly::to<std::string>(errnoCopy);
690   }
691   return errors;
692 }
693
694 std::ostream&
695 operator<<(std::ostream& os, const PasswordCollector& collector) {
696   os << collector.describe();
697   return os;
698 }
699
700 } // namespace folly