Create the pthread.h portability header
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/Memory.h>
26 #include <folly/Random.h>
27 #include <folly/SpinLock.h>
28
29 // ---------------------------------------------------------------------
30 // SSLContext implementation
31 // ---------------------------------------------------------------------
32
33 struct CRYPTO_dynlock_value {
34   std::mutex mutex;
35 };
36
37 namespace folly {
38
39 bool SSLContext::initialized_ = false;
40
41 namespace {
42
43 std::mutex& initMutex() {
44   static std::mutex m;
45   return m;
46 }
47
48 } // anonymous namespace
49
50 #ifdef OPENSSL_NPN_NEGOTIATED
51 int SSLContext::sNextProtocolsExDataIndex_ = -1;
52 #endif
53
54 // SSLContext implementation
55 SSLContext::SSLContext(SSLVersion version) {
56   {
57     std::lock_guard<std::mutex> g(initMutex());
58     initializeOpenSSLLocked();
59   }
60
61   ctx_ = SSL_CTX_new(SSLv23_method());
62   if (ctx_ == nullptr) {
63     throw std::runtime_error("SSL_CTX_new: " + getErrors());
64   }
65
66   int opt = 0;
67   switch (version) {
68     case TLSv1:
69       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
70       break;
71     case SSLv3:
72       opt = SSL_OP_NO_SSLv2;
73       break;
74     default:
75       // do nothing
76       break;
77   }
78   int newOpt = SSL_CTX_set_options(ctx_, opt);
79   DCHECK((newOpt & opt) == opt);
80
81   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
82
83   checkPeerName_ = false;
84
85 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
86   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
87   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
88 #endif
89 }
90
91 SSLContext::~SSLContext() {
92   if (ctx_ != nullptr) {
93     SSL_CTX_free(ctx_);
94     ctx_ = nullptr;
95   }
96
97 #ifdef OPENSSL_NPN_NEGOTIATED
98   deleteNextProtocolsStrings();
99 #endif
100 }
101
102 void SSLContext::ciphers(const std::string& ciphers) {
103   providedCiphersString_ = ciphers;
104   setCiphersOrThrow(ciphers);
105 }
106
107 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
108   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
109   if (ERR_peek_error() != 0) {
110     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
111   }
112   if (rc == 0) {
113     throw std::runtime_error("None of specified ciphers are supported");
114   }
115 }
116
117 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
118     verifyPeer) {
119   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
120   verifyPeer_ = verifyPeer;
121 }
122
123 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
124     verifyPeer) {
125   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
126   int mode = SSL_VERIFY_NONE;
127   switch(verifyPeer) {
128     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
129     // break;
130
131     case SSLVerifyPeerEnum::VERIFY:
132       mode = SSL_VERIFY_PEER;
133       break;
134
135     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
136       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
137       break;
138
139     case SSLVerifyPeerEnum::NO_VERIFY:
140       mode = SSL_VERIFY_NONE;
141       break;
142
143     default:
144       break;
145   }
146   return mode;
147 }
148
149 int SSLContext::getVerificationMode() {
150   return getVerificationMode(verifyPeer_);
151 }
152
153 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
154                               const std::string& peerName) {
155   int mode;
156   if (checkPeerCert) {
157     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
158     checkPeerName_ = checkPeerName;
159     peerFixedName_ = peerName;
160   } else {
161     mode = SSL_VERIFY_NONE;
162     checkPeerName_ = false; // can't check name without cert!
163     peerFixedName_.clear();
164   }
165   SSL_CTX_set_verify(ctx_, mode, nullptr);
166 }
167
168 void SSLContext::loadCertificate(const char* path, const char* format) {
169   if (path == nullptr || format == nullptr) {
170     throw std::invalid_argument(
171          "loadCertificateChain: either <path> or <format> is nullptr");
172   }
173   if (strcmp(format, "PEM") == 0) {
174     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
175       int errnoCopy = errno;
176       std::string reason("SSL_CTX_use_certificate_chain_file: ");
177       reason.append(path);
178       reason.append(": ");
179       reason.append(getErrors(errnoCopy));
180       throw std::runtime_error(reason);
181     }
182   } else {
183     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
184   }
185 }
186
187 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
188   if (cert.data() == nullptr) {
189     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
190   }
191
192   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
193   if (bio == nullptr) {
194     throw std::runtime_error("BIO_new: " + getErrors());
195   }
196
197   int written = BIO_write(bio.get(), cert.data(), cert.size());
198   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
199     throw std::runtime_error("BIO_write: " + getErrors());
200   }
201
202   ssl::X509UniquePtr x509(
203       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
204   if (x509 == nullptr) {
205     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
206   }
207
208   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
209     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
210   }
211 }
212
213 void SSLContext::loadPrivateKey(const char* path, const char* format) {
214   if (path == nullptr || format == nullptr) {
215     throw std::invalid_argument(
216         "loadPrivateKey: either <path> or <format> is nullptr");
217   }
218   if (strcmp(format, "PEM") == 0) {
219     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
220       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
221     }
222   } else {
223     throw std::runtime_error("Unsupported private key format: " + std::string(format));
224   }
225 }
226
227 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
228   if (pkey.data() == nullptr) {
229     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
230   }
231
232   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
233   if (bio == nullptr) {
234     throw std::runtime_error("BIO_new: " + getErrors());
235   }
236
237   int written = BIO_write(bio.get(), pkey.data(), pkey.size());
238   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
239     throw std::runtime_error("BIO_write: " + getErrors());
240   }
241
242   ssl::EvpPkeyUniquePtr key(
243       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
244   if (key == nullptr) {
245     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
246   }
247
248   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
249     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
250   }
251 }
252
253 void SSLContext::loadTrustedCertificates(const char* path) {
254   if (path == nullptr) {
255     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
256   }
257   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
258     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
259   }
260 }
261
262 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
263   SSL_CTX_set_cert_store(ctx_, store);
264 }
265
266 void SSLContext::loadClientCAList(const char* path) {
267   auto clientCAs = SSL_load_client_CA_file(path);
268   if (clientCAs == nullptr) {
269     LOG(ERROR) << "Unable to load ca file: " << path;
270     return;
271   }
272   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
273 }
274
275 void SSLContext::randomize() {
276   RAND_poll();
277 }
278
279 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
280   if (collector == nullptr) {
281     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
282     return;
283   }
284   collector_ = collector;
285   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
286   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
287 }
288
289 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
290
291 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
292   serverNameCb_ = cb;
293 }
294
295 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
296   clientHelloCbs_.push_back(cb);
297 }
298
299 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
300   SSLContext* context = (SSLContext*)data;
301
302   if (context == nullptr) {
303     return SSL_TLSEXT_ERR_NOACK;
304   }
305
306   for (auto& cb : context->clientHelloCbs_) {
307     // Generic callbacks to happen after we receive the Client Hello.
308     // For example, we use one to switch which cipher we use depending
309     // on the user's TLS version.  Because the primary purpose of
310     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
311     // are side-uses, we ignore any possible failures other than just logging
312     // them.
313     cb(ssl);
314   }
315
316   if (!context->serverNameCb_) {
317     return SSL_TLSEXT_ERR_NOACK;
318   }
319
320   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
321   switch (ret) {
322     case SERVER_NAME_FOUND:
323       return SSL_TLSEXT_ERR_OK;
324     case SERVER_NAME_NOT_FOUND:
325       return SSL_TLSEXT_ERR_NOACK;
326     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
327       *al = TLS1_AD_UNRECOGNIZED_NAME;
328       return SSL_TLSEXT_ERR_ALERT_FATAL;
329     default:
330       CHECK(false);
331   }
332
333   return SSL_TLSEXT_ERR_NOACK;
334 }
335
336 void SSLContext::switchCiphersIfTLS11(
337     SSL* ssl,
338     const std::string& tls11CipherString,
339     const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
340   CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
341       << "Shouldn't call if empty ciphers / alt ciphers";
342
343   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
344     // We only do this for TLS v 1.1 and later
345     return;
346   }
347
348   const std::string* ciphers = &tls11CipherString;
349   if (!tls11AltCipherlist.empty()) {
350     if (!cipherListPicker_) {
351       std::vector<int> weights;
352       std::for_each(
353           tls11AltCipherlist.begin(),
354           tls11AltCipherlist.end(),
355           [&](const std::pair<std::string, int>& e) {
356             weights.push_back(e.second);
357           });
358       cipherListPicker_.reset(
359           new std::discrete_distribution<int>(weights.begin(), weights.end()));
360     }
361     auto rng = ThreadLocalPRNG();
362     auto index = (*cipherListPicker_)(rng);
363     if ((size_t)index >= tls11AltCipherlist.size()) {
364       LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
365                  << ", but tls11AltCipherlist is of length "
366                  << tls11AltCipherlist.size();
367     } else {
368       ciphers = &tls11AltCipherlist[index].first;
369     }
370   }
371
372   // Prefer AES for TLS versions 1.1 and later since these are not
373   // vulnerable to BEAST attacks on AES.  Note that we're setting the
374   // cipher list on the SSL object, not the SSL_CTX object, so it will
375   // only last for this request.
376   int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
377   if ((rc == 0) || ERR_peek_error() != 0) {
378     // This shouldn't happen since we checked for this when proxygen
379     // started up.
380     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
381     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
382   }
383 }
384 #endif
385
386 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
387 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
388                                    const unsigned char** out,
389                                    unsigned char* outlen,
390                                    const unsigned char* in,
391                                    unsigned int inlen,
392                                    void* data) {
393   SSLContext* context = (SSLContext*)data;
394   CHECK(context);
395   if (context->advertisedNextProtocols_.empty()) {
396     *out = nullptr;
397     *outlen = 0;
398   } else {
399     auto i = context->pickNextProtocols();
400     const auto& item = context->advertisedNextProtocols_[i];
401     if (SSL_select_next_proto((unsigned char**)out,
402                               outlen,
403                               item.protocols,
404                               item.length,
405                               in,
406                               inlen) != OPENSSL_NPN_NEGOTIATED) {
407       return SSL_TLSEXT_ERR_NOACK;
408     }
409   }
410   return SSL_TLSEXT_ERR_OK;
411 }
412 #endif
413
414 #ifdef OPENSSL_NPN_NEGOTIATED
415
416 bool SSLContext::setAdvertisedNextProtocols(
417     const std::list<std::string>& protocols, NextProtocolType protocolType) {
418   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
419 }
420
421 bool SSLContext::setRandomizedAdvertisedNextProtocols(
422     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
423   unsetNextProtocols();
424   if (items.size() == 0) {
425     return false;
426   }
427   int total_weight = 0;
428   for (const auto &item : items) {
429     if (item.protocols.size() == 0) {
430       continue;
431     }
432     AdvertisedNextProtocolsItem advertised_item;
433     advertised_item.length = 0;
434     for (const auto& proto : item.protocols) {
435       ++advertised_item.length;
436       unsigned protoLength = proto.length();
437       if (protoLength >= 256) {
438         deleteNextProtocolsStrings();
439         return false;
440       }
441       advertised_item.length += protoLength;
442     }
443     advertised_item.protocols = new unsigned char[advertised_item.length];
444     if (!advertised_item.protocols) {
445       throw std::runtime_error("alloc failure");
446     }
447     unsigned char* dst = advertised_item.protocols;
448     for (auto& proto : item.protocols) {
449       unsigned protoLength = proto.length();
450       *dst++ = (unsigned char)protoLength;
451       memcpy(dst, proto.data(), protoLength);
452       dst += protoLength;
453     }
454     total_weight += item.weight;
455     advertisedNextProtocols_.push_back(advertised_item);
456     advertisedNextProtocolWeights_.push_back(item.weight);
457   }
458   if (total_weight == 0) {
459     deleteNextProtocolsStrings();
460     return false;
461   }
462   nextProtocolDistribution_ =
463       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
464                                    advertisedNextProtocolWeights_.end());
465   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
466     SSL_CTX_set_next_protos_advertised_cb(
467         ctx_, advertisedNextProtocolCallback, this);
468     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
469   }
470 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
471   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
472     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
473     // Client cannot really use randomized alpn
474     SSL_CTX_set_alpn_protos(ctx_,
475                             advertisedNextProtocols_[0].protocols,
476                             advertisedNextProtocols_[0].length);
477   }
478 #endif
479   return true;
480 }
481
482 void SSLContext::deleteNextProtocolsStrings() {
483   for (auto protocols : advertisedNextProtocols_) {
484     delete[] protocols.protocols;
485   }
486   advertisedNextProtocols_.clear();
487   advertisedNextProtocolWeights_.clear();
488 }
489
490 void SSLContext::unsetNextProtocols() {
491   deleteNextProtocolsStrings();
492   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
493   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
494 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
495   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
496   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
497 #endif
498 }
499
500 size_t SSLContext::pickNextProtocols() {
501   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
502   auto rng = ThreadLocalPRNG();
503   return nextProtocolDistribution_(rng);
504 }
505
506 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
507       const unsigned char** out, unsigned int* outlen, void* data) {
508   SSLContext* context = (SSLContext*)data;
509   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
510     *out = nullptr;
511     *outlen = 0;
512   } else if (context->advertisedNextProtocols_.size() == 1) {
513     *out = context->advertisedNextProtocols_[0].protocols;
514     *outlen = context->advertisedNextProtocols_[0].length;
515   } else {
516     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
517           sNextProtocolsExDataIndex_));
518     if (selected_index) {
519       --selected_index;
520       *out = context->advertisedNextProtocols_[selected_index].protocols;
521       *outlen = context->advertisedNextProtocols_[selected_index].length;
522     } else {
523       auto i = context->pickNextProtocols();
524       uintptr_t selected = i + 1;
525       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
526       *out = context->advertisedNextProtocols_[i].protocols;
527       *outlen = context->advertisedNextProtocols_[i].length;
528     }
529   }
530   return SSL_TLSEXT_ERR_OK;
531 }
532
533 int SSLContext::selectNextProtocolCallback(SSL* ssl,
534                                            unsigned char** out,
535                                            unsigned char* outlen,
536                                            const unsigned char* server,
537                                            unsigned int server_len,
538                                            void* data) {
539   (void)ssl; // Make -Wunused-parameters happy
540   SSLContext* ctx = (SSLContext*)data;
541   if (ctx->advertisedNextProtocols_.size() > 1) {
542     VLOG(3) << "SSLContext::selectNextProcolCallback() "
543             << "client should be deterministic in selecting protocols.";
544   }
545
546   unsigned char *client;
547   unsigned int client_len;
548   bool filtered = false;
549   auto cpf = ctx->getClientProtocolFilterCallback();
550   if (cpf) {
551     filtered = (*cpf)(&client, &client_len, server, server_len);
552   }
553
554   if (!filtered) {
555     if (ctx->advertisedNextProtocols_.empty()) {
556       client = (unsigned char *) "";
557       client_len = 0;
558     } else {
559       client = ctx->advertisedNextProtocols_[0].protocols;
560       client_len = ctx->advertisedNextProtocols_[0].length;
561     }
562   }
563
564   int retval = SSL_select_next_proto(out, outlen, server, server_len,
565                                      client, client_len);
566   if (retval != OPENSSL_NPN_NEGOTIATED) {
567     VLOG(3) << "SSLContext::selectNextProcolCallback() "
568             << "unable to pick a next protocol.";
569   }
570   return SSL_TLSEXT_ERR_OK;
571 }
572 #endif // OPENSSL_NPN_NEGOTIATED
573
574 SSL* SSLContext::createSSL() const {
575   SSL* ssl = SSL_new(ctx_);
576   if (ssl == nullptr) {
577     throw std::runtime_error("SSL_new: " + getErrors());
578   }
579   return ssl;
580 }
581
582 void SSLContext::setSessionCacheContext(const std::string& context) {
583   SSL_CTX_set_session_id_context(
584       ctx_,
585       reinterpret_cast<const unsigned char*>(context.data()),
586       std::min(
587           static_cast<int>(context.length()), SSL_MAX_SSL_SESSION_ID_LENGTH));
588 }
589
590 /**
591  * Match a name with a pattern. The pattern may include wildcard. A single
592  * wildcard "*" can match up to one component in the domain name.
593  *
594  * @param  host    Host name, typically the name of the remote host
595  * @param  pattern Name retrieved from certificate
596  * @param  size    Size of "pattern"
597  * @return True, if "host" matches "pattern". False otherwise.
598  */
599 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
600   bool match = false;
601   int i = 0, j = 0;
602   while (i < size && host[j] != '\0') {
603     if (toupper(pattern[i]) == toupper(host[j])) {
604       i++;
605       j++;
606       continue;
607     }
608     if (pattern[i] == '*') {
609       while (host[j] != '.' && host[j] != '\0') {
610         j++;
611       }
612       i++;
613       continue;
614     }
615     break;
616   }
617   if (i == size && host[j] == '\0') {
618     match = true;
619   }
620   return match;
621 }
622
623 int SSLContext::passwordCallback(char* password,
624                                  int size,
625                                  int,
626                                  void* data) {
627   SSLContext* context = (SSLContext*)data;
628   if (context == nullptr || context->passwordCollector() == nullptr) {
629     return 0;
630   }
631   std::string userPassword;
632   // call user defined password collector to get password
633   context->passwordCollector()->getPassword(userPassword, size);
634   int length = userPassword.size();
635   if (length > size) {
636     length = size;
637   }
638   strncpy(password, userPassword.c_str(), length);
639   return length;
640 }
641
642 struct SSLLock {
643   explicit SSLLock(
644     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
645       lockType(inLockType) {
646   }
647
648   void lock() {
649     if (lockType == SSLContext::LOCK_MUTEX) {
650       mutex.lock();
651     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
652       spinLock.lock();
653     }
654     // lockType == LOCK_NONE, no-op
655   }
656
657   void unlock() {
658     if (lockType == SSLContext::LOCK_MUTEX) {
659       mutex.unlock();
660     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
661       spinLock.unlock();
662     }
663     // lockType == LOCK_NONE, no-op
664   }
665
666   SSLContext::SSLLockType lockType;
667   folly::SpinLock spinLock{};
668   std::mutex mutex;
669 };
670
671 // Statics are unsafe in environments that call exit().
672 // If one thread calls exit() while another thread is
673 // references a member of SSLContext, bad things can happen.
674 // SSLContext runs in such environments.
675 // Instead of declaring a static member we "new" the static
676 // member so that it won't be destructed on exit().
677 static std::unique_ptr<SSLLock[]>& locks() {
678   static auto locksInst = new std::unique_ptr<SSLLock[]>();
679   return *locksInst;
680 }
681
682 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
683   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
684   return *lockTypesInst;
685 }
686
687 static void callbackLocking(int mode, int n, const char*, int) {
688   if (mode & CRYPTO_LOCK) {
689     locks()[n].lock();
690   } else {
691     locks()[n].unlock();
692   }
693 }
694
695 static unsigned long callbackThreadID() {
696   return static_cast<unsigned long>(
697 #ifdef __APPLE__
698     pthread_mach_thread_np(pthread_self())
699 #elif _MSC_VER
700     pthread_getw32threadid_np(pthread_self())
701 #else
702     pthread_self()
703 #endif
704   );
705 }
706
707 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
708   return new CRYPTO_dynlock_value;
709 }
710
711 static void dyn_lock(int mode,
712                      struct CRYPTO_dynlock_value* lock,
713                      const char*, int) {
714   if (lock != nullptr) {
715     if (mode & CRYPTO_LOCK) {
716       lock->mutex.lock();
717     } else {
718       lock->mutex.unlock();
719     }
720   }
721 }
722
723 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
724   delete lock;
725 }
726
727 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
728   lockTypes() = inLockTypes;
729 }
730
731 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
732 void SSLContext::enableFalseStart() {
733   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
734 }
735 #endif
736
737 void SSLContext::markInitialized() {
738   std::lock_guard<std::mutex> g(initMutex());
739   initialized_ = true;
740 }
741
742 void SSLContext::initializeOpenSSL() {
743   std::lock_guard<std::mutex> g(initMutex());
744   initializeOpenSSLLocked();
745 }
746
747 void SSLContext::initializeOpenSSLLocked() {
748   if (initialized_) {
749     return;
750   }
751   SSL_library_init();
752   SSL_load_error_strings();
753   ERR_load_crypto_strings();
754   // static locking
755   locks().reset(new SSLLock[::CRYPTO_num_locks()]);
756   for (auto it: lockTypes()) {
757     locks()[it.first].lockType = it.second;
758   }
759   CRYPTO_set_id_callback(callbackThreadID);
760   CRYPTO_set_locking_callback(callbackLocking);
761   // dynamic locking
762   CRYPTO_set_dynlock_create_callback(dyn_create);
763   CRYPTO_set_dynlock_lock_callback(dyn_lock);
764   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
765   randomize();
766 #ifdef OPENSSL_NPN_NEGOTIATED
767   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
768       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
769 #endif
770   initialized_ = true;
771 }
772
773 void SSLContext::cleanupOpenSSL() {
774   std::lock_guard<std::mutex> g(initMutex());
775   cleanupOpenSSLLocked();
776 }
777
778 void SSLContext::cleanupOpenSSLLocked() {
779   if (!initialized_) {
780     return;
781   }
782
783   CRYPTO_set_id_callback(nullptr);
784   CRYPTO_set_locking_callback(nullptr);
785   CRYPTO_set_dynlock_create_callback(nullptr);
786   CRYPTO_set_dynlock_lock_callback(nullptr);
787   CRYPTO_set_dynlock_destroy_callback(nullptr);
788   CRYPTO_cleanup_all_ex_data();
789   ERR_free_strings();
790   EVP_cleanup();
791   ERR_remove_state(0);
792   locks().reset();
793   initialized_ = false;
794 }
795
796 void SSLContext::setOptions(long options) {
797   long newOpt = SSL_CTX_set_options(ctx_, options);
798   if ((newOpt & options) != options) {
799     throw std::runtime_error("SSL_CTX_set_options failed");
800   }
801 }
802
803 std::string SSLContext::getErrors(int errnoCopy) {
804   std::string errors;
805   unsigned long  errorCode;
806   char   message[256];
807
808   errors.reserve(512);
809   while ((errorCode = ERR_get_error()) != 0) {
810     if (!errors.empty()) {
811       errors += "; ";
812     }
813     const char* reason = ERR_reason_error_string(errorCode);
814     if (reason == nullptr) {
815       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
816       reason = message;
817     }
818     errors += reason;
819   }
820   if (errors.empty()) {
821     errors = "error code: " + folly::to<std::string>(errnoCopy);
822   }
823   return errors;
824 }
825
826 std::ostream&
827 operator<<(std::ostream& os, const PasswordCollector& collector) {
828   os << collector.describe();
829   return os;
830 }
831
832 } // folly