5232b6816fed7a3093c925160e113feb878d617c
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/Memory.h>
26 #include <folly/SpinLock.h>
27
28 // ---------------------------------------------------------------------
29 // SSLContext implementation
30 // ---------------------------------------------------------------------
31
32 struct CRYPTO_dynlock_value {
33   std::mutex mutex;
34 };
35
36 namespace folly {
37
38 bool SSLContext::initialized_ = false;
39
40 namespace {
41
42 std::mutex& initMutex() {
43   static std::mutex m;
44   return m;
45 }
46
47 } // anonymous namespace
48
49 #ifdef OPENSSL_NPN_NEGOTIATED
50 int SSLContext::sNextProtocolsExDataIndex_ = -1;
51 #endif
52
53 // SSLContext implementation
54 SSLContext::SSLContext(SSLVersion version) {
55   {
56     std::lock_guard<std::mutex> g(initMutex());
57     initializeOpenSSLLocked();
58   }
59
60   ctx_ = SSL_CTX_new(SSLv23_method());
61   if (ctx_ == nullptr) {
62     throw std::runtime_error("SSL_CTX_new: " + getErrors());
63   }
64
65   int opt = 0;
66   switch (version) {
67     case TLSv1:
68       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
69       break;
70     case SSLv3:
71       opt = SSL_OP_NO_SSLv2;
72       break;
73     default:
74       // do nothing
75       break;
76   }
77   int newOpt = SSL_CTX_set_options(ctx_, opt);
78   DCHECK((newOpt & opt) == opt);
79
80   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
81
82   checkPeerName_ = false;
83
84 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
85   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
86   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
87 #endif
88
89 #ifdef OPENSSL_NPN_NEGOTIATED
90   Random::seed(nextProtocolPicker_);
91 #endif
92 }
93
94 SSLContext::~SSLContext() {
95   if (ctx_ != nullptr) {
96     SSL_CTX_free(ctx_);
97     ctx_ = nullptr;
98   }
99
100 #ifdef OPENSSL_NPN_NEGOTIATED
101   deleteNextProtocolsStrings();
102 #endif
103 }
104
105 void SSLContext::ciphers(const std::string& ciphers) {
106   providedCiphersString_ = ciphers;
107   setCiphersOrThrow(ciphers);
108 }
109
110 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
111   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
112   if (ERR_peek_error() != 0) {
113     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
114   }
115   if (rc == 0) {
116     throw std::runtime_error("None of specified ciphers are supported");
117   }
118 }
119
120 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
121     verifyPeer) {
122   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
123   verifyPeer_ = verifyPeer;
124 }
125
126 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
127     verifyPeer) {
128   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
129   int mode = SSL_VERIFY_NONE;
130   switch(verifyPeer) {
131     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
132     // break;
133
134     case SSLVerifyPeerEnum::VERIFY:
135       mode = SSL_VERIFY_PEER;
136       break;
137
138     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
139       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
140       break;
141
142     case SSLVerifyPeerEnum::NO_VERIFY:
143       mode = SSL_VERIFY_NONE;
144       break;
145
146     default:
147       break;
148   }
149   return mode;
150 }
151
152 int SSLContext::getVerificationMode() {
153   return getVerificationMode(verifyPeer_);
154 }
155
156 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
157                               const std::string& peerName) {
158   int mode;
159   if (checkPeerCert) {
160     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
161     checkPeerName_ = checkPeerName;
162     peerFixedName_ = peerName;
163   } else {
164     mode = SSL_VERIFY_NONE;
165     checkPeerName_ = false; // can't check name without cert!
166     peerFixedName_.clear();
167   }
168   SSL_CTX_set_verify(ctx_, mode, nullptr);
169 }
170
171 void SSLContext::loadCertificate(const char* path, const char* format) {
172   if (path == nullptr || format == nullptr) {
173     throw std::invalid_argument(
174          "loadCertificateChain: either <path> or <format> is nullptr");
175   }
176   if (strcmp(format, "PEM") == 0) {
177     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
178       int errnoCopy = errno;
179       std::string reason("SSL_CTX_use_certificate_chain_file: ");
180       reason.append(path);
181       reason.append(": ");
182       reason.append(getErrors(errnoCopy));
183       throw std::runtime_error(reason);
184     }
185   } else {
186     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
187   }
188 }
189
190 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
191   if (cert.data() == nullptr) {
192     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
193   }
194
195   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
196   if (bio == nullptr) {
197     throw std::runtime_error("BIO_new: " + getErrors());
198   }
199
200   int written = BIO_write(bio.get(), cert.data(), cert.size());
201   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
202     throw std::runtime_error("BIO_write: " + getErrors());
203   }
204
205   ssl::X509UniquePtr x509(
206       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
207   if (x509 == nullptr) {
208     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
209   }
210
211   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
212     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
213   }
214 }
215
216 void SSLContext::loadPrivateKey(const char* path, const char* format) {
217   if (path == nullptr || format == nullptr) {
218     throw std::invalid_argument(
219         "loadPrivateKey: either <path> or <format> is nullptr");
220   }
221   if (strcmp(format, "PEM") == 0) {
222     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
223       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
224     }
225   } else {
226     throw std::runtime_error("Unsupported private key format: " + std::string(format));
227   }
228 }
229
230 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
231   if (pkey.data() == nullptr) {
232     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
233   }
234
235   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
236   if (bio == nullptr) {
237     throw std::runtime_error("BIO_new: " + getErrors());
238   }
239
240   int written = BIO_write(bio.get(), pkey.data(), pkey.size());
241   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
242     throw std::runtime_error("BIO_write: " + getErrors());
243   }
244
245   ssl::EvpPkeyUniquePtr key(
246       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
247   if (key == nullptr) {
248     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
249   }
250
251   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
252     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
253   }
254 }
255
256 void SSLContext::loadTrustedCertificates(const char* path) {
257   if (path == nullptr) {
258     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
259   }
260   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
261     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
262   }
263 }
264
265 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
266   SSL_CTX_set_cert_store(ctx_, store);
267 }
268
269 void SSLContext::loadClientCAList(const char* path) {
270   auto clientCAs = SSL_load_client_CA_file(path);
271   if (clientCAs == nullptr) {
272     LOG(ERROR) << "Unable to load ca file: " << path;
273     return;
274   }
275   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
276 }
277
278 void SSLContext::randomize() {
279   RAND_poll();
280 }
281
282 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
283   if (collector == nullptr) {
284     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
285     return;
286   }
287   collector_ = collector;
288   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
289   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
290 }
291
292 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
293
294 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
295   serverNameCb_ = cb;
296 }
297
298 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
299   clientHelloCbs_.push_back(cb);
300 }
301
302 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
303   SSLContext* context = (SSLContext*)data;
304
305   if (context == nullptr) {
306     return SSL_TLSEXT_ERR_NOACK;
307   }
308
309   for (auto& cb : context->clientHelloCbs_) {
310     // Generic callbacks to happen after we receive the Client Hello.
311     // For example, we use one to switch which cipher we use depending
312     // on the user's TLS version.  Because the primary purpose of
313     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
314     // are side-uses, we ignore any possible failures other than just logging
315     // them.
316     cb(ssl);
317   }
318
319   if (!context->serverNameCb_) {
320     return SSL_TLSEXT_ERR_NOACK;
321   }
322
323   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
324   switch (ret) {
325     case SERVER_NAME_FOUND:
326       return SSL_TLSEXT_ERR_OK;
327     case SERVER_NAME_NOT_FOUND:
328       return SSL_TLSEXT_ERR_NOACK;
329     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
330       *al = TLS1_AD_UNRECOGNIZED_NAME;
331       return SSL_TLSEXT_ERR_ALERT_FATAL;
332     default:
333       CHECK(false);
334   }
335
336   return SSL_TLSEXT_ERR_NOACK;
337 }
338
339 void SSLContext::switchCiphersIfTLS11(
340     SSL* ssl,
341     const std::string& tls11CipherString) {
342
343   CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
344
345   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
346     // We only do this for TLS v 1.1 and later
347     return;
348   }
349
350   // Prefer AES for TLS versions 1.1 and later since these are not
351   // vulnerable to BEAST attacks on AES.  Note that we're setting the
352   // cipher list on the SSL object, not the SSL_CTX object, so it will
353   // only last for this request.
354   int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
355   if ((rc == 0) || ERR_peek_error() != 0) {
356     // This shouldn't happen since we checked for this when proxygen
357     // started up.
358     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
359     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
360   }
361 }
362 #endif
363
364 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
365 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
366                                    const unsigned char** out,
367                                    unsigned char* outlen,
368                                    const unsigned char* in,
369                                    unsigned int inlen,
370                                    void* data) {
371   SSLContext* context = (SSLContext*)data;
372   CHECK(context);
373   if (context->advertisedNextProtocols_.empty()) {
374     *out = nullptr;
375     *outlen = 0;
376   } else {
377     auto i = context->pickNextProtocols();
378     const auto& item = context->advertisedNextProtocols_[i];
379     if (SSL_select_next_proto((unsigned char**)out,
380                               outlen,
381                               item.protocols,
382                               item.length,
383                               in,
384                               inlen) != OPENSSL_NPN_NEGOTIATED) {
385       return SSL_TLSEXT_ERR_NOACK;
386     }
387   }
388   return SSL_TLSEXT_ERR_OK;
389 }
390 #endif
391
392 #ifdef OPENSSL_NPN_NEGOTIATED
393
394 bool SSLContext::setAdvertisedNextProtocols(
395     const std::list<std::string>& protocols, NextProtocolType protocolType) {
396   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
397 }
398
399 bool SSLContext::setRandomizedAdvertisedNextProtocols(
400     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
401   unsetNextProtocols();
402   if (items.size() == 0) {
403     return false;
404   }
405   int total_weight = 0;
406   for (const auto &item : items) {
407     if (item.protocols.size() == 0) {
408       continue;
409     }
410     AdvertisedNextProtocolsItem advertised_item;
411     advertised_item.length = 0;
412     for (const auto& proto : item.protocols) {
413       ++advertised_item.length;
414       unsigned protoLength = proto.length();
415       if (protoLength >= 256) {
416         deleteNextProtocolsStrings();
417         return false;
418       }
419       advertised_item.length += protoLength;
420     }
421     advertised_item.protocols = new unsigned char[advertised_item.length];
422     if (!advertised_item.protocols) {
423       throw std::runtime_error("alloc failure");
424     }
425     unsigned char* dst = advertised_item.protocols;
426     for (auto& proto : item.protocols) {
427       unsigned protoLength = proto.length();
428       *dst++ = (unsigned char)protoLength;
429       memcpy(dst, proto.data(), protoLength);
430       dst += protoLength;
431     }
432     total_weight += item.weight;
433     advertisedNextProtocols_.push_back(advertised_item);
434     advertisedNextProtocolWeights_.push_back(item.weight);
435   }
436   if (total_weight == 0) {
437     deleteNextProtocolsStrings();
438     return false;
439   }
440   nextProtocolDistribution_ =
441       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
442                                    advertisedNextProtocolWeights_.end());
443   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
444     SSL_CTX_set_next_protos_advertised_cb(
445         ctx_, advertisedNextProtocolCallback, this);
446     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
447   }
448 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
449   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
450     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
451     // Client cannot really use randomized alpn
452     SSL_CTX_set_alpn_protos(ctx_,
453                             advertisedNextProtocols_[0].protocols,
454                             advertisedNextProtocols_[0].length);
455   }
456 #endif
457   return true;
458 }
459
460 void SSLContext::deleteNextProtocolsStrings() {
461   for (auto protocols : advertisedNextProtocols_) {
462     delete[] protocols.protocols;
463   }
464   advertisedNextProtocols_.clear();
465   advertisedNextProtocolWeights_.clear();
466 }
467
468 void SSLContext::unsetNextProtocols() {
469   deleteNextProtocolsStrings();
470   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
471   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
472 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
473   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
474   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
475 #endif
476 }
477
478 size_t SSLContext::pickNextProtocols() {
479   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
480   return nextProtocolDistribution_(nextProtocolPicker_);
481 }
482
483 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
484       const unsigned char** out, unsigned int* outlen, void* data) {
485   SSLContext* context = (SSLContext*)data;
486   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
487     *out = nullptr;
488     *outlen = 0;
489   } else if (context->advertisedNextProtocols_.size() == 1) {
490     *out = context->advertisedNextProtocols_[0].protocols;
491     *outlen = context->advertisedNextProtocols_[0].length;
492   } else {
493     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
494           sNextProtocolsExDataIndex_));
495     if (selected_index) {
496       --selected_index;
497       *out = context->advertisedNextProtocols_[selected_index].protocols;
498       *outlen = context->advertisedNextProtocols_[selected_index].length;
499     } else {
500       auto i = context->pickNextProtocols();
501       uintptr_t selected = i + 1;
502       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
503       *out = context->advertisedNextProtocols_[i].protocols;
504       *outlen = context->advertisedNextProtocols_[i].length;
505     }
506   }
507   return SSL_TLSEXT_ERR_OK;
508 }
509
510 int SSLContext::selectNextProtocolCallback(SSL* ssl,
511                                            unsigned char** out,
512                                            unsigned char* outlen,
513                                            const unsigned char* server,
514                                            unsigned int server_len,
515                                            void* data) {
516   (void)ssl; // Make -Wunused-parameters happy
517   SSLContext* ctx = (SSLContext*)data;
518   if (ctx->advertisedNextProtocols_.size() > 1) {
519     VLOG(3) << "SSLContext::selectNextProcolCallback() "
520             << "client should be deterministic in selecting protocols.";
521   }
522
523   unsigned char *client;
524   unsigned int client_len;
525   bool filtered = false;
526   auto cpf = ctx->getClientProtocolFilterCallback();
527   if (cpf) {
528     filtered = (*cpf)(&client, &client_len, server, server_len);
529   }
530
531   if (!filtered) {
532     if (ctx->advertisedNextProtocols_.empty()) {
533       client = (unsigned char *) "";
534       client_len = 0;
535     } else {
536       client = ctx->advertisedNextProtocols_[0].protocols;
537       client_len = ctx->advertisedNextProtocols_[0].length;
538     }
539   }
540
541   int retval = SSL_select_next_proto(out, outlen, server, server_len,
542                                      client, client_len);
543   if (retval != OPENSSL_NPN_NEGOTIATED) {
544     VLOG(3) << "SSLContext::selectNextProcolCallback() "
545             << "unable to pick a next protocol.";
546   }
547   return SSL_TLSEXT_ERR_OK;
548 }
549 #endif // OPENSSL_NPN_NEGOTIATED
550
551 SSL* SSLContext::createSSL() const {
552   SSL* ssl = SSL_new(ctx_);
553   if (ssl == nullptr) {
554     throw std::runtime_error("SSL_new: " + getErrors());
555   }
556   return ssl;
557 }
558
559 void SSLContext::setSessionCacheContext(const std::string& context) {
560   SSL_CTX_set_session_id_context(
561       ctx_,
562       reinterpret_cast<const unsigned char*>(context.data()),
563       std::min(
564           static_cast<int>(context.length()), SSL_MAX_SSL_SESSION_ID_LENGTH));
565 }
566
567 /**
568  * Match a name with a pattern. The pattern may include wildcard. A single
569  * wildcard "*" can match up to one component in the domain name.
570  *
571  * @param  host    Host name, typically the name of the remote host
572  * @param  pattern Name retrieved from certificate
573  * @param  size    Size of "pattern"
574  * @return True, if "host" matches "pattern". False otherwise.
575  */
576 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
577   bool match = false;
578   int i = 0, j = 0;
579   while (i < size && host[j] != '\0') {
580     if (toupper(pattern[i]) == toupper(host[j])) {
581       i++;
582       j++;
583       continue;
584     }
585     if (pattern[i] == '*') {
586       while (host[j] != '.' && host[j] != '\0') {
587         j++;
588       }
589       i++;
590       continue;
591     }
592     break;
593   }
594   if (i == size && host[j] == '\0') {
595     match = true;
596   }
597   return match;
598 }
599
600 int SSLContext::passwordCallback(char* password,
601                                  int size,
602                                  int,
603                                  void* data) {
604   SSLContext* context = (SSLContext*)data;
605   if (context == nullptr || context->passwordCollector() == nullptr) {
606     return 0;
607   }
608   std::string userPassword;
609   // call user defined password collector to get password
610   context->passwordCollector()->getPassword(userPassword, size);
611   int length = userPassword.size();
612   if (length > size) {
613     length = size;
614   }
615   strncpy(password, userPassword.c_str(), length);
616   return length;
617 }
618
619 struct SSLLock {
620   explicit SSLLock(
621     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
622       lockType(inLockType) {
623   }
624
625   void lock() {
626     if (lockType == SSLContext::LOCK_MUTEX) {
627       mutex.lock();
628     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
629       spinLock.lock();
630     }
631     // lockType == LOCK_NONE, no-op
632   }
633
634   void unlock() {
635     if (lockType == SSLContext::LOCK_MUTEX) {
636       mutex.unlock();
637     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
638       spinLock.unlock();
639     }
640     // lockType == LOCK_NONE, no-op
641   }
642
643   SSLContext::SSLLockType lockType;
644   folly::SpinLock spinLock{};
645   std::mutex mutex;
646 };
647
648 // Statics are unsafe in environments that call exit().
649 // If one thread calls exit() while another thread is
650 // references a member of SSLContext, bad things can happen.
651 // SSLContext runs in such environments.
652 // Instead of declaring a static member we "new" the static
653 // member so that it won't be destructed on exit().
654 static std::unique_ptr<SSLLock[]>& locks() {
655   static auto locksInst = new std::unique_ptr<SSLLock[]>();
656   return *locksInst;
657 }
658
659 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
660   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
661   return *lockTypesInst;
662 }
663
664 static void callbackLocking(int mode, int n, const char*, int) {
665   if (mode & CRYPTO_LOCK) {
666     locks()[n].lock();
667   } else {
668     locks()[n].unlock();
669   }
670 }
671
672 static unsigned long callbackThreadID() {
673   return static_cast<unsigned long>(
674 #ifdef __APPLE__
675     pthread_mach_thread_np(pthread_self())
676 #else
677     pthread_self()
678 #endif
679   );
680 }
681
682 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
683   return new CRYPTO_dynlock_value;
684 }
685
686 static void dyn_lock(int mode,
687                      struct CRYPTO_dynlock_value* lock,
688                      const char*, int) {
689   if (lock != nullptr) {
690     if (mode & CRYPTO_LOCK) {
691       lock->mutex.lock();
692     } else {
693       lock->mutex.unlock();
694     }
695   }
696 }
697
698 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
699   delete lock;
700 }
701
702 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
703   lockTypes() = inLockTypes;
704 }
705
706 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
707 void SSLContext::enableFalseStart() {
708   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
709 }
710 #endif
711
712 void SSLContext::markInitialized() {
713   std::lock_guard<std::mutex> g(initMutex());
714   initialized_ = true;
715 }
716
717 void SSLContext::initializeOpenSSL() {
718   std::lock_guard<std::mutex> g(initMutex());
719   initializeOpenSSLLocked();
720 }
721
722 void SSLContext::initializeOpenSSLLocked() {
723   if (initialized_) {
724     return;
725   }
726   SSL_library_init();
727   SSL_load_error_strings();
728   ERR_load_crypto_strings();
729   // static locking
730   locks().reset(new SSLLock[::CRYPTO_num_locks()]);
731   for (auto it: lockTypes()) {
732     locks()[it.first].lockType = it.second;
733   }
734   CRYPTO_set_id_callback(callbackThreadID);
735   CRYPTO_set_locking_callback(callbackLocking);
736   // dynamic locking
737   CRYPTO_set_dynlock_create_callback(dyn_create);
738   CRYPTO_set_dynlock_lock_callback(dyn_lock);
739   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
740   randomize();
741 #ifdef OPENSSL_NPN_NEGOTIATED
742   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
743       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
744 #endif
745   initialized_ = true;
746 }
747
748 void SSLContext::cleanupOpenSSL() {
749   std::lock_guard<std::mutex> g(initMutex());
750   cleanupOpenSSLLocked();
751 }
752
753 void SSLContext::cleanupOpenSSLLocked() {
754   if (!initialized_) {
755     return;
756   }
757
758   CRYPTO_set_id_callback(nullptr);
759   CRYPTO_set_locking_callback(nullptr);
760   CRYPTO_set_dynlock_create_callback(nullptr);
761   CRYPTO_set_dynlock_lock_callback(nullptr);
762   CRYPTO_set_dynlock_destroy_callback(nullptr);
763   CRYPTO_cleanup_all_ex_data();
764   ERR_free_strings();
765   EVP_cleanup();
766   ERR_remove_state(0);
767   locks().reset();
768   initialized_ = false;
769 }
770
771 void SSLContext::setOptions(long options) {
772   long newOpt = SSL_CTX_set_options(ctx_, options);
773   if ((newOpt & options) != options) {
774     throw std::runtime_error("SSL_CTX_set_options failed");
775   }
776 }
777
778 std::string SSLContext::getErrors(int errnoCopy) {
779   std::string errors;
780   unsigned long  errorCode;
781   char   message[256];
782
783   errors.reserve(512);
784   while ((errorCode = ERR_get_error()) != 0) {
785     if (!errors.empty()) {
786       errors += "; ";
787     }
788     const char* reason = ERR_reason_error_string(errorCode);
789     if (reason == nullptr) {
790       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
791       reason = message;
792     }
793     errors += reason;
794   }
795   if (errors.empty()) {
796     errors = "error code: " + folly::to<std::string>(errnoCopy);
797   }
798   return errors;
799 }
800
801 std::ostream&
802 operator<<(std::ostream& os, const PasswordCollector& collector) {
803   os << collector.describe();
804   return os;
805 }
806
807 } // folly