Treat OpenSSL as a non-portable include
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <folly/Format.h>
20 #include <folly/Memory.h>
21 #include <folly/Random.h>
22 #include <folly/SpinLock.h>
23
24 // ---------------------------------------------------------------------
25 // SSLContext implementation
26 // ---------------------------------------------------------------------
27
28 struct CRYPTO_dynlock_value {
29   std::mutex mutex;
30 };
31
32 namespace folly {
33 //
34 // For OpenSSL portability API
35 using namespace folly::ssl;
36
37 bool SSLContext::initialized_ = false;
38
39 namespace {
40
41 std::mutex& initMutex() {
42   static std::mutex m;
43   return m;
44 }
45
46 } // anonymous namespace
47
48 #ifdef OPENSSL_NPN_NEGOTIATED
49 int SSLContext::sNextProtocolsExDataIndex_ = -1;
50 #endif
51
52 // SSLContext implementation
53 SSLContext::SSLContext(SSLVersion version) {
54   {
55     std::lock_guard<std::mutex> g(initMutex());
56     initializeOpenSSLLocked();
57   }
58
59   ctx_ = SSL_CTX_new(SSLv23_method());
60   if (ctx_ == nullptr) {
61     throw std::runtime_error("SSL_CTX_new: " + getErrors());
62   }
63
64   int opt = 0;
65   switch (version) {
66     case TLSv1:
67       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
68       break;
69     case SSLv3:
70       opt = SSL_OP_NO_SSLv2;
71       break;
72     default:
73       // do nothing
74       break;
75   }
76   int newOpt = SSL_CTX_set_options(ctx_, opt);
77   DCHECK((newOpt & opt) == opt);
78
79   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
80
81   checkPeerName_ = false;
82
83   SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
84
85 #if FOLLY_OPENSSL_HAS_SNI
86   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
87   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
88 #endif
89 }
90
91 SSLContext::~SSLContext() {
92   if (ctx_ != nullptr) {
93     SSL_CTX_free(ctx_);
94     ctx_ = nullptr;
95   }
96
97 #ifdef OPENSSL_NPN_NEGOTIATED
98   deleteNextProtocolsStrings();
99 #endif
100 }
101
102 void SSLContext::ciphers(const std::string& ciphers) {
103   providedCiphersString_ = ciphers;
104   setCiphersOrThrow(ciphers);
105 }
106
107 void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
108   if (ciphers.size() == 0) {
109     return;
110   }
111   std::string opensslCipherList;
112   join(":", ciphers, opensslCipherList);
113   setCiphersOrThrow(opensslCipherList);
114 }
115
116 void SSLContext::setSignatureAlgorithms(
117     const std::vector<std::string>& sigalgs) {
118   if (sigalgs.size() == 0) {
119     return;
120   }
121 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
122   std::string opensslSigAlgsList;
123   join(":", sigalgs, opensslSigAlgsList);
124   int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
125   if (rc == 0) {
126     throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
127   }
128 #endif
129 }
130
131 void SSLContext::setClientECCurvesList(
132     const std::vector<std::string>& ecCurves) {
133   if (ecCurves.size() == 0) {
134     return;
135   }
136 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
137   std::string ecCurvesList;
138   join(":", ecCurves, ecCurvesList);
139   int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
140   if (rc == 0) {
141     throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
142   }
143 #endif
144 }
145
146 void SSLContext::setServerECCurve(const std::string& curveName) {
147 #if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
148   EC_KEY* ecdh = nullptr;
149   int nid;
150
151   /*
152    * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
153    * from RFC 4492 section 5.1.1, or explicitly described curves over
154    * binary fields. OpenSSL only supports the "named curves", which provide
155    * maximum interoperability.
156    */
157
158   nid = OBJ_sn2nid(curveName.c_str());
159   if (nid == 0) {
160     LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
161   }
162   ecdh = EC_KEY_new_by_curve_name(nid);
163   if (ecdh == nullptr) {
164     LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
165   }
166
167   SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
168   EC_KEY_free(ecdh);
169 #else
170   throw std::runtime_error("Elliptic curve encryption not allowed");
171 #endif
172 }
173
174 void SSLContext::setX509VerifyParam(
175     const ssl::X509VerifyParam& x509VerifyParam) {
176   if (!x509VerifyParam) {
177     return;
178   }
179   if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
180     throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
181   }
182 }
183
184 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
185   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
186   if (rc == 0) {
187     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
188   }
189 }
190
191 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
192     verifyPeer) {
193   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
194   verifyPeer_ = verifyPeer;
195 }
196
197 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
198     verifyPeer) {
199   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
200   int mode = SSL_VERIFY_NONE;
201   switch(verifyPeer) {
202     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
203     // break;
204
205     case SSLVerifyPeerEnum::VERIFY:
206       mode = SSL_VERIFY_PEER;
207       break;
208
209     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
210       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
211       break;
212
213     case SSLVerifyPeerEnum::NO_VERIFY:
214       mode = SSL_VERIFY_NONE;
215       break;
216
217     default:
218       break;
219   }
220   return mode;
221 }
222
223 int SSLContext::getVerificationMode() {
224   return getVerificationMode(verifyPeer_);
225 }
226
227 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
228                               const std::string& peerName) {
229   int mode;
230   if (checkPeerCert) {
231     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
232     checkPeerName_ = checkPeerName;
233     peerFixedName_ = peerName;
234   } else {
235     mode = SSL_VERIFY_NONE;
236     checkPeerName_ = false; // can't check name without cert!
237     peerFixedName_.clear();
238   }
239   SSL_CTX_set_verify(ctx_, mode, nullptr);
240 }
241
242 void SSLContext::loadCertificate(const char* path, const char* format) {
243   if (path == nullptr || format == nullptr) {
244     throw std::invalid_argument(
245          "loadCertificateChain: either <path> or <format> is nullptr");
246   }
247   if (strcmp(format, "PEM") == 0) {
248     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
249       int errnoCopy = errno;
250       std::string reason("SSL_CTX_use_certificate_chain_file: ");
251       reason.append(path);
252       reason.append(": ");
253       reason.append(getErrors(errnoCopy));
254       throw std::runtime_error(reason);
255     }
256   } else {
257     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
258   }
259 }
260
261 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
262   if (cert.data() == nullptr) {
263     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
264   }
265
266   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
267   if (bio == nullptr) {
268     throw std::runtime_error("BIO_new: " + getErrors());
269   }
270
271   int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
272   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
273     throw std::runtime_error("BIO_write: " + getErrors());
274   }
275
276   ssl::X509UniquePtr x509(
277       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
278   if (x509 == nullptr) {
279     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
280   }
281
282   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
283     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
284   }
285 }
286
287 void SSLContext::loadPrivateKey(const char* path, const char* format) {
288   if (path == nullptr || format == nullptr) {
289     throw std::invalid_argument(
290         "loadPrivateKey: either <path> or <format> is nullptr");
291   }
292   if (strcmp(format, "PEM") == 0) {
293     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
294       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
295     }
296   } else {
297     throw std::runtime_error("Unsupported private key format: " + std::string(format));
298   }
299 }
300
301 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
302   if (pkey.data() == nullptr) {
303     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
304   }
305
306   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
307   if (bio == nullptr) {
308     throw std::runtime_error("BIO_new: " + getErrors());
309   }
310
311   int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
312   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
313     throw std::runtime_error("BIO_write: " + getErrors());
314   }
315
316   ssl::EvpPkeyUniquePtr key(
317       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
318   if (key == nullptr) {
319     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
320   }
321
322   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
323     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
324   }
325 }
326
327 void SSLContext::loadTrustedCertificates(const char* path) {
328   if (path == nullptr) {
329     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
330   }
331   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
332     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
333   }
334   ERR_clear_error();
335 }
336
337 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
338   SSL_CTX_set_cert_store(ctx_, store);
339 }
340
341 void SSLContext::loadClientCAList(const char* path) {
342   auto clientCAs = SSL_load_client_CA_file(path);
343   if (clientCAs == nullptr) {
344     LOG(ERROR) << "Unable to load ca file: " << path;
345     return;
346   }
347   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
348 }
349
350 void SSLContext::randomize() {
351   RAND_poll();
352 }
353
354 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
355   if (collector == nullptr) {
356     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
357     return;
358   }
359   collector_ = collector;
360   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
361   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
362 }
363
364 #if FOLLY_OPENSSL_HAS_SNI
365
366 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
367   serverNameCb_ = cb;
368 }
369
370 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
371   clientHelloCbs_.push_back(cb);
372 }
373
374 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
375   SSLContext* context = (SSLContext*)data;
376
377   if (context == nullptr) {
378     return SSL_TLSEXT_ERR_NOACK;
379   }
380
381   for (auto& cb : context->clientHelloCbs_) {
382     // Generic callbacks to happen after we receive the Client Hello.
383     // For example, we use one to switch which cipher we use depending
384     // on the user's TLS version.  Because the primary purpose of
385     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
386     // are side-uses, we ignore any possible failures other than just logging
387     // them.
388     cb(ssl);
389   }
390
391   if (!context->serverNameCb_) {
392     return SSL_TLSEXT_ERR_NOACK;
393   }
394
395   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
396   switch (ret) {
397     case SERVER_NAME_FOUND:
398       return SSL_TLSEXT_ERR_OK;
399     case SERVER_NAME_NOT_FOUND:
400       return SSL_TLSEXT_ERR_NOACK;
401     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
402       *al = TLS1_AD_UNRECOGNIZED_NAME;
403       return SSL_TLSEXT_ERR_ALERT_FATAL;
404     default:
405       CHECK(false);
406   }
407
408   return SSL_TLSEXT_ERR_NOACK;
409 }
410
411 void SSLContext::switchCiphersIfTLS11(
412     SSL* ssl,
413     const std::string& tls11CipherString,
414     const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
415   CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
416       << "Shouldn't call if empty ciphers / alt ciphers";
417
418   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
419     // We only do this for TLS v 1.1 and later
420     return;
421   }
422
423   const std::string* ciphers = &tls11CipherString;
424   if (!tls11AltCipherlist.empty()) {
425     if (!cipherListPicker_) {
426       std::vector<int> weights;
427       std::for_each(
428           tls11AltCipherlist.begin(),
429           tls11AltCipherlist.end(),
430           [&](const std::pair<std::string, int>& e) {
431             weights.push_back(e.second);
432           });
433       cipherListPicker_.reset(
434           new std::discrete_distribution<int>(weights.begin(), weights.end()));
435     }
436     auto rng = ThreadLocalPRNG();
437     auto index = (*cipherListPicker_)(rng);
438     if ((size_t)index >= tls11AltCipherlist.size()) {
439       LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
440                  << ", but tls11AltCipherlist is of length "
441                  << tls11AltCipherlist.size();
442     } else {
443       ciphers = &tls11AltCipherlist[size_t(index)].first;
444     }
445   }
446
447   // Prefer AES for TLS versions 1.1 and later since these are not
448   // vulnerable to BEAST attacks on AES.  Note that we're setting the
449   // cipher list on the SSL object, not the SSL_CTX object, so it will
450   // only last for this request.
451   int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
452   if ((rc == 0) || ERR_peek_error() != 0) {
453     // This shouldn't happen since we checked for this when proxygen
454     // started up.
455     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
456     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
457   }
458 }
459 #endif // FOLLY_OPENSSL_HAS_SNI
460
461 #if FOLLY_OPENSSL_HAS_ALPN
462 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
463                                    const unsigned char** out,
464                                    unsigned char* outlen,
465                                    const unsigned char* in,
466                                    unsigned int inlen,
467                                    void* data) {
468   SSLContext* context = (SSLContext*)data;
469   CHECK(context);
470   if (context->advertisedNextProtocols_.empty()) {
471     *out = nullptr;
472     *outlen = 0;
473   } else {
474     auto i = context->pickNextProtocols();
475     const auto& item = context->advertisedNextProtocols_[i];
476     if (SSL_select_next_proto((unsigned char**)out,
477                               outlen,
478                               item.protocols,
479                               item.length,
480                               in,
481                               inlen) != OPENSSL_NPN_NEGOTIATED) {
482       return SSL_TLSEXT_ERR_NOACK;
483     }
484   }
485   return SSL_TLSEXT_ERR_OK;
486 }
487 #endif // FOLLY_OPENSSL_HAS_ALPN
488
489 #ifdef OPENSSL_NPN_NEGOTIATED
490
491 bool SSLContext::setAdvertisedNextProtocols(
492     const std::list<std::string>& protocols, NextProtocolType protocolType) {
493   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
494 }
495
496 bool SSLContext::setRandomizedAdvertisedNextProtocols(
497     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
498   unsetNextProtocols();
499   if (items.size() == 0) {
500     return false;
501   }
502   int total_weight = 0;
503   for (const auto &item : items) {
504     if (item.protocols.size() == 0) {
505       continue;
506     }
507     AdvertisedNextProtocolsItem advertised_item;
508     advertised_item.length = 0;
509     for (const auto& proto : item.protocols) {
510       ++advertised_item.length;
511       auto protoLength = proto.length();
512       if (protoLength >= 256) {
513         deleteNextProtocolsStrings();
514         return false;
515       }
516       advertised_item.length += unsigned(protoLength);
517     }
518     advertised_item.protocols = new unsigned char[advertised_item.length];
519     if (!advertised_item.protocols) {
520       throw std::runtime_error("alloc failure");
521     }
522     unsigned char* dst = advertised_item.protocols;
523     for (auto& proto : item.protocols) {
524       uint8_t protoLength = uint8_t(proto.length());
525       *dst++ = (unsigned char)protoLength;
526       memcpy(dst, proto.data(), protoLength);
527       dst += protoLength;
528     }
529     total_weight += item.weight;
530     advertisedNextProtocols_.push_back(advertised_item);
531     advertisedNextProtocolWeights_.push_back(item.weight);
532   }
533   if (total_weight == 0) {
534     deleteNextProtocolsStrings();
535     return false;
536   }
537   nextProtocolDistribution_ =
538       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
539                                    advertisedNextProtocolWeights_.end());
540   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
541     SSL_CTX_set_next_protos_advertised_cb(
542         ctx_, advertisedNextProtocolCallback, this);
543     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
544   }
545 #if FOLLY_OPENSSL_HAS_ALPN
546   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
547     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
548     // Client cannot really use randomized alpn
549     SSL_CTX_set_alpn_protos(ctx_,
550                             advertisedNextProtocols_[0].protocols,
551                             advertisedNextProtocols_[0].length);
552   }
553 #endif
554   return true;
555 }
556
557 void SSLContext::deleteNextProtocolsStrings() {
558   for (auto protocols : advertisedNextProtocols_) {
559     delete[] protocols.protocols;
560   }
561   advertisedNextProtocols_.clear();
562   advertisedNextProtocolWeights_.clear();
563 }
564
565 void SSLContext::unsetNextProtocols() {
566   deleteNextProtocolsStrings();
567   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
568   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
569 #if FOLLY_OPENSSL_HAS_ALPN
570   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
571   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
572 #endif
573 }
574
575 size_t SSLContext::pickNextProtocols() {
576   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
577   auto rng = ThreadLocalPRNG();
578   return size_t(nextProtocolDistribution_(rng));
579 }
580
581 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
582       const unsigned char** out, unsigned int* outlen, void* data) {
583   SSLContext* context = (SSLContext*)data;
584   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
585     *out = nullptr;
586     *outlen = 0;
587   } else if (context->advertisedNextProtocols_.size() == 1) {
588     *out = context->advertisedNextProtocols_[0].protocols;
589     *outlen = context->advertisedNextProtocols_[0].length;
590   } else {
591     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
592           sNextProtocolsExDataIndex_));
593     if (selected_index) {
594       --selected_index;
595       *out = context->advertisedNextProtocols_[selected_index].protocols;
596       *outlen = context->advertisedNextProtocols_[selected_index].length;
597     } else {
598       auto i = context->pickNextProtocols();
599       uintptr_t selected = i + 1;
600       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
601       *out = context->advertisedNextProtocols_[i].protocols;
602       *outlen = context->advertisedNextProtocols_[i].length;
603     }
604   }
605   return SSL_TLSEXT_ERR_OK;
606 }
607
608 int SSLContext::selectNextProtocolCallback(SSL* ssl,
609                                            unsigned char** out,
610                                            unsigned char* outlen,
611                                            const unsigned char* server,
612                                            unsigned int server_len,
613                                            void* data) {
614   (void)ssl; // Make -Wunused-parameters happy
615   SSLContext* ctx = (SSLContext*)data;
616   if (ctx->advertisedNextProtocols_.size() > 1) {
617     VLOG(3) << "SSLContext::selectNextProcolCallback() "
618             << "client should be deterministic in selecting protocols.";
619   }
620
621   unsigned char* client = nullptr;
622   unsigned int client_len = 0;
623   bool filtered = false;
624   auto cpf = ctx->getClientProtocolFilterCallback();
625   if (cpf) {
626     filtered = (*cpf)(&client, &client_len, server, server_len);
627   }
628
629   if (!filtered) {
630     if (ctx->advertisedNextProtocols_.empty()) {
631       client = (unsigned char *) "";
632       client_len = 0;
633     } else {
634       client = ctx->advertisedNextProtocols_[0].protocols;
635       client_len = ctx->advertisedNextProtocols_[0].length;
636     }
637   }
638
639   int retval = SSL_select_next_proto(out, outlen, server, server_len,
640                                      client, client_len);
641   if (retval != OPENSSL_NPN_NEGOTIATED) {
642     VLOG(3) << "SSLContext::selectNextProcolCallback() "
643             << "unable to pick a next protocol.";
644   }
645   return SSL_TLSEXT_ERR_OK;
646 }
647 #endif // OPENSSL_NPN_NEGOTIATED
648
649 SSL* SSLContext::createSSL() const {
650   SSL* ssl = SSL_new(ctx_);
651   if (ssl == nullptr) {
652     throw std::runtime_error("SSL_new: " + getErrors());
653   }
654   return ssl;
655 }
656
657 void SSLContext::setSessionCacheContext(const std::string& context) {
658   SSL_CTX_set_session_id_context(
659       ctx_,
660       reinterpret_cast<const unsigned char*>(context.data()),
661       std::min<unsigned int>(
662           static_cast<unsigned int>(context.length()),
663           SSL_MAX_SSL_SESSION_ID_LENGTH));
664 }
665
666 /**
667  * Match a name with a pattern. The pattern may include wildcard. A single
668  * wildcard "*" can match up to one component in the domain name.
669  *
670  * @param  host    Host name, typically the name of the remote host
671  * @param  pattern Name retrieved from certificate
672  * @param  size    Size of "pattern"
673  * @return True, if "host" matches "pattern". False otherwise.
674  */
675 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
676   bool match = false;
677   int i = 0, j = 0;
678   while (i < size && host[j] != '\0') {
679     if (toupper(pattern[i]) == toupper(host[j])) {
680       i++;
681       j++;
682       continue;
683     }
684     if (pattern[i] == '*') {
685       while (host[j] != '.' && host[j] != '\0') {
686         j++;
687       }
688       i++;
689       continue;
690     }
691     break;
692   }
693   if (i == size && host[j] == '\0') {
694     match = true;
695   }
696   return match;
697 }
698
699 int SSLContext::passwordCallback(char* password,
700                                  int size,
701                                  int,
702                                  void* data) {
703   SSLContext* context = (SSLContext*)data;
704   if (context == nullptr || context->passwordCollector() == nullptr) {
705     return 0;
706   }
707   std::string userPassword;
708   // call user defined password collector to get password
709   context->passwordCollector()->getPassword(userPassword, size);
710   auto length = int(userPassword.size());
711   if (length > size) {
712     length = size;
713   }
714   strncpy(password, userPassword.c_str(), size_t(length));
715   return length;
716 }
717
718 struct SSLLock {
719   explicit SSLLock(
720     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
721       lockType(inLockType) {
722   }
723
724   void lock() {
725     if (lockType == SSLContext::LOCK_MUTEX) {
726       mutex.lock();
727     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
728       spinLock.lock();
729     }
730     // lockType == LOCK_NONE, no-op
731   }
732
733   void unlock() {
734     if (lockType == SSLContext::LOCK_MUTEX) {
735       mutex.unlock();
736     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
737       spinLock.unlock();
738     }
739     // lockType == LOCK_NONE, no-op
740   }
741
742   SSLContext::SSLLockType lockType;
743   folly::SpinLock spinLock{};
744   std::mutex mutex;
745 };
746
747 // Statics are unsafe in environments that call exit().
748 // If one thread calls exit() while another thread is
749 // references a member of SSLContext, bad things can happen.
750 // SSLContext runs in such environments.
751 // Instead of declaring a static member we "new" the static
752 // member so that it won't be destructed on exit().
753 static std::unique_ptr<SSLLock[]>& locks() {
754   static auto locksInst = new std::unique_ptr<SSLLock[]>();
755   return *locksInst;
756 }
757
758 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
759   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
760   return *lockTypesInst;
761 }
762
763 static void callbackLocking(int mode, int n, const char*, int) {
764   if (mode & CRYPTO_LOCK) {
765     locks()[size_t(n)].lock();
766   } else {
767     locks()[size_t(n)].unlock();
768   }
769 }
770
771 static unsigned long callbackThreadID() {
772   return static_cast<unsigned long>(
773 #ifdef __APPLE__
774     pthread_mach_thread_np(pthread_self())
775 #elif _MSC_VER
776     pthread_getw32threadid_np(pthread_self())
777 #else
778     pthread_self()
779 #endif
780   );
781 }
782
783 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
784   return new CRYPTO_dynlock_value;
785 }
786
787 static void dyn_lock(int mode,
788                      struct CRYPTO_dynlock_value* lock,
789                      const char*, int) {
790   if (lock != nullptr) {
791     if (mode & CRYPTO_LOCK) {
792       lock->mutex.lock();
793     } else {
794       lock->mutex.unlock();
795     }
796   }
797 }
798
799 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
800   delete lock;
801 }
802
803 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
804   lockTypes() = inLockTypes;
805 }
806
807 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
808 void SSLContext::enableFalseStart() {
809   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
810 }
811 #endif
812
813 void SSLContext::markInitialized() {
814   std::lock_guard<std::mutex> g(initMutex());
815   initialized_ = true;
816 }
817
818 void SSLContext::initializeOpenSSL() {
819   std::lock_guard<std::mutex> g(initMutex());
820   initializeOpenSSLLocked();
821 }
822
823 void SSLContext::initializeOpenSSLLocked() {
824   if (initialized_) {
825     return;
826   }
827   SSL_library_init();
828   SSL_load_error_strings();
829   ERR_load_crypto_strings();
830   // static locking
831   locks().reset(new SSLLock[size_t(CRYPTO_num_locks())]);
832   for (auto it: lockTypes()) {
833     locks()[size_t(it.first)].lockType = it.second;
834   }
835   CRYPTO_set_id_callback(callbackThreadID);
836   CRYPTO_set_locking_callback(callbackLocking);
837   // dynamic locking
838   CRYPTO_set_dynlock_create_callback(dyn_create);
839   CRYPTO_set_dynlock_lock_callback(dyn_lock);
840   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
841   randomize();
842 #ifdef OPENSSL_NPN_NEGOTIATED
843   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
844       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
845 #endif
846   initialized_ = true;
847 }
848
849 void SSLContext::cleanupOpenSSL() {
850   std::lock_guard<std::mutex> g(initMutex());
851   cleanupOpenSSLLocked();
852 }
853
854 void SSLContext::cleanupOpenSSLLocked() {
855   if (!initialized_) {
856     return;
857   }
858
859   CRYPTO_set_id_callback(nullptr);
860   CRYPTO_set_locking_callback(nullptr);
861   CRYPTO_set_dynlock_create_callback(nullptr);
862   CRYPTO_set_dynlock_lock_callback(nullptr);
863   CRYPTO_set_dynlock_destroy_callback(nullptr);
864   CRYPTO_cleanup_all_ex_data();
865   ERR_free_strings();
866   EVP_cleanup();
867   ERR_clear_error();
868   locks().reset();
869   initialized_ = false;
870 }
871
872 void SSLContext::setOptions(long options) {
873   long newOpt = SSL_CTX_set_options(ctx_, options);
874   if ((newOpt & options) != options) {
875     throw std::runtime_error("SSL_CTX_set_options failed");
876   }
877 }
878
879 std::string SSLContext::getErrors(int errnoCopy) {
880   std::string errors;
881   unsigned long  errorCode;
882   char   message[256];
883
884   errors.reserve(512);
885   while ((errorCode = ERR_get_error()) != 0) {
886     if (!errors.empty()) {
887       errors += "; ";
888     }
889     const char* reason = ERR_reason_error_string(errorCode);
890     if (reason == nullptr) {
891       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
892       reason = message;
893     }
894     errors += reason;
895   }
896   if (errors.empty()) {
897     errors = "error code: " + folly::to<std::string>(errnoCopy);
898   }
899   return errors;
900 }
901
902 std::ostream&
903 operator<<(std::ostream& os, const PasswordCollector& collector) {
904   os << collector.describe();
905   return os;
906 }
907
908 } // folly