Update providedCiphersStr_ in one place.
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <folly/Format.h>
20 #include <folly/Memory.h>
21 #include <folly/Random.h>
22 #include <folly/SharedMutex.h>
23 #include <folly/SpinLock.h>
24 #include <folly/ThreadId.h>
25
26 // ---------------------------------------------------------------------
27 // SSLContext implementation
28 // ---------------------------------------------------------------------
29
30 struct CRYPTO_dynlock_value {
31   std::mutex mutex;
32 };
33
34 namespace folly {
35 //
36 // For OpenSSL portability API
37 using namespace folly::ssl;
38
39 bool SSLContext::initialized_ = false;
40
41 namespace {
42
43 std::mutex& initMutex() {
44   static std::mutex m;
45   return m;
46 }
47
48 } // anonymous namespace
49
50 #ifdef OPENSSL_NPN_NEGOTIATED
51 int SSLContext::sNextProtocolsExDataIndex_ = -1;
52 #endif
53
54 // SSLContext implementation
55 SSLContext::SSLContext(SSLVersion version) {
56   {
57     std::lock_guard<std::mutex> g(initMutex());
58     initializeOpenSSLLocked();
59   }
60
61   ctx_ = SSL_CTX_new(SSLv23_method());
62   if (ctx_ == nullptr) {
63     throw std::runtime_error("SSL_CTX_new: " + getErrors());
64   }
65
66   int opt = 0;
67   switch (version) {
68     case TLSv1:
69       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
70       break;
71     case SSLv3:
72       opt = SSL_OP_NO_SSLv2;
73       break;
74     default:
75       // do nothing
76       break;
77   }
78   int newOpt = SSL_CTX_set_options(ctx_, opt);
79   DCHECK((newOpt & opt) == opt);
80
81   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
82
83   checkPeerName_ = false;
84
85   SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
86
87 #if FOLLY_OPENSSL_HAS_SNI
88   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
89   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
90 #endif
91 }
92
93 SSLContext::~SSLContext() {
94   if (ctx_ != nullptr) {
95     SSL_CTX_free(ctx_);
96     ctx_ = nullptr;
97   }
98
99 #ifdef OPENSSL_NPN_NEGOTIATED
100   deleteNextProtocolsStrings();
101 #endif
102 }
103
104 void SSLContext::ciphers(const std::string& ciphers) {
105   setCiphersOrThrow(ciphers);
106 }
107
108 void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
109   if (ciphers.size() == 0) {
110     return;
111   }
112   std::string opensslCipherList;
113   join(":", ciphers, opensslCipherList);
114   setCiphersOrThrow(opensslCipherList);
115 }
116
117 void SSLContext::setSignatureAlgorithms(
118     const std::vector<std::string>& sigalgs) {
119   if (sigalgs.size() == 0) {
120     return;
121   }
122 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
123   std::string opensslSigAlgsList;
124   join(":", sigalgs, opensslSigAlgsList);
125   int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
126   if (rc == 0) {
127     throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
128   }
129 #endif
130 }
131
132 void SSLContext::setClientECCurvesList(
133     const std::vector<std::string>& ecCurves) {
134   if (ecCurves.size() == 0) {
135     return;
136   }
137 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
138   std::string ecCurvesList;
139   join(":", ecCurves, ecCurvesList);
140   int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
141   if (rc == 0) {
142     throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
143   }
144 #endif
145 }
146
147 void SSLContext::setServerECCurve(const std::string& curveName) {
148 #if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
149   EC_KEY* ecdh = nullptr;
150   int nid;
151
152   /*
153    * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
154    * from RFC 4492 section 5.1.1, or explicitly described curves over
155    * binary fields. OpenSSL only supports the "named curves", which provide
156    * maximum interoperability.
157    */
158
159   nid = OBJ_sn2nid(curveName.c_str());
160   if (nid == 0) {
161     LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
162   }
163   ecdh = EC_KEY_new_by_curve_name(nid);
164   if (ecdh == nullptr) {
165     LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
166   }
167
168   SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
169   EC_KEY_free(ecdh);
170 #else
171   throw std::runtime_error("Elliptic curve encryption not allowed");
172 #endif
173 }
174
175 void SSLContext::setX509VerifyParam(
176     const ssl::X509VerifyParam& x509VerifyParam) {
177   if (!x509VerifyParam) {
178     return;
179   }
180   if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
181     throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
182   }
183 }
184
185 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
186   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
187   if (rc == 0) {
188     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
189   }
190   providedCiphersString_ = ciphers;
191 }
192
193 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
194     verifyPeer) {
195   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
196   verifyPeer_ = verifyPeer;
197 }
198
199 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
200     verifyPeer) {
201   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
202   int mode = SSL_VERIFY_NONE;
203   switch(verifyPeer) {
204     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
205     // break;
206
207     case SSLVerifyPeerEnum::VERIFY:
208       mode = SSL_VERIFY_PEER;
209       break;
210
211     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
212       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
213       break;
214
215     case SSLVerifyPeerEnum::NO_VERIFY:
216       mode = SSL_VERIFY_NONE;
217       break;
218
219     default:
220       break;
221   }
222   return mode;
223 }
224
225 int SSLContext::getVerificationMode() {
226   return getVerificationMode(verifyPeer_);
227 }
228
229 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
230                               const std::string& peerName) {
231   int mode;
232   if (checkPeerCert) {
233     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
234     checkPeerName_ = checkPeerName;
235     peerFixedName_ = peerName;
236   } else {
237     mode = SSL_VERIFY_NONE;
238     checkPeerName_ = false; // can't check name without cert!
239     peerFixedName_.clear();
240   }
241   SSL_CTX_set_verify(ctx_, mode, nullptr);
242 }
243
244 void SSLContext::loadCertificate(const char* path, const char* format) {
245   if (path == nullptr || format == nullptr) {
246     throw std::invalid_argument(
247          "loadCertificateChain: either <path> or <format> is nullptr");
248   }
249   if (strcmp(format, "PEM") == 0) {
250     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
251       int errnoCopy = errno;
252       std::string reason("SSL_CTX_use_certificate_chain_file: ");
253       reason.append(path);
254       reason.append(": ");
255       reason.append(getErrors(errnoCopy));
256       throw std::runtime_error(reason);
257     }
258   } else {
259     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
260   }
261 }
262
263 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
264   if (cert.data() == nullptr) {
265     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
266   }
267
268   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
269   if (bio == nullptr) {
270     throw std::runtime_error("BIO_new: " + getErrors());
271   }
272
273   int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
274   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
275     throw std::runtime_error("BIO_write: " + getErrors());
276   }
277
278   ssl::X509UniquePtr x509(
279       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
280   if (x509 == nullptr) {
281     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
282   }
283
284   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
285     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
286   }
287 }
288
289 void SSLContext::loadPrivateKey(const char* path, const char* format) {
290   if (path == nullptr || format == nullptr) {
291     throw std::invalid_argument(
292         "loadPrivateKey: either <path> or <format> is nullptr");
293   }
294   if (strcmp(format, "PEM") == 0) {
295     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
296       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
297     }
298   } else {
299     throw std::runtime_error("Unsupported private key format: " + std::string(format));
300   }
301 }
302
303 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
304   if (pkey.data() == nullptr) {
305     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
306   }
307
308   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
309   if (bio == nullptr) {
310     throw std::runtime_error("BIO_new: " + getErrors());
311   }
312
313   int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
314   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
315     throw std::runtime_error("BIO_write: " + getErrors());
316   }
317
318   ssl::EvpPkeyUniquePtr key(
319       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
320   if (key == nullptr) {
321     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
322   }
323
324   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
325     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
326   }
327 }
328
329 void SSLContext::loadTrustedCertificates(const char* path) {
330   if (path == nullptr) {
331     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
332   }
333   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
334     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
335   }
336   ERR_clear_error();
337 }
338
339 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
340   SSL_CTX_set_cert_store(ctx_, store);
341 }
342
343 void SSLContext::loadClientCAList(const char* path) {
344   auto clientCAs = SSL_load_client_CA_file(path);
345   if (clientCAs == nullptr) {
346     LOG(ERROR) << "Unable to load ca file: " << path;
347     return;
348   }
349   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
350 }
351
352 void SSLContext::randomize() {
353   RAND_poll();
354 }
355
356 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
357   if (collector == nullptr) {
358     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
359     return;
360   }
361   collector_ = collector;
362   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
363   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
364 }
365
366 #if FOLLY_OPENSSL_HAS_SNI
367
368 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
369   serverNameCb_ = cb;
370 }
371
372 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
373   clientHelloCbs_.push_back(cb);
374 }
375
376 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
377   SSLContext* context = (SSLContext*)data;
378
379   if (context == nullptr) {
380     return SSL_TLSEXT_ERR_NOACK;
381   }
382
383   for (auto& cb : context->clientHelloCbs_) {
384     // Generic callbacks to happen after we receive the Client Hello.
385     // For example, we use one to switch which cipher we use depending
386     // on the user's TLS version.  Because the primary purpose of
387     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
388     // are side-uses, we ignore any possible failures other than just logging
389     // them.
390     cb(ssl);
391   }
392
393   if (!context->serverNameCb_) {
394     return SSL_TLSEXT_ERR_NOACK;
395   }
396
397   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
398   switch (ret) {
399     case SERVER_NAME_FOUND:
400       return SSL_TLSEXT_ERR_OK;
401     case SERVER_NAME_NOT_FOUND:
402       return SSL_TLSEXT_ERR_NOACK;
403     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
404       *al = TLS1_AD_UNRECOGNIZED_NAME;
405       return SSL_TLSEXT_ERR_ALERT_FATAL;
406     default:
407       CHECK(false);
408   }
409
410   return SSL_TLSEXT_ERR_NOACK;
411 }
412
413 void SSLContext::switchCiphersIfTLS11(
414     SSL* ssl,
415     const std::string& tls11CipherString,
416     const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
417   CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
418       << "Shouldn't call if empty ciphers / alt ciphers";
419
420   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
421     // We only do this for TLS v 1.1 and later
422     return;
423   }
424
425   const std::string* ciphers = &tls11CipherString;
426   if (!tls11AltCipherlist.empty()) {
427     if (!cipherListPicker_) {
428       std::vector<int> weights;
429       std::for_each(
430           tls11AltCipherlist.begin(),
431           tls11AltCipherlist.end(),
432           [&](const std::pair<std::string, int>& e) {
433             weights.push_back(e.second);
434           });
435       cipherListPicker_.reset(
436           new std::discrete_distribution<int>(weights.begin(), weights.end()));
437     }
438     auto rng = ThreadLocalPRNG();
439     auto index = (*cipherListPicker_)(rng);
440     if ((size_t)index >= tls11AltCipherlist.size()) {
441       LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
442                  << ", but tls11AltCipherlist is of length "
443                  << tls11AltCipherlist.size();
444     } else {
445       ciphers = &tls11AltCipherlist[size_t(index)].first;
446     }
447   }
448
449   // Prefer AES for TLS versions 1.1 and later since these are not
450   // vulnerable to BEAST attacks on AES.  Note that we're setting the
451   // cipher list on the SSL object, not the SSL_CTX object, so it will
452   // only last for this request.
453   int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
454   if ((rc == 0) || ERR_peek_error() != 0) {
455     // This shouldn't happen since we checked for this when proxygen
456     // started up.
457     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
458     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
459   }
460 }
461 #endif // FOLLY_OPENSSL_HAS_SNI
462
463 #if FOLLY_OPENSSL_HAS_ALPN
464 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
465                                    const unsigned char** out,
466                                    unsigned char* outlen,
467                                    const unsigned char* in,
468                                    unsigned int inlen,
469                                    void* data) {
470   SSLContext* context = (SSLContext*)data;
471   CHECK(context);
472   if (context->advertisedNextProtocols_.empty()) {
473     *out = nullptr;
474     *outlen = 0;
475   } else {
476     auto i = context->pickNextProtocols();
477     const auto& item = context->advertisedNextProtocols_[i];
478     if (SSL_select_next_proto((unsigned char**)out,
479                               outlen,
480                               item.protocols,
481                               item.length,
482                               in,
483                               inlen) != OPENSSL_NPN_NEGOTIATED) {
484       return SSL_TLSEXT_ERR_NOACK;
485     }
486   }
487   return SSL_TLSEXT_ERR_OK;
488 }
489 #endif // FOLLY_OPENSSL_HAS_ALPN
490
491 #ifdef OPENSSL_NPN_NEGOTIATED
492
493 bool SSLContext::setAdvertisedNextProtocols(
494     const std::list<std::string>& protocols, NextProtocolType protocolType) {
495   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
496 }
497
498 bool SSLContext::setRandomizedAdvertisedNextProtocols(
499     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
500   unsetNextProtocols();
501   if (items.size() == 0) {
502     return false;
503   }
504   int total_weight = 0;
505   for (const auto &item : items) {
506     if (item.protocols.size() == 0) {
507       continue;
508     }
509     AdvertisedNextProtocolsItem advertised_item;
510     advertised_item.length = 0;
511     for (const auto& proto : item.protocols) {
512       ++advertised_item.length;
513       auto protoLength = proto.length();
514       if (protoLength >= 256) {
515         deleteNextProtocolsStrings();
516         return false;
517       }
518       advertised_item.length += unsigned(protoLength);
519     }
520     advertised_item.protocols = new unsigned char[advertised_item.length];
521     if (!advertised_item.protocols) {
522       throw std::runtime_error("alloc failure");
523     }
524     unsigned char* dst = advertised_item.protocols;
525     for (auto& proto : item.protocols) {
526       uint8_t protoLength = uint8_t(proto.length());
527       *dst++ = (unsigned char)protoLength;
528       memcpy(dst, proto.data(), protoLength);
529       dst += protoLength;
530     }
531     total_weight += item.weight;
532     advertisedNextProtocols_.push_back(advertised_item);
533     advertisedNextProtocolWeights_.push_back(item.weight);
534   }
535   if (total_weight == 0) {
536     deleteNextProtocolsStrings();
537     return false;
538   }
539   nextProtocolDistribution_ =
540       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
541                                    advertisedNextProtocolWeights_.end());
542   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
543     SSL_CTX_set_next_protos_advertised_cb(
544         ctx_, advertisedNextProtocolCallback, this);
545     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
546   }
547 #if FOLLY_OPENSSL_HAS_ALPN
548   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
549     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
550     // Client cannot really use randomized alpn
551     SSL_CTX_set_alpn_protos(ctx_,
552                             advertisedNextProtocols_[0].protocols,
553                             advertisedNextProtocols_[0].length);
554   }
555 #endif
556   return true;
557 }
558
559 void SSLContext::deleteNextProtocolsStrings() {
560   for (auto protocols : advertisedNextProtocols_) {
561     delete[] protocols.protocols;
562   }
563   advertisedNextProtocols_.clear();
564   advertisedNextProtocolWeights_.clear();
565 }
566
567 void SSLContext::unsetNextProtocols() {
568   deleteNextProtocolsStrings();
569   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
570   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
571 #if FOLLY_OPENSSL_HAS_ALPN
572   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
573   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
574 #endif
575 }
576
577 size_t SSLContext::pickNextProtocols() {
578   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
579   auto rng = ThreadLocalPRNG();
580   return size_t(nextProtocolDistribution_(rng));
581 }
582
583 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
584       const unsigned char** out, unsigned int* outlen, void* data) {
585   SSLContext* context = (SSLContext*)data;
586   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
587     *out = nullptr;
588     *outlen = 0;
589   } else if (context->advertisedNextProtocols_.size() == 1) {
590     *out = context->advertisedNextProtocols_[0].protocols;
591     *outlen = context->advertisedNextProtocols_[0].length;
592   } else {
593     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
594           sNextProtocolsExDataIndex_));
595     if (selected_index) {
596       --selected_index;
597       *out = context->advertisedNextProtocols_[selected_index].protocols;
598       *outlen = context->advertisedNextProtocols_[selected_index].length;
599     } else {
600       auto i = context->pickNextProtocols();
601       uintptr_t selected = i + 1;
602       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
603       *out = context->advertisedNextProtocols_[i].protocols;
604       *outlen = context->advertisedNextProtocols_[i].length;
605     }
606   }
607   return SSL_TLSEXT_ERR_OK;
608 }
609
610 int SSLContext::selectNextProtocolCallback(SSL* ssl,
611                                            unsigned char** out,
612                                            unsigned char* outlen,
613                                            const unsigned char* server,
614                                            unsigned int server_len,
615                                            void* data) {
616   (void)ssl; // Make -Wunused-parameters happy
617   SSLContext* ctx = (SSLContext*)data;
618   if (ctx->advertisedNextProtocols_.size() > 1) {
619     VLOG(3) << "SSLContext::selectNextProcolCallback() "
620             << "client should be deterministic in selecting protocols.";
621   }
622
623   unsigned char* client = nullptr;
624   unsigned int client_len = 0;
625   bool filtered = false;
626   auto cpf = ctx->getClientProtocolFilterCallback();
627   if (cpf) {
628     filtered = (*cpf)(&client, &client_len, server, server_len);
629   }
630
631   if (!filtered) {
632     if (ctx->advertisedNextProtocols_.empty()) {
633       client = (unsigned char *) "";
634       client_len = 0;
635     } else {
636       client = ctx->advertisedNextProtocols_[0].protocols;
637       client_len = ctx->advertisedNextProtocols_[0].length;
638     }
639   }
640
641   int retval = SSL_select_next_proto(out, outlen, server, server_len,
642                                      client, client_len);
643   if (retval != OPENSSL_NPN_NEGOTIATED) {
644     VLOG(3) << "SSLContext::selectNextProcolCallback() "
645             << "unable to pick a next protocol.";
646   }
647   return SSL_TLSEXT_ERR_OK;
648 }
649 #endif // OPENSSL_NPN_NEGOTIATED
650
651 SSL* SSLContext::createSSL() const {
652   SSL* ssl = SSL_new(ctx_);
653   if (ssl == nullptr) {
654     throw std::runtime_error("SSL_new: " + getErrors());
655   }
656   return ssl;
657 }
658
659 void SSLContext::setSessionCacheContext(const std::string& context) {
660   SSL_CTX_set_session_id_context(
661       ctx_,
662       reinterpret_cast<const unsigned char*>(context.data()),
663       std::min<unsigned int>(
664           static_cast<unsigned int>(context.length()),
665           SSL_MAX_SSL_SESSION_ID_LENGTH));
666 }
667
668 /**
669  * Match a name with a pattern. The pattern may include wildcard. A single
670  * wildcard "*" can match up to one component in the domain name.
671  *
672  * @param  host    Host name, typically the name of the remote host
673  * @param  pattern Name retrieved from certificate
674  * @param  size    Size of "pattern"
675  * @return True, if "host" matches "pattern". False otherwise.
676  */
677 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
678   bool match = false;
679   int i = 0, j = 0;
680   while (i < size && host[j] != '\0') {
681     if (toupper(pattern[i]) == toupper(host[j])) {
682       i++;
683       j++;
684       continue;
685     }
686     if (pattern[i] == '*') {
687       while (host[j] != '.' && host[j] != '\0') {
688         j++;
689       }
690       i++;
691       continue;
692     }
693     break;
694   }
695   if (i == size && host[j] == '\0') {
696     match = true;
697   }
698   return match;
699 }
700
701 int SSLContext::passwordCallback(char* password,
702                                  int size,
703                                  int,
704                                  void* data) {
705   SSLContext* context = (SSLContext*)data;
706   if (context == nullptr || context->passwordCollector() == nullptr) {
707     return 0;
708   }
709   std::string userPassword;
710   // call user defined password collector to get password
711   context->passwordCollector()->getPassword(userPassword, size);
712   auto length = int(userPassword.size());
713   if (length > size) {
714     length = size;
715   }
716   strncpy(password, userPassword.c_str(), size_t(length));
717   return length;
718 }
719
720 struct SSLLock {
721   explicit SSLLock(
722     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
723       lockType(inLockType) {
724   }
725
726   void lock(bool read) {
727     if (lockType == SSLContext::LOCK_MUTEX) {
728       mutex.lock();
729     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
730       spinLock.lock();
731     } else if (lockType == SSLContext::LOCK_SHAREDMUTEX) {
732       if (read) {
733         sharedMutex.lock_shared();
734       } else {
735         sharedMutex.lock();
736       }
737     }
738     // lockType == LOCK_NONE, no-op
739   }
740
741   void unlock(bool read) {
742     if (lockType == SSLContext::LOCK_MUTEX) {
743       mutex.unlock();
744     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
745       spinLock.unlock();
746     } else if (lockType == SSLContext::LOCK_SHAREDMUTEX) {
747       if (read) {
748         sharedMutex.unlock_shared();
749       } else {
750         sharedMutex.unlock();
751       }
752     }
753     // lockType == LOCK_NONE, no-op
754   }
755
756   SSLContext::SSLLockType lockType;
757   folly::SpinLock spinLock{};
758   std::mutex mutex;
759   SharedMutex sharedMutex;
760 };
761
762 // Statics are unsafe in environments that call exit().
763 // If one thread calls exit() while another thread is
764 // references a member of SSLContext, bad things can happen.
765 // SSLContext runs in such environments.
766 // Instead of declaring a static member we "new" the static
767 // member so that it won't be destructed on exit().
768 static std::unique_ptr<SSLLock[]>& locks() {
769   static auto locksInst = new std::unique_ptr<SSLLock[]>();
770   return *locksInst;
771 }
772
773 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
774   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
775   return *lockTypesInst;
776 }
777
778 static void callbackLocking(int mode, int n, const char*, int) {
779   if (mode & CRYPTO_LOCK) {
780     locks()[size_t(n)].lock(mode & CRYPTO_READ);
781   } else {
782     locks()[size_t(n)].unlock(mode & CRYPTO_READ);
783   }
784 }
785
786 static unsigned long callbackThreadID() {
787   return static_cast<unsigned long>(folly::getCurrentThreadID());
788 }
789
790 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
791   return new CRYPTO_dynlock_value;
792 }
793
794 static void dyn_lock(int mode,
795                      struct CRYPTO_dynlock_value* lock,
796                      const char*, int) {
797   if (lock != nullptr) {
798     if (mode & CRYPTO_LOCK) {
799       lock->mutex.lock();
800     } else {
801       lock->mutex.unlock();
802     }
803   }
804 }
805
806 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
807   delete lock;
808 }
809
810 void SSLContext::setSSLLockTypesLocked(std::map<int, SSLLockType> inLockTypes) {
811   lockTypes() = inLockTypes;
812 }
813
814 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
815   std::lock_guard<std::mutex> g(initMutex());
816   if (initialized_) {
817     // We set the locks on initialization, so if we are already initialized
818     // this would have no affect.
819     LOG(INFO) << "Ignoring setSSLLockTypes after initialization";
820     return;
821   }
822   setSSLLockTypesLocked(std::move(inLockTypes));
823 }
824
825 void SSLContext::setSSLLockTypesAndInitOpenSSL(
826     std::map<int, SSLLockType> inLockTypes) {
827   std::lock_guard<std::mutex> g(initMutex());
828   CHECK(!initialized_) << "OpenSSL is already initialized";
829   setSSLLockTypesLocked(std::move(inLockTypes));
830   initializeOpenSSLLocked();
831 }
832
833 bool SSLContext::isSSLLockDisabled(int lockId) {
834   std::lock_guard<std::mutex> g(initMutex());
835   CHECK(initialized_) << "OpenSSL is not initialized yet";
836   const auto& sslLocks = lockTypes();
837   const auto it = sslLocks.find(lockId);
838   return it != sslLocks.end() &&
839       it->second == SSLContext::SSLLockType::LOCK_NONE;
840 }
841
842 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
843 void SSLContext::enableFalseStart() {
844   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
845 }
846 #endif
847
848 void SSLContext::markInitialized() {
849   std::lock_guard<std::mutex> g(initMutex());
850   initialized_ = true;
851 }
852
853 void SSLContext::initializeOpenSSL() {
854   std::lock_guard<std::mutex> g(initMutex());
855   initializeOpenSSLLocked();
856 }
857
858 void SSLContext::initializeOpenSSLLocked() {
859   if (initialized_) {
860     return;
861   }
862   SSL_library_init();
863   SSL_load_error_strings();
864   ERR_load_crypto_strings();
865   // static locking
866   locks().reset(new SSLLock[size_t(CRYPTO_num_locks())]);
867   for (auto it: lockTypes()) {
868     locks()[size_t(it.first)].lockType = it.second;
869   }
870   CRYPTO_set_id_callback(callbackThreadID);
871   CRYPTO_set_locking_callback(callbackLocking);
872   // dynamic locking
873   CRYPTO_set_dynlock_create_callback(dyn_create);
874   CRYPTO_set_dynlock_lock_callback(dyn_lock);
875   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
876   randomize();
877 #ifdef OPENSSL_NPN_NEGOTIATED
878   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
879       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
880 #endif
881   initialized_ = true;
882 }
883
884 void SSLContext::cleanupOpenSSL() {
885   std::lock_guard<std::mutex> g(initMutex());
886   cleanupOpenSSLLocked();
887 }
888
889 void SSLContext::cleanupOpenSSLLocked() {
890   if (!initialized_) {
891     return;
892   }
893
894   CRYPTO_set_id_callback(nullptr);
895   CRYPTO_set_locking_callback(nullptr);
896   CRYPTO_set_dynlock_create_callback(nullptr);
897   CRYPTO_set_dynlock_lock_callback(nullptr);
898   CRYPTO_set_dynlock_destroy_callback(nullptr);
899   CRYPTO_cleanup_all_ex_data();
900   ERR_free_strings();
901   EVP_cleanup();
902   ERR_clear_error();
903   locks().reset();
904   initialized_ = false;
905 }
906
907 void SSLContext::setOptions(long options) {
908   long newOpt = SSL_CTX_set_options(ctx_, options);
909   if ((newOpt & options) != options) {
910     throw std::runtime_error("SSL_CTX_set_options failed");
911   }
912 }
913
914 std::string SSLContext::getErrors(int errnoCopy) {
915   std::string errors;
916   unsigned long  errorCode;
917   char   message[256];
918
919   errors.reserve(512);
920   while ((errorCode = ERR_get_error()) != 0) {
921     if (!errors.empty()) {
922       errors += "; ";
923     }
924     const char* reason = ERR_reason_error_string(errorCode);
925     if (reason == nullptr) {
926       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
927       reason = message;
928     }
929     errors += reason;
930   }
931   if (errors.empty()) {
932     errors = "error code: " + folly::to<std::string>(errnoCopy);
933   }
934   return errors;
935 }
936
937 std::ostream&
938 operator<<(std::ostream& os, const PasswordCollector& collector) {
939   os << collector.describe();
940   return os;
941 }
942
943 } // folly