Use folly::getCurrentThreadId() in SSLContext
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <folly/Format.h>
20 #include <folly/Memory.h>
21 #include <folly/Random.h>
22 #include <folly/SpinLock.h>
23 #include <folly/ThreadId.h>
24
25 // ---------------------------------------------------------------------
26 // SSLContext implementation
27 // ---------------------------------------------------------------------
28
29 struct CRYPTO_dynlock_value {
30   std::mutex mutex;
31 };
32
33 namespace folly {
34 //
35 // For OpenSSL portability API
36 using namespace folly::ssl;
37
38 bool SSLContext::initialized_ = false;
39
40 namespace {
41
42 std::mutex& initMutex() {
43   static std::mutex m;
44   return m;
45 }
46
47 } // anonymous namespace
48
49 #ifdef OPENSSL_NPN_NEGOTIATED
50 int SSLContext::sNextProtocolsExDataIndex_ = -1;
51 #endif
52
53 // SSLContext implementation
54 SSLContext::SSLContext(SSLVersion version) {
55   {
56     std::lock_guard<std::mutex> g(initMutex());
57     initializeOpenSSLLocked();
58   }
59
60   ctx_ = SSL_CTX_new(SSLv23_method());
61   if (ctx_ == nullptr) {
62     throw std::runtime_error("SSL_CTX_new: " + getErrors());
63   }
64
65   int opt = 0;
66   switch (version) {
67     case TLSv1:
68       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
69       break;
70     case SSLv3:
71       opt = SSL_OP_NO_SSLv2;
72       break;
73     default:
74       // do nothing
75       break;
76   }
77   int newOpt = SSL_CTX_set_options(ctx_, opt);
78   DCHECK((newOpt & opt) == opt);
79
80   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
81
82   checkPeerName_ = false;
83
84   SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
85
86 #if FOLLY_OPENSSL_HAS_SNI
87   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
88   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
89 #endif
90 }
91
92 SSLContext::~SSLContext() {
93   if (ctx_ != nullptr) {
94     SSL_CTX_free(ctx_);
95     ctx_ = nullptr;
96   }
97
98 #ifdef OPENSSL_NPN_NEGOTIATED
99   deleteNextProtocolsStrings();
100 #endif
101 }
102
103 void SSLContext::ciphers(const std::string& ciphers) {
104   providedCiphersString_ = ciphers;
105   setCiphersOrThrow(ciphers);
106 }
107
108 void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
109   if (ciphers.size() == 0) {
110     return;
111   }
112   std::string opensslCipherList;
113   join(":", ciphers, opensslCipherList);
114   setCiphersOrThrow(opensslCipherList);
115 }
116
117 void SSLContext::setSignatureAlgorithms(
118     const std::vector<std::string>& sigalgs) {
119   if (sigalgs.size() == 0) {
120     return;
121   }
122 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
123   std::string opensslSigAlgsList;
124   join(":", sigalgs, opensslSigAlgsList);
125   int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
126   if (rc == 0) {
127     throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
128   }
129 #endif
130 }
131
132 void SSLContext::setClientECCurvesList(
133     const std::vector<std::string>& ecCurves) {
134   if (ecCurves.size() == 0) {
135     return;
136   }
137 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
138   std::string ecCurvesList;
139   join(":", ecCurves, ecCurvesList);
140   int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
141   if (rc == 0) {
142     throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
143   }
144 #endif
145 }
146
147 void SSLContext::setServerECCurve(const std::string& curveName) {
148 #if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
149   EC_KEY* ecdh = nullptr;
150   int nid;
151
152   /*
153    * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
154    * from RFC 4492 section 5.1.1, or explicitly described curves over
155    * binary fields. OpenSSL only supports the "named curves", which provide
156    * maximum interoperability.
157    */
158
159   nid = OBJ_sn2nid(curveName.c_str());
160   if (nid == 0) {
161     LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
162   }
163   ecdh = EC_KEY_new_by_curve_name(nid);
164   if (ecdh == nullptr) {
165     LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
166   }
167
168   SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
169   EC_KEY_free(ecdh);
170 #else
171   throw std::runtime_error("Elliptic curve encryption not allowed");
172 #endif
173 }
174
175 void SSLContext::setX509VerifyParam(
176     const ssl::X509VerifyParam& x509VerifyParam) {
177   if (!x509VerifyParam) {
178     return;
179   }
180   if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
181     throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
182   }
183 }
184
185 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
186   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
187   if (rc == 0) {
188     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
189   }
190 }
191
192 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
193     verifyPeer) {
194   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
195   verifyPeer_ = verifyPeer;
196 }
197
198 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
199     verifyPeer) {
200   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
201   int mode = SSL_VERIFY_NONE;
202   switch(verifyPeer) {
203     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
204     // break;
205
206     case SSLVerifyPeerEnum::VERIFY:
207       mode = SSL_VERIFY_PEER;
208       break;
209
210     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
211       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
212       break;
213
214     case SSLVerifyPeerEnum::NO_VERIFY:
215       mode = SSL_VERIFY_NONE;
216       break;
217
218     default:
219       break;
220   }
221   return mode;
222 }
223
224 int SSLContext::getVerificationMode() {
225   return getVerificationMode(verifyPeer_);
226 }
227
228 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
229                               const std::string& peerName) {
230   int mode;
231   if (checkPeerCert) {
232     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
233     checkPeerName_ = checkPeerName;
234     peerFixedName_ = peerName;
235   } else {
236     mode = SSL_VERIFY_NONE;
237     checkPeerName_ = false; // can't check name without cert!
238     peerFixedName_.clear();
239   }
240   SSL_CTX_set_verify(ctx_, mode, nullptr);
241 }
242
243 void SSLContext::loadCertificate(const char* path, const char* format) {
244   if (path == nullptr || format == nullptr) {
245     throw std::invalid_argument(
246          "loadCertificateChain: either <path> or <format> is nullptr");
247   }
248   if (strcmp(format, "PEM") == 0) {
249     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
250       int errnoCopy = errno;
251       std::string reason("SSL_CTX_use_certificate_chain_file: ");
252       reason.append(path);
253       reason.append(": ");
254       reason.append(getErrors(errnoCopy));
255       throw std::runtime_error(reason);
256     }
257   } else {
258     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
259   }
260 }
261
262 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
263   if (cert.data() == nullptr) {
264     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
265   }
266
267   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
268   if (bio == nullptr) {
269     throw std::runtime_error("BIO_new: " + getErrors());
270   }
271
272   int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
273   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
274     throw std::runtime_error("BIO_write: " + getErrors());
275   }
276
277   ssl::X509UniquePtr x509(
278       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
279   if (x509 == nullptr) {
280     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
281   }
282
283   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
284     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
285   }
286 }
287
288 void SSLContext::loadPrivateKey(const char* path, const char* format) {
289   if (path == nullptr || format == nullptr) {
290     throw std::invalid_argument(
291         "loadPrivateKey: either <path> or <format> is nullptr");
292   }
293   if (strcmp(format, "PEM") == 0) {
294     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
295       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
296     }
297   } else {
298     throw std::runtime_error("Unsupported private key format: " + std::string(format));
299   }
300 }
301
302 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
303   if (pkey.data() == nullptr) {
304     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
305   }
306
307   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
308   if (bio == nullptr) {
309     throw std::runtime_error("BIO_new: " + getErrors());
310   }
311
312   int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
313   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
314     throw std::runtime_error("BIO_write: " + getErrors());
315   }
316
317   ssl::EvpPkeyUniquePtr key(
318       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
319   if (key == nullptr) {
320     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
321   }
322
323   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
324     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
325   }
326 }
327
328 void SSLContext::loadTrustedCertificates(const char* path) {
329   if (path == nullptr) {
330     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
331   }
332   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
333     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
334   }
335   ERR_clear_error();
336 }
337
338 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
339   SSL_CTX_set_cert_store(ctx_, store);
340 }
341
342 void SSLContext::loadClientCAList(const char* path) {
343   auto clientCAs = SSL_load_client_CA_file(path);
344   if (clientCAs == nullptr) {
345     LOG(ERROR) << "Unable to load ca file: " << path;
346     return;
347   }
348   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
349 }
350
351 void SSLContext::randomize() {
352   RAND_poll();
353 }
354
355 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
356   if (collector == nullptr) {
357     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
358     return;
359   }
360   collector_ = collector;
361   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
362   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
363 }
364
365 #if FOLLY_OPENSSL_HAS_SNI
366
367 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
368   serverNameCb_ = cb;
369 }
370
371 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
372   clientHelloCbs_.push_back(cb);
373 }
374
375 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
376   SSLContext* context = (SSLContext*)data;
377
378   if (context == nullptr) {
379     return SSL_TLSEXT_ERR_NOACK;
380   }
381
382   for (auto& cb : context->clientHelloCbs_) {
383     // Generic callbacks to happen after we receive the Client Hello.
384     // For example, we use one to switch which cipher we use depending
385     // on the user's TLS version.  Because the primary purpose of
386     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
387     // are side-uses, we ignore any possible failures other than just logging
388     // them.
389     cb(ssl);
390   }
391
392   if (!context->serverNameCb_) {
393     return SSL_TLSEXT_ERR_NOACK;
394   }
395
396   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
397   switch (ret) {
398     case SERVER_NAME_FOUND:
399       return SSL_TLSEXT_ERR_OK;
400     case SERVER_NAME_NOT_FOUND:
401       return SSL_TLSEXT_ERR_NOACK;
402     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
403       *al = TLS1_AD_UNRECOGNIZED_NAME;
404       return SSL_TLSEXT_ERR_ALERT_FATAL;
405     default:
406       CHECK(false);
407   }
408
409   return SSL_TLSEXT_ERR_NOACK;
410 }
411
412 void SSLContext::switchCiphersIfTLS11(
413     SSL* ssl,
414     const std::string& tls11CipherString,
415     const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
416   CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
417       << "Shouldn't call if empty ciphers / alt ciphers";
418
419   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
420     // We only do this for TLS v 1.1 and later
421     return;
422   }
423
424   const std::string* ciphers = &tls11CipherString;
425   if (!tls11AltCipherlist.empty()) {
426     if (!cipherListPicker_) {
427       std::vector<int> weights;
428       std::for_each(
429           tls11AltCipherlist.begin(),
430           tls11AltCipherlist.end(),
431           [&](const std::pair<std::string, int>& e) {
432             weights.push_back(e.second);
433           });
434       cipherListPicker_.reset(
435           new std::discrete_distribution<int>(weights.begin(), weights.end()));
436     }
437     auto rng = ThreadLocalPRNG();
438     auto index = (*cipherListPicker_)(rng);
439     if ((size_t)index >= tls11AltCipherlist.size()) {
440       LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
441                  << ", but tls11AltCipherlist is of length "
442                  << tls11AltCipherlist.size();
443     } else {
444       ciphers = &tls11AltCipherlist[size_t(index)].first;
445     }
446   }
447
448   // Prefer AES for TLS versions 1.1 and later since these are not
449   // vulnerable to BEAST attacks on AES.  Note that we're setting the
450   // cipher list on the SSL object, not the SSL_CTX object, so it will
451   // only last for this request.
452   int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
453   if ((rc == 0) || ERR_peek_error() != 0) {
454     // This shouldn't happen since we checked for this when proxygen
455     // started up.
456     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
457     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
458   }
459 }
460 #endif // FOLLY_OPENSSL_HAS_SNI
461
462 #if FOLLY_OPENSSL_HAS_ALPN
463 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
464                                    const unsigned char** out,
465                                    unsigned char* outlen,
466                                    const unsigned char* in,
467                                    unsigned int inlen,
468                                    void* data) {
469   SSLContext* context = (SSLContext*)data;
470   CHECK(context);
471   if (context->advertisedNextProtocols_.empty()) {
472     *out = nullptr;
473     *outlen = 0;
474   } else {
475     auto i = context->pickNextProtocols();
476     const auto& item = context->advertisedNextProtocols_[i];
477     if (SSL_select_next_proto((unsigned char**)out,
478                               outlen,
479                               item.protocols,
480                               item.length,
481                               in,
482                               inlen) != OPENSSL_NPN_NEGOTIATED) {
483       return SSL_TLSEXT_ERR_NOACK;
484     }
485   }
486   return SSL_TLSEXT_ERR_OK;
487 }
488 #endif // FOLLY_OPENSSL_HAS_ALPN
489
490 #ifdef OPENSSL_NPN_NEGOTIATED
491
492 bool SSLContext::setAdvertisedNextProtocols(
493     const std::list<std::string>& protocols, NextProtocolType protocolType) {
494   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
495 }
496
497 bool SSLContext::setRandomizedAdvertisedNextProtocols(
498     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
499   unsetNextProtocols();
500   if (items.size() == 0) {
501     return false;
502   }
503   int total_weight = 0;
504   for (const auto &item : items) {
505     if (item.protocols.size() == 0) {
506       continue;
507     }
508     AdvertisedNextProtocolsItem advertised_item;
509     advertised_item.length = 0;
510     for (const auto& proto : item.protocols) {
511       ++advertised_item.length;
512       auto protoLength = proto.length();
513       if (protoLength >= 256) {
514         deleteNextProtocolsStrings();
515         return false;
516       }
517       advertised_item.length += unsigned(protoLength);
518     }
519     advertised_item.protocols = new unsigned char[advertised_item.length];
520     if (!advertised_item.protocols) {
521       throw std::runtime_error("alloc failure");
522     }
523     unsigned char* dst = advertised_item.protocols;
524     for (auto& proto : item.protocols) {
525       uint8_t protoLength = uint8_t(proto.length());
526       *dst++ = (unsigned char)protoLength;
527       memcpy(dst, proto.data(), protoLength);
528       dst += protoLength;
529     }
530     total_weight += item.weight;
531     advertisedNextProtocols_.push_back(advertised_item);
532     advertisedNextProtocolWeights_.push_back(item.weight);
533   }
534   if (total_weight == 0) {
535     deleteNextProtocolsStrings();
536     return false;
537   }
538   nextProtocolDistribution_ =
539       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
540                                    advertisedNextProtocolWeights_.end());
541   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
542     SSL_CTX_set_next_protos_advertised_cb(
543         ctx_, advertisedNextProtocolCallback, this);
544     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
545   }
546 #if FOLLY_OPENSSL_HAS_ALPN
547   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
548     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
549     // Client cannot really use randomized alpn
550     SSL_CTX_set_alpn_protos(ctx_,
551                             advertisedNextProtocols_[0].protocols,
552                             advertisedNextProtocols_[0].length);
553   }
554 #endif
555   return true;
556 }
557
558 void SSLContext::deleteNextProtocolsStrings() {
559   for (auto protocols : advertisedNextProtocols_) {
560     delete[] protocols.protocols;
561   }
562   advertisedNextProtocols_.clear();
563   advertisedNextProtocolWeights_.clear();
564 }
565
566 void SSLContext::unsetNextProtocols() {
567   deleteNextProtocolsStrings();
568   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
569   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
570 #if FOLLY_OPENSSL_HAS_ALPN
571   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
572   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
573 #endif
574 }
575
576 size_t SSLContext::pickNextProtocols() {
577   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
578   auto rng = ThreadLocalPRNG();
579   return size_t(nextProtocolDistribution_(rng));
580 }
581
582 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
583       const unsigned char** out, unsigned int* outlen, void* data) {
584   SSLContext* context = (SSLContext*)data;
585   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
586     *out = nullptr;
587     *outlen = 0;
588   } else if (context->advertisedNextProtocols_.size() == 1) {
589     *out = context->advertisedNextProtocols_[0].protocols;
590     *outlen = context->advertisedNextProtocols_[0].length;
591   } else {
592     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
593           sNextProtocolsExDataIndex_));
594     if (selected_index) {
595       --selected_index;
596       *out = context->advertisedNextProtocols_[selected_index].protocols;
597       *outlen = context->advertisedNextProtocols_[selected_index].length;
598     } else {
599       auto i = context->pickNextProtocols();
600       uintptr_t selected = i + 1;
601       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
602       *out = context->advertisedNextProtocols_[i].protocols;
603       *outlen = context->advertisedNextProtocols_[i].length;
604     }
605   }
606   return SSL_TLSEXT_ERR_OK;
607 }
608
609 int SSLContext::selectNextProtocolCallback(SSL* ssl,
610                                            unsigned char** out,
611                                            unsigned char* outlen,
612                                            const unsigned char* server,
613                                            unsigned int server_len,
614                                            void* data) {
615   (void)ssl; // Make -Wunused-parameters happy
616   SSLContext* ctx = (SSLContext*)data;
617   if (ctx->advertisedNextProtocols_.size() > 1) {
618     VLOG(3) << "SSLContext::selectNextProcolCallback() "
619             << "client should be deterministic in selecting protocols.";
620   }
621
622   unsigned char* client = nullptr;
623   unsigned int client_len = 0;
624   bool filtered = false;
625   auto cpf = ctx->getClientProtocolFilterCallback();
626   if (cpf) {
627     filtered = (*cpf)(&client, &client_len, server, server_len);
628   }
629
630   if (!filtered) {
631     if (ctx->advertisedNextProtocols_.empty()) {
632       client = (unsigned char *) "";
633       client_len = 0;
634     } else {
635       client = ctx->advertisedNextProtocols_[0].protocols;
636       client_len = ctx->advertisedNextProtocols_[0].length;
637     }
638   }
639
640   int retval = SSL_select_next_proto(out, outlen, server, server_len,
641                                      client, client_len);
642   if (retval != OPENSSL_NPN_NEGOTIATED) {
643     VLOG(3) << "SSLContext::selectNextProcolCallback() "
644             << "unable to pick a next protocol.";
645   }
646   return SSL_TLSEXT_ERR_OK;
647 }
648 #endif // OPENSSL_NPN_NEGOTIATED
649
650 SSL* SSLContext::createSSL() const {
651   SSL* ssl = SSL_new(ctx_);
652   if (ssl == nullptr) {
653     throw std::runtime_error("SSL_new: " + getErrors());
654   }
655   return ssl;
656 }
657
658 void SSLContext::setSessionCacheContext(const std::string& context) {
659   SSL_CTX_set_session_id_context(
660       ctx_,
661       reinterpret_cast<const unsigned char*>(context.data()),
662       std::min<unsigned int>(
663           static_cast<unsigned int>(context.length()),
664           SSL_MAX_SSL_SESSION_ID_LENGTH));
665 }
666
667 /**
668  * Match a name with a pattern. The pattern may include wildcard. A single
669  * wildcard "*" can match up to one component in the domain name.
670  *
671  * @param  host    Host name, typically the name of the remote host
672  * @param  pattern Name retrieved from certificate
673  * @param  size    Size of "pattern"
674  * @return True, if "host" matches "pattern". False otherwise.
675  */
676 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
677   bool match = false;
678   int i = 0, j = 0;
679   while (i < size && host[j] != '\0') {
680     if (toupper(pattern[i]) == toupper(host[j])) {
681       i++;
682       j++;
683       continue;
684     }
685     if (pattern[i] == '*') {
686       while (host[j] != '.' && host[j] != '\0') {
687         j++;
688       }
689       i++;
690       continue;
691     }
692     break;
693   }
694   if (i == size && host[j] == '\0') {
695     match = true;
696   }
697   return match;
698 }
699
700 int SSLContext::passwordCallback(char* password,
701                                  int size,
702                                  int,
703                                  void* data) {
704   SSLContext* context = (SSLContext*)data;
705   if (context == nullptr || context->passwordCollector() == nullptr) {
706     return 0;
707   }
708   std::string userPassword;
709   // call user defined password collector to get password
710   context->passwordCollector()->getPassword(userPassword, size);
711   auto length = int(userPassword.size());
712   if (length > size) {
713     length = size;
714   }
715   strncpy(password, userPassword.c_str(), size_t(length));
716   return length;
717 }
718
719 struct SSLLock {
720   explicit SSLLock(
721     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
722       lockType(inLockType) {
723   }
724
725   void lock() {
726     if (lockType == SSLContext::LOCK_MUTEX) {
727       mutex.lock();
728     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
729       spinLock.lock();
730     }
731     // lockType == LOCK_NONE, no-op
732   }
733
734   void unlock() {
735     if (lockType == SSLContext::LOCK_MUTEX) {
736       mutex.unlock();
737     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
738       spinLock.unlock();
739     }
740     // lockType == LOCK_NONE, no-op
741   }
742
743   SSLContext::SSLLockType lockType;
744   folly::SpinLock spinLock{};
745   std::mutex mutex;
746 };
747
748 // Statics are unsafe in environments that call exit().
749 // If one thread calls exit() while another thread is
750 // references a member of SSLContext, bad things can happen.
751 // SSLContext runs in such environments.
752 // Instead of declaring a static member we "new" the static
753 // member so that it won't be destructed on exit().
754 static std::unique_ptr<SSLLock[]>& locks() {
755   static auto locksInst = new std::unique_ptr<SSLLock[]>();
756   return *locksInst;
757 }
758
759 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
760   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
761   return *lockTypesInst;
762 }
763
764 static void callbackLocking(int mode, int n, const char*, int) {
765   if (mode & CRYPTO_LOCK) {
766     locks()[size_t(n)].lock();
767   } else {
768     locks()[size_t(n)].unlock();
769   }
770 }
771
772 static unsigned long callbackThreadID() {
773   return static_cast<unsigned long>(folly::getCurrentThreadID());
774 }
775
776 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
777   return new CRYPTO_dynlock_value;
778 }
779
780 static void dyn_lock(int mode,
781                      struct CRYPTO_dynlock_value* lock,
782                      const char*, int) {
783   if (lock != nullptr) {
784     if (mode & CRYPTO_LOCK) {
785       lock->mutex.lock();
786     } else {
787       lock->mutex.unlock();
788     }
789   }
790 }
791
792 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
793   delete lock;
794 }
795
796 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
797   lockTypes() = inLockTypes;
798 }
799
800 bool SSLContext::isSSLLockDisabled(int lockId) {
801   const auto& sslLocks = lockTypes();
802   const auto it = sslLocks.find(lockId);
803   return it != sslLocks.end() &&
804       it->second == SSLContext::SSLLockType::LOCK_NONE;
805 }
806
807 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
808 void SSLContext::enableFalseStart() {
809   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
810 }
811 #endif
812
813 void SSLContext::markInitialized() {
814   std::lock_guard<std::mutex> g(initMutex());
815   initialized_ = true;
816 }
817
818 void SSLContext::initializeOpenSSL() {
819   std::lock_guard<std::mutex> g(initMutex());
820   initializeOpenSSLLocked();
821 }
822
823 void SSLContext::initializeOpenSSLLocked() {
824   if (initialized_) {
825     return;
826   }
827   SSL_library_init();
828   SSL_load_error_strings();
829   ERR_load_crypto_strings();
830   // static locking
831   locks().reset(new SSLLock[size_t(CRYPTO_num_locks())]);
832   for (auto it: lockTypes()) {
833     locks()[size_t(it.first)].lockType = it.second;
834   }
835   CRYPTO_set_id_callback(callbackThreadID);
836   CRYPTO_set_locking_callback(callbackLocking);
837   // dynamic locking
838   CRYPTO_set_dynlock_create_callback(dyn_create);
839   CRYPTO_set_dynlock_lock_callback(dyn_lock);
840   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
841   randomize();
842 #ifdef OPENSSL_NPN_NEGOTIATED
843   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
844       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
845 #endif
846   initialized_ = true;
847 }
848
849 void SSLContext::cleanupOpenSSL() {
850   std::lock_guard<std::mutex> g(initMutex());
851   cleanupOpenSSLLocked();
852 }
853
854 void SSLContext::cleanupOpenSSLLocked() {
855   if (!initialized_) {
856     return;
857   }
858
859   CRYPTO_set_id_callback(nullptr);
860   CRYPTO_set_locking_callback(nullptr);
861   CRYPTO_set_dynlock_create_callback(nullptr);
862   CRYPTO_set_dynlock_lock_callback(nullptr);
863   CRYPTO_set_dynlock_destroy_callback(nullptr);
864   CRYPTO_cleanup_all_ex_data();
865   ERR_free_strings();
866   EVP_cleanup();
867   ERR_clear_error();
868   locks().reset();
869   initialized_ = false;
870 }
871
872 void SSLContext::setOptions(long options) {
873   long newOpt = SSL_CTX_set_options(ctx_, options);
874   if ((newOpt & options) != options) {
875     throw std::runtime_error("SSL_CTX_set_options failed");
876   }
877 }
878
879 std::string SSLContext::getErrors(int errnoCopy) {
880   std::string errors;
881   unsigned long  errorCode;
882   char   message[256];
883
884   errors.reserve(512);
885   while ((errorCode = ERR_get_error()) != 0) {
886     if (!errors.empty()) {
887       errors += "; ";
888     }
889     const char* reason = ERR_reason_error_string(errorCode);
890     if (reason == nullptr) {
891       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
892       reason = message;
893     }
894     errors += reason;
895   }
896   if (errors.empty()) {
897     errors = "error code: " + folly::to<std::string>(errnoCopy);
898   }
899   return errors;
900 }
901
902 std::ostream&
903 operator<<(std::ostream& os, const PasswordCollector& collector) {
904   os << collector.describe();
905   return os;
906 }
907
908 } // folly