Add TLS 1.2+ version for contexts
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <folly/Format.h>
20 #include <folly/Memory.h>
21 #include <folly/Random.h>
22 #include <folly/SharedMutex.h>
23 #include <folly/SpinLock.h>
24 #include <folly/ThreadId.h>
25 #include <folly/ssl/Init.h>
26
27 // ---------------------------------------------------------------------
28 // SSLContext implementation
29 // ---------------------------------------------------------------------
30 namespace folly {
31 //
32 // For OpenSSL portability API
33 using namespace folly::ssl;
34
35 // SSLContext implementation
36 SSLContext::SSLContext(SSLVersion version) {
37   folly::ssl::init();
38
39   ctx_ = SSL_CTX_new(SSLv23_method());
40   if (ctx_ == nullptr) {
41     throw std::runtime_error("SSL_CTX_new: " + getErrors());
42   }
43
44   int opt = 0;
45   switch (version) {
46     case TLSv1:
47       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
48       break;
49     case SSLv3:
50       opt = SSL_OP_NO_SSLv2;
51       break;
52     case TLSv1_2:
53       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 |
54           SSL_OP_NO_TLSv1_1;
55       break;
56     default:
57       // do nothing
58       break;
59   }
60   int newOpt = SSL_CTX_set_options(ctx_, opt);
61   DCHECK((newOpt & opt) == opt);
62
63   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
64
65   checkPeerName_ = false;
66
67   SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
68
69 #if FOLLY_OPENSSL_HAS_SNI
70   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
71   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
72 #endif
73 }
74
75 SSLContext::~SSLContext() {
76   if (ctx_ != nullptr) {
77     SSL_CTX_free(ctx_);
78     ctx_ = nullptr;
79   }
80
81 #ifdef OPENSSL_NPN_NEGOTIATED
82   deleteNextProtocolsStrings();
83 #endif
84 }
85
86 void SSLContext::ciphers(const std::string& ciphers) {
87   setCiphersOrThrow(ciphers);
88 }
89
90 void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
91   if (ciphers.size() == 0) {
92     return;
93   }
94   std::string opensslCipherList;
95   join(":", ciphers, opensslCipherList);
96   setCiphersOrThrow(opensslCipherList);
97 }
98
99 void SSLContext::setSignatureAlgorithms(
100     const std::vector<std::string>& sigalgs) {
101   if (sigalgs.size() == 0) {
102     return;
103   }
104 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
105   std::string opensslSigAlgsList;
106   join(":", sigalgs, opensslSigAlgsList);
107   int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
108   if (rc == 0) {
109     throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
110   }
111 #endif
112 }
113
114 void SSLContext::setClientECCurvesList(
115     const std::vector<std::string>& ecCurves) {
116   if (ecCurves.size() == 0) {
117     return;
118   }
119 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
120   std::string ecCurvesList;
121   join(":", ecCurves, ecCurvesList);
122   int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
123   if (rc == 0) {
124     throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
125   }
126 #endif
127 }
128
129 void SSLContext::setServerECCurve(const std::string& curveName) {
130 #if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
131   EC_KEY* ecdh = nullptr;
132   int nid;
133
134   /*
135    * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
136    * from RFC 4492 section 5.1.1, or explicitly described curves over
137    * binary fields. OpenSSL only supports the "named curves", which provide
138    * maximum interoperability.
139    */
140
141   nid = OBJ_sn2nid(curveName.c_str());
142   if (nid == 0) {
143     LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
144   }
145   ecdh = EC_KEY_new_by_curve_name(nid);
146   if (ecdh == nullptr) {
147     LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
148   }
149
150   SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
151   EC_KEY_free(ecdh);
152 #else
153   throw std::runtime_error("Elliptic curve encryption not allowed");
154 #endif
155 }
156
157 void SSLContext::setX509VerifyParam(
158     const ssl::X509VerifyParam& x509VerifyParam) {
159   if (!x509VerifyParam) {
160     return;
161   }
162   if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
163     throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
164   }
165 }
166
167 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
168   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
169   if (rc == 0) {
170     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
171   }
172   providedCiphersString_ = ciphers;
173 }
174
175 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
176     verifyPeer) {
177   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
178   verifyPeer_ = verifyPeer;
179 }
180
181 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
182     verifyPeer) {
183   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
184   int mode = SSL_VERIFY_NONE;
185   switch(verifyPeer) {
186     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
187     // break;
188
189     case SSLVerifyPeerEnum::VERIFY:
190       mode = SSL_VERIFY_PEER;
191       break;
192
193     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
194       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
195       break;
196
197     case SSLVerifyPeerEnum::NO_VERIFY:
198       mode = SSL_VERIFY_NONE;
199       break;
200
201     default:
202       break;
203   }
204   return mode;
205 }
206
207 int SSLContext::getVerificationMode() {
208   return getVerificationMode(verifyPeer_);
209 }
210
211 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
212                               const std::string& peerName) {
213   int mode;
214   if (checkPeerCert) {
215     mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
216         SSL_VERIFY_CLIENT_ONCE;
217     checkPeerName_ = checkPeerName;
218     peerFixedName_ = peerName;
219   } else {
220     mode = SSL_VERIFY_NONE;
221     checkPeerName_ = false; // can't check name without cert!
222     peerFixedName_.clear();
223   }
224   SSL_CTX_set_verify(ctx_, mode, nullptr);
225 }
226
227 void SSLContext::loadCertificate(const char* path, const char* format) {
228   if (path == nullptr || format == nullptr) {
229     throw std::invalid_argument(
230          "loadCertificateChain: either <path> or <format> is nullptr");
231   }
232   if (strcmp(format, "PEM") == 0) {
233     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
234       int errnoCopy = errno;
235       std::string reason("SSL_CTX_use_certificate_chain_file: ");
236       reason.append(path);
237       reason.append(": ");
238       reason.append(getErrors(errnoCopy));
239       throw std::runtime_error(reason);
240     }
241   } else {
242     throw std::runtime_error(
243         "Unsupported certificate format: " + std::string(format));
244   }
245 }
246
247 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
248   if (cert.data() == nullptr) {
249     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
250   }
251
252   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
253   if (bio == nullptr) {
254     throw std::runtime_error("BIO_new: " + getErrors());
255   }
256
257   int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
258   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
259     throw std::runtime_error("BIO_write: " + getErrors());
260   }
261
262   ssl::X509UniquePtr x509(
263       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
264   if (x509 == nullptr) {
265     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
266   }
267
268   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
269     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
270   }
271 }
272
273 void SSLContext::loadPrivateKey(const char* path, const char* format) {
274   if (path == nullptr || format == nullptr) {
275     throw std::invalid_argument(
276         "loadPrivateKey: either <path> or <format> is nullptr");
277   }
278   if (strcmp(format, "PEM") == 0) {
279     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
280       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
281     }
282   } else {
283     throw std::runtime_error(
284         "Unsupported private key format: " + std::string(format));
285   }
286 }
287
288 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
289   if (pkey.data() == nullptr) {
290     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
291   }
292
293   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
294   if (bio == nullptr) {
295     throw std::runtime_error("BIO_new: " + getErrors());
296   }
297
298   int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
299   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
300     throw std::runtime_error("BIO_write: " + getErrors());
301   }
302
303   ssl::EvpPkeyUniquePtr key(
304       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
305   if (key == nullptr) {
306     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
307   }
308
309   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
310     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
311   }
312 }
313
314 void SSLContext::loadTrustedCertificates(const char* path) {
315   if (path == nullptr) {
316     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
317   }
318   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
319     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
320   }
321   ERR_clear_error();
322 }
323
324 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
325   SSL_CTX_set_cert_store(ctx_, store);
326 }
327
328 void SSLContext::loadClientCAList(const char* path) {
329   auto clientCAs = SSL_load_client_CA_file(path);
330   if (clientCAs == nullptr) {
331     LOG(ERROR) << "Unable to load ca file: " << path;
332     return;
333   }
334   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
335 }
336
337 void SSLContext::passwordCollector(
338     std::shared_ptr<PasswordCollector> collector) {
339   if (collector == nullptr) {
340     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
341     return;
342   }
343   collector_ = collector;
344   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
345   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
346 }
347
348 #if FOLLY_OPENSSL_HAS_SNI
349
350 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
351   serverNameCb_ = cb;
352 }
353
354 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
355   clientHelloCbs_.push_back(cb);
356 }
357
358 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
359   SSLContext* context = (SSLContext*)data;
360
361   if (context == nullptr) {
362     return SSL_TLSEXT_ERR_NOACK;
363   }
364
365   for (auto& cb : context->clientHelloCbs_) {
366     // Generic callbacks to happen after we receive the Client Hello.
367     // For example, we use one to switch which cipher we use depending
368     // on the user's TLS version.  Because the primary purpose of
369     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
370     // are side-uses, we ignore any possible failures other than just logging
371     // them.
372     cb(ssl);
373   }
374
375   if (!context->serverNameCb_) {
376     return SSL_TLSEXT_ERR_NOACK;
377   }
378
379   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
380   switch (ret) {
381     case SERVER_NAME_FOUND:
382       return SSL_TLSEXT_ERR_OK;
383     case SERVER_NAME_NOT_FOUND:
384       return SSL_TLSEXT_ERR_NOACK;
385     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
386       *al = TLS1_AD_UNRECOGNIZED_NAME;
387       return SSL_TLSEXT_ERR_ALERT_FATAL;
388     default:
389       CHECK(false);
390   }
391
392   return SSL_TLSEXT_ERR_NOACK;
393 }
394 #endif // FOLLY_OPENSSL_HAS_SNI
395
396 #if FOLLY_OPENSSL_HAS_ALPN
397 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
398                                    const unsigned char** out,
399                                    unsigned char* outlen,
400                                    const unsigned char* in,
401                                    unsigned int inlen,
402                                    void* data) {
403   SSLContext* context = (SSLContext*)data;
404   CHECK(context);
405   if (context->advertisedNextProtocols_.empty()) {
406     *out = nullptr;
407     *outlen = 0;
408   } else {
409     auto i = context->pickNextProtocols();
410     const auto& item = context->advertisedNextProtocols_[i];
411     if (SSL_select_next_proto((unsigned char**)out,
412                               outlen,
413                               item.protocols,
414                               item.length,
415                               in,
416                               inlen) != OPENSSL_NPN_NEGOTIATED) {
417       return SSL_TLSEXT_ERR_NOACK;
418     }
419   }
420   return SSL_TLSEXT_ERR_OK;
421 }
422 #endif // FOLLY_OPENSSL_HAS_ALPN
423
424 #ifdef OPENSSL_NPN_NEGOTIATED
425
426 bool SSLContext::setAdvertisedNextProtocols(
427     const std::list<std::string>& protocols, NextProtocolType protocolType) {
428   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
429 }
430
431 bool SSLContext::setRandomizedAdvertisedNextProtocols(
432     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
433   unsetNextProtocols();
434   if (items.size() == 0) {
435     return false;
436   }
437   int total_weight = 0;
438   for (const auto &item : items) {
439     if (item.protocols.size() == 0) {
440       continue;
441     }
442     AdvertisedNextProtocolsItem advertised_item;
443     advertised_item.length = 0;
444     for (const auto& proto : item.protocols) {
445       ++advertised_item.length;
446       auto protoLength = proto.length();
447       if (protoLength >= 256) {
448         deleteNextProtocolsStrings();
449         return false;
450       }
451       advertised_item.length += unsigned(protoLength);
452     }
453     advertised_item.protocols = new unsigned char[advertised_item.length];
454     if (!advertised_item.protocols) {
455       throw std::runtime_error("alloc failure");
456     }
457     unsigned char* dst = advertised_item.protocols;
458     for (auto& proto : item.protocols) {
459       uint8_t protoLength = uint8_t(proto.length());
460       *dst++ = (unsigned char)protoLength;
461       memcpy(dst, proto.data(), protoLength);
462       dst += protoLength;
463     }
464     total_weight += item.weight;
465     advertisedNextProtocols_.push_back(advertised_item);
466     advertisedNextProtocolWeights_.push_back(item.weight);
467   }
468   if (total_weight == 0) {
469     deleteNextProtocolsStrings();
470     return false;
471   }
472   nextProtocolDistribution_ =
473       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
474                                    advertisedNextProtocolWeights_.end());
475   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
476     SSL_CTX_set_next_protos_advertised_cb(
477         ctx_, advertisedNextProtocolCallback, this);
478     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
479   }
480 #if FOLLY_OPENSSL_HAS_ALPN
481   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
482     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
483     // Client cannot really use randomized alpn
484     SSL_CTX_set_alpn_protos(ctx_,
485                             advertisedNextProtocols_[0].protocols,
486                             advertisedNextProtocols_[0].length);
487   }
488 #endif
489   return true;
490 }
491
492 void SSLContext::deleteNextProtocolsStrings() {
493   for (auto protocols : advertisedNextProtocols_) {
494     delete[] protocols.protocols;
495   }
496   advertisedNextProtocols_.clear();
497   advertisedNextProtocolWeights_.clear();
498 }
499
500 void SSLContext::unsetNextProtocols() {
501   deleteNextProtocolsStrings();
502   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
503   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
504 #if FOLLY_OPENSSL_HAS_ALPN
505   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
506   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
507 #endif
508 }
509
510 size_t SSLContext::pickNextProtocols() {
511   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
512   auto rng = ThreadLocalPRNG();
513   return size_t(nextProtocolDistribution_(rng));
514 }
515
516 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
517       const unsigned char** out, unsigned int* outlen, void* data) {
518   static int nextProtocolsExDataIndex = SSL_get_ex_new_index(
519       0, (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
520
521   SSLContext* context = (SSLContext*)data;
522   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
523     *out = nullptr;
524     *outlen = 0;
525   } else if (context->advertisedNextProtocols_.size() == 1) {
526     *out = context->advertisedNextProtocols_[0].protocols;
527     *outlen = context->advertisedNextProtocols_[0].length;
528   } else {
529     uintptr_t selected_index = reinterpret_cast<uintptr_t>(
530         SSL_get_ex_data(ssl, nextProtocolsExDataIndex));
531     if (selected_index) {
532       --selected_index;
533       *out = context->advertisedNextProtocols_[selected_index].protocols;
534       *outlen = context->advertisedNextProtocols_[selected_index].length;
535     } else {
536       auto i = context->pickNextProtocols();
537       uintptr_t selected = i + 1;
538       SSL_set_ex_data(ssl, nextProtocolsExDataIndex, (void*)selected);
539       *out = context->advertisedNextProtocols_[i].protocols;
540       *outlen = context->advertisedNextProtocols_[i].length;
541     }
542   }
543   return SSL_TLSEXT_ERR_OK;
544 }
545
546 int SSLContext::selectNextProtocolCallback(SSL* ssl,
547                                            unsigned char** out,
548                                            unsigned char* outlen,
549                                            const unsigned char* server,
550                                            unsigned int server_len,
551                                            void* data) {
552   (void)ssl; // Make -Wunused-parameters happy
553   SSLContext* ctx = (SSLContext*)data;
554   if (ctx->advertisedNextProtocols_.size() > 1) {
555     VLOG(3) << "SSLContext::selectNextProcolCallback() "
556             << "client should be deterministic in selecting protocols.";
557   }
558
559   unsigned char* client = nullptr;
560   unsigned int client_len = 0;
561   bool filtered = false;
562   auto cpf = ctx->getClientProtocolFilterCallback();
563   if (cpf) {
564     filtered = (*cpf)(&client, &client_len, server, server_len);
565   }
566
567   if (!filtered) {
568     if (ctx->advertisedNextProtocols_.empty()) {
569       client = (unsigned char *) "";
570       client_len = 0;
571     } else {
572       client = ctx->advertisedNextProtocols_[0].protocols;
573       client_len = ctx->advertisedNextProtocols_[0].length;
574     }
575   }
576
577   int retval = SSL_select_next_proto(out, outlen, server, server_len,
578                                      client, client_len);
579   if (retval != OPENSSL_NPN_NEGOTIATED) {
580     VLOG(3) << "SSLContext::selectNextProcolCallback() "
581             << "unable to pick a next protocol.";
582   }
583   return SSL_TLSEXT_ERR_OK;
584 }
585 #endif // OPENSSL_NPN_NEGOTIATED
586
587 SSL* SSLContext::createSSL() const {
588   SSL* ssl = SSL_new(ctx_);
589   if (ssl == nullptr) {
590     throw std::runtime_error("SSL_new: " + getErrors());
591   }
592   return ssl;
593 }
594
595 void SSLContext::setSessionCacheContext(const std::string& context) {
596   SSL_CTX_set_session_id_context(
597       ctx_,
598       reinterpret_cast<const unsigned char*>(context.data()),
599       std::min<unsigned int>(
600           static_cast<unsigned int>(context.length()),
601           SSL_MAX_SSL_SESSION_ID_LENGTH));
602 }
603
604 /**
605  * Match a name with a pattern. The pattern may include wildcard. A single
606  * wildcard "*" can match up to one component in the domain name.
607  *
608  * @param  host    Host name, typically the name of the remote host
609  * @param  pattern Name retrieved from certificate
610  * @param  size    Size of "pattern"
611  * @return True, if "host" matches "pattern". False otherwise.
612  */
613 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
614   bool match = false;
615   int i = 0, j = 0;
616   while (i < size && host[j] != '\0') {
617     if (toupper(pattern[i]) == toupper(host[j])) {
618       i++;
619       j++;
620       continue;
621     }
622     if (pattern[i] == '*') {
623       while (host[j] != '.' && host[j] != '\0') {
624         j++;
625       }
626       i++;
627       continue;
628     }
629     break;
630   }
631   if (i == size && host[j] == '\0') {
632     match = true;
633   }
634   return match;
635 }
636
637 int SSLContext::passwordCallback(char* password,
638                                  int size,
639                                  int,
640                                  void* data) {
641   SSLContext* context = (SSLContext*)data;
642   if (context == nullptr || context->passwordCollector() == nullptr) {
643     return 0;
644   }
645   std::string userPassword;
646   // call user defined password collector to get password
647   context->passwordCollector()->getPassword(userPassword, size);
648   auto const length = std::min(userPassword.size(), size_t(size));
649   std::memcpy(password, userPassword.data(), length);
650   return int(length);
651 }
652
653 void SSLContext::setSSLLockTypes(std::map<int, LockType> inLockTypes) {
654   folly::ssl::setLockTypes(inLockTypes);
655 }
656
657 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
658 void SSLContext::enableFalseStart() {
659   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
660 }
661 #endif
662
663 void SSLContext::initializeOpenSSL() {
664   folly::ssl::init();
665 }
666
667 void SSLContext::setOptions(long options) {
668   long newOpt = SSL_CTX_set_options(ctx_, options);
669   if ((newOpt & options) != options) {
670     throw std::runtime_error("SSL_CTX_set_options failed");
671   }
672 }
673
674 std::string SSLContext::getErrors(int errnoCopy) {
675   std::string errors;
676   unsigned long  errorCode;
677   char   message[256];
678
679   errors.reserve(512);
680   while ((errorCode = ERR_get_error()) != 0) {
681     if (!errors.empty()) {
682       errors += "; ";
683     }
684     const char* reason = ERR_reason_error_string(errorCode);
685     if (reason == nullptr) {
686       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
687       reason = message;
688     }
689     errors += reason;
690   }
691   if (errors.empty()) {
692     errors = "error code: " + folly::to<std::string>(errnoCopy);
693   }
694   return errors;
695 }
696
697 std::ostream&
698 operator<<(std::ostream& os, const PasswordCollector& collector) {
699   os << collector.describe();
700   return os;
701 }
702
703 } // namespace folly