eb9920a73c33569b3e84df90c1d4dfa09cb8f0f0
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/Memory.h>
26 #include <folly/SpinLock.h>
27
28 // ---------------------------------------------------------------------
29 // SSLContext implementation
30 // ---------------------------------------------------------------------
31
32 struct CRYPTO_dynlock_value {
33   std::mutex mutex;
34 };
35
36 namespace folly {
37
38 bool SSLContext::initialized_ = false;
39
40 namespace {
41
42 std::mutex& initMutex() {
43   static std::mutex m;
44   return m;
45 }
46
47 } // anonymous namespace
48
49 #ifdef OPENSSL_NPN_NEGOTIATED
50 int SSLContext::sNextProtocolsExDataIndex_ = -1;
51 #endif
52
53 // SSLContext implementation
54 SSLContext::SSLContext(SSLVersion version) {
55   {
56     std::lock_guard<std::mutex> g(initMutex());
57     initializeOpenSSLLocked();
58   }
59
60   ctx_ = SSL_CTX_new(SSLv23_method());
61   if (ctx_ == nullptr) {
62     throw std::runtime_error("SSL_CTX_new: " + getErrors());
63   }
64
65   int opt = 0;
66   switch (version) {
67     case TLSv1:
68       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
69       break;
70     case SSLv3:
71       opt = SSL_OP_NO_SSLv2;
72       break;
73     default:
74       // do nothing
75       break;
76   }
77   int newOpt = SSL_CTX_set_options(ctx_, opt);
78   DCHECK((newOpt & opt) == opt);
79
80   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
81
82   checkPeerName_ = false;
83
84 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
85   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
86   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
87 #endif
88
89   Random::seed(randomGenerator_);
90 }
91
92 SSLContext::~SSLContext() {
93   if (ctx_ != nullptr) {
94     SSL_CTX_free(ctx_);
95     ctx_ = nullptr;
96   }
97
98 #ifdef OPENSSL_NPN_NEGOTIATED
99   deleteNextProtocolsStrings();
100 #endif
101 }
102
103 void SSLContext::ciphers(const std::string& ciphers) {
104   providedCiphersString_ = ciphers;
105   setCiphersOrThrow(ciphers);
106 }
107
108 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
109   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
110   if (ERR_peek_error() != 0) {
111     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
112   }
113   if (rc == 0) {
114     throw std::runtime_error("None of specified ciphers are supported");
115   }
116 }
117
118 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
119     verifyPeer) {
120   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
121   verifyPeer_ = verifyPeer;
122 }
123
124 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
125     verifyPeer) {
126   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
127   int mode = SSL_VERIFY_NONE;
128   switch(verifyPeer) {
129     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
130     // break;
131
132     case SSLVerifyPeerEnum::VERIFY:
133       mode = SSL_VERIFY_PEER;
134       break;
135
136     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
137       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
138       break;
139
140     case SSLVerifyPeerEnum::NO_VERIFY:
141       mode = SSL_VERIFY_NONE;
142       break;
143
144     default:
145       break;
146   }
147   return mode;
148 }
149
150 int SSLContext::getVerificationMode() {
151   return getVerificationMode(verifyPeer_);
152 }
153
154 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
155                               const std::string& peerName) {
156   int mode;
157   if (checkPeerCert) {
158     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
159     checkPeerName_ = checkPeerName;
160     peerFixedName_ = peerName;
161   } else {
162     mode = SSL_VERIFY_NONE;
163     checkPeerName_ = false; // can't check name without cert!
164     peerFixedName_.clear();
165   }
166   SSL_CTX_set_verify(ctx_, mode, nullptr);
167 }
168
169 void SSLContext::loadCertificate(const char* path, const char* format) {
170   if (path == nullptr || format == nullptr) {
171     throw std::invalid_argument(
172          "loadCertificateChain: either <path> or <format> is nullptr");
173   }
174   if (strcmp(format, "PEM") == 0) {
175     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
176       int errnoCopy = errno;
177       std::string reason("SSL_CTX_use_certificate_chain_file: ");
178       reason.append(path);
179       reason.append(": ");
180       reason.append(getErrors(errnoCopy));
181       throw std::runtime_error(reason);
182     }
183   } else {
184     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
185   }
186 }
187
188 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
189   if (cert.data() == nullptr) {
190     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
191   }
192
193   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
194   if (bio == nullptr) {
195     throw std::runtime_error("BIO_new: " + getErrors());
196   }
197
198   int written = BIO_write(bio.get(), cert.data(), cert.size());
199   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
200     throw std::runtime_error("BIO_write: " + getErrors());
201   }
202
203   ssl::X509UniquePtr x509(
204       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
205   if (x509 == nullptr) {
206     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
207   }
208
209   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
210     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
211   }
212 }
213
214 void SSLContext::loadPrivateKey(const char* path, const char* format) {
215   if (path == nullptr || format == nullptr) {
216     throw std::invalid_argument(
217         "loadPrivateKey: either <path> or <format> is nullptr");
218   }
219   if (strcmp(format, "PEM") == 0) {
220     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
221       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
222     }
223   } else {
224     throw std::runtime_error("Unsupported private key format: " + std::string(format));
225   }
226 }
227
228 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
229   if (pkey.data() == nullptr) {
230     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
231   }
232
233   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
234   if (bio == nullptr) {
235     throw std::runtime_error("BIO_new: " + getErrors());
236   }
237
238   int written = BIO_write(bio.get(), pkey.data(), pkey.size());
239   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
240     throw std::runtime_error("BIO_write: " + getErrors());
241   }
242
243   ssl::EvpPkeyUniquePtr key(
244       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
245   if (key == nullptr) {
246     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
247   }
248
249   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
250     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
251   }
252 }
253
254 void SSLContext::loadTrustedCertificates(const char* path) {
255   if (path == nullptr) {
256     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
257   }
258   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
259     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
260   }
261 }
262
263 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
264   SSL_CTX_set_cert_store(ctx_, store);
265 }
266
267 void SSLContext::loadClientCAList(const char* path) {
268   auto clientCAs = SSL_load_client_CA_file(path);
269   if (clientCAs == nullptr) {
270     LOG(ERROR) << "Unable to load ca file: " << path;
271     return;
272   }
273   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
274 }
275
276 void SSLContext::randomize() {
277   RAND_poll();
278 }
279
280 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
281   if (collector == nullptr) {
282     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
283     return;
284   }
285   collector_ = collector;
286   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
287   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
288 }
289
290 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
291
292 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
293   serverNameCb_ = cb;
294 }
295
296 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
297   clientHelloCbs_.push_back(cb);
298 }
299
300 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
301   SSLContext* context = (SSLContext*)data;
302
303   if (context == nullptr) {
304     return SSL_TLSEXT_ERR_NOACK;
305   }
306
307   for (auto& cb : context->clientHelloCbs_) {
308     // Generic callbacks to happen after we receive the Client Hello.
309     // For example, we use one to switch which cipher we use depending
310     // on the user's TLS version.  Because the primary purpose of
311     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
312     // are side-uses, we ignore any possible failures other than just logging
313     // them.
314     cb(ssl);
315   }
316
317   if (!context->serverNameCb_) {
318     return SSL_TLSEXT_ERR_NOACK;
319   }
320
321   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
322   switch (ret) {
323     case SERVER_NAME_FOUND:
324       return SSL_TLSEXT_ERR_OK;
325     case SERVER_NAME_NOT_FOUND:
326       return SSL_TLSEXT_ERR_NOACK;
327     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
328       *al = TLS1_AD_UNRECOGNIZED_NAME;
329       return SSL_TLSEXT_ERR_ALERT_FATAL;
330     default:
331       CHECK(false);
332   }
333
334   return SSL_TLSEXT_ERR_NOACK;
335 }
336
337 void SSLContext::switchCiphersIfTLS11(
338     SSL* ssl,
339     const std::string& tls11CipherString,
340     const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
341   CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
342       << "Shouldn't call if empty ciphers / alt ciphers";
343
344   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
345     // We only do this for TLS v 1.1 and later
346     return;
347   }
348
349   const std::string* ciphers = &tls11CipherString;
350   if (!tls11AltCipherlist.empty()) {
351     if (!cipherListPicker_) {
352       std::vector<int> weights;
353       std::for_each(
354           tls11AltCipherlist.begin(),
355           tls11AltCipherlist.end(),
356           [&](const std::pair<std::string, int>& e) {
357             weights.push_back(e.second);
358           });
359       cipherListPicker_.reset(
360           new std::discrete_distribution<int>(weights.begin(), weights.end()));
361     }
362     auto index = (*cipherListPicker_)(randomGenerator_);
363     if ((size_t)index >= tls11AltCipherlist.size()) {
364       LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
365                  << ", but tls11AltCipherlist is of length "
366                  << tls11AltCipherlist.size();
367     } else {
368       ciphers = &tls11AltCipherlist[index].first;
369     }
370   }
371
372   // Prefer AES for TLS versions 1.1 and later since these are not
373   // vulnerable to BEAST attacks on AES.  Note that we're setting the
374   // cipher list on the SSL object, not the SSL_CTX object, so it will
375   // only last for this request.
376   int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
377   if ((rc == 0) || ERR_peek_error() != 0) {
378     // This shouldn't happen since we checked for this when proxygen
379     // started up.
380     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
381     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
382   }
383 }
384 #endif
385
386 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
387 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
388                                    const unsigned char** out,
389                                    unsigned char* outlen,
390                                    const unsigned char* in,
391                                    unsigned int inlen,
392                                    void* data) {
393   SSLContext* context = (SSLContext*)data;
394   CHECK(context);
395   if (context->advertisedNextProtocols_.empty()) {
396     *out = nullptr;
397     *outlen = 0;
398   } else {
399     auto i = context->pickNextProtocols();
400     const auto& item = context->advertisedNextProtocols_[i];
401     if (SSL_select_next_proto((unsigned char**)out,
402                               outlen,
403                               item.protocols,
404                               item.length,
405                               in,
406                               inlen) != OPENSSL_NPN_NEGOTIATED) {
407       return SSL_TLSEXT_ERR_NOACK;
408     }
409   }
410   return SSL_TLSEXT_ERR_OK;
411 }
412 #endif
413
414 #ifdef OPENSSL_NPN_NEGOTIATED
415
416 bool SSLContext::setAdvertisedNextProtocols(
417     const std::list<std::string>& protocols, NextProtocolType protocolType) {
418   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
419 }
420
421 bool SSLContext::setRandomizedAdvertisedNextProtocols(
422     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
423   unsetNextProtocols();
424   if (items.size() == 0) {
425     return false;
426   }
427   int total_weight = 0;
428   for (const auto &item : items) {
429     if (item.protocols.size() == 0) {
430       continue;
431     }
432     AdvertisedNextProtocolsItem advertised_item;
433     advertised_item.length = 0;
434     for (const auto& proto : item.protocols) {
435       ++advertised_item.length;
436       unsigned protoLength = proto.length();
437       if (protoLength >= 256) {
438         deleteNextProtocolsStrings();
439         return false;
440       }
441       advertised_item.length += protoLength;
442     }
443     advertised_item.protocols = new unsigned char[advertised_item.length];
444     if (!advertised_item.protocols) {
445       throw std::runtime_error("alloc failure");
446     }
447     unsigned char* dst = advertised_item.protocols;
448     for (auto& proto : item.protocols) {
449       unsigned protoLength = proto.length();
450       *dst++ = (unsigned char)protoLength;
451       memcpy(dst, proto.data(), protoLength);
452       dst += protoLength;
453     }
454     total_weight += item.weight;
455     advertisedNextProtocols_.push_back(advertised_item);
456     advertisedNextProtocolWeights_.push_back(item.weight);
457   }
458   if (total_weight == 0) {
459     deleteNextProtocolsStrings();
460     return false;
461   }
462   nextProtocolDistribution_ =
463       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
464                                    advertisedNextProtocolWeights_.end());
465   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
466     SSL_CTX_set_next_protos_advertised_cb(
467         ctx_, advertisedNextProtocolCallback, this);
468     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
469   }
470 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
471   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
472     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
473     // Client cannot really use randomized alpn
474     SSL_CTX_set_alpn_protos(ctx_,
475                             advertisedNextProtocols_[0].protocols,
476                             advertisedNextProtocols_[0].length);
477   }
478 #endif
479   return true;
480 }
481
482 void SSLContext::deleteNextProtocolsStrings() {
483   for (auto protocols : advertisedNextProtocols_) {
484     delete[] protocols.protocols;
485   }
486   advertisedNextProtocols_.clear();
487   advertisedNextProtocolWeights_.clear();
488 }
489
490 void SSLContext::unsetNextProtocols() {
491   deleteNextProtocolsStrings();
492   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
493   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
494 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
495   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
496   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
497 #endif
498 }
499
500 size_t SSLContext::pickNextProtocols() {
501   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
502   return nextProtocolDistribution_(randomGenerator_);
503 }
504
505 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
506       const unsigned char** out, unsigned int* outlen, void* data) {
507   SSLContext* context = (SSLContext*)data;
508   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
509     *out = nullptr;
510     *outlen = 0;
511   } else if (context->advertisedNextProtocols_.size() == 1) {
512     *out = context->advertisedNextProtocols_[0].protocols;
513     *outlen = context->advertisedNextProtocols_[0].length;
514   } else {
515     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
516           sNextProtocolsExDataIndex_));
517     if (selected_index) {
518       --selected_index;
519       *out = context->advertisedNextProtocols_[selected_index].protocols;
520       *outlen = context->advertisedNextProtocols_[selected_index].length;
521     } else {
522       auto i = context->pickNextProtocols();
523       uintptr_t selected = i + 1;
524       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
525       *out = context->advertisedNextProtocols_[i].protocols;
526       *outlen = context->advertisedNextProtocols_[i].length;
527     }
528   }
529   return SSL_TLSEXT_ERR_OK;
530 }
531
532 int SSLContext::selectNextProtocolCallback(SSL* ssl,
533                                            unsigned char** out,
534                                            unsigned char* outlen,
535                                            const unsigned char* server,
536                                            unsigned int server_len,
537                                            void* data) {
538   (void)ssl; // Make -Wunused-parameters happy
539   SSLContext* ctx = (SSLContext*)data;
540   if (ctx->advertisedNextProtocols_.size() > 1) {
541     VLOG(3) << "SSLContext::selectNextProcolCallback() "
542             << "client should be deterministic in selecting protocols.";
543   }
544
545   unsigned char *client;
546   unsigned int client_len;
547   bool filtered = false;
548   auto cpf = ctx->getClientProtocolFilterCallback();
549   if (cpf) {
550     filtered = (*cpf)(&client, &client_len, server, server_len);
551   }
552
553   if (!filtered) {
554     if (ctx->advertisedNextProtocols_.empty()) {
555       client = (unsigned char *) "";
556       client_len = 0;
557     } else {
558       client = ctx->advertisedNextProtocols_[0].protocols;
559       client_len = ctx->advertisedNextProtocols_[0].length;
560     }
561   }
562
563   int retval = SSL_select_next_proto(out, outlen, server, server_len,
564                                      client, client_len);
565   if (retval != OPENSSL_NPN_NEGOTIATED) {
566     VLOG(3) << "SSLContext::selectNextProcolCallback() "
567             << "unable to pick a next protocol.";
568   }
569   return SSL_TLSEXT_ERR_OK;
570 }
571 #endif // OPENSSL_NPN_NEGOTIATED
572
573 SSL* SSLContext::createSSL() const {
574   SSL* ssl = SSL_new(ctx_);
575   if (ssl == nullptr) {
576     throw std::runtime_error("SSL_new: " + getErrors());
577   }
578   return ssl;
579 }
580
581 void SSLContext::setSessionCacheContext(const std::string& context) {
582   SSL_CTX_set_session_id_context(
583       ctx_,
584       reinterpret_cast<const unsigned char*>(context.data()),
585       std::min(
586           static_cast<int>(context.length()), SSL_MAX_SSL_SESSION_ID_LENGTH));
587 }
588
589 /**
590  * Match a name with a pattern. The pattern may include wildcard. A single
591  * wildcard "*" can match up to one component in the domain name.
592  *
593  * @param  host    Host name, typically the name of the remote host
594  * @param  pattern Name retrieved from certificate
595  * @param  size    Size of "pattern"
596  * @return True, if "host" matches "pattern". False otherwise.
597  */
598 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
599   bool match = false;
600   int i = 0, j = 0;
601   while (i < size && host[j] != '\0') {
602     if (toupper(pattern[i]) == toupper(host[j])) {
603       i++;
604       j++;
605       continue;
606     }
607     if (pattern[i] == '*') {
608       while (host[j] != '.' && host[j] != '\0') {
609         j++;
610       }
611       i++;
612       continue;
613     }
614     break;
615   }
616   if (i == size && host[j] == '\0') {
617     match = true;
618   }
619   return match;
620 }
621
622 int SSLContext::passwordCallback(char* password,
623                                  int size,
624                                  int,
625                                  void* data) {
626   SSLContext* context = (SSLContext*)data;
627   if (context == nullptr || context->passwordCollector() == nullptr) {
628     return 0;
629   }
630   std::string userPassword;
631   // call user defined password collector to get password
632   context->passwordCollector()->getPassword(userPassword, size);
633   int length = userPassword.size();
634   if (length > size) {
635     length = size;
636   }
637   strncpy(password, userPassword.c_str(), length);
638   return length;
639 }
640
641 struct SSLLock {
642   explicit SSLLock(
643     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
644       lockType(inLockType) {
645   }
646
647   void lock() {
648     if (lockType == SSLContext::LOCK_MUTEX) {
649       mutex.lock();
650     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
651       spinLock.lock();
652     }
653     // lockType == LOCK_NONE, no-op
654   }
655
656   void unlock() {
657     if (lockType == SSLContext::LOCK_MUTEX) {
658       mutex.unlock();
659     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
660       spinLock.unlock();
661     }
662     // lockType == LOCK_NONE, no-op
663   }
664
665   SSLContext::SSLLockType lockType;
666   folly::SpinLock spinLock{};
667   std::mutex mutex;
668 };
669
670 // Statics are unsafe in environments that call exit().
671 // If one thread calls exit() while another thread is
672 // references a member of SSLContext, bad things can happen.
673 // SSLContext runs in such environments.
674 // Instead of declaring a static member we "new" the static
675 // member so that it won't be destructed on exit().
676 static std::unique_ptr<SSLLock[]>& locks() {
677   static auto locksInst = new std::unique_ptr<SSLLock[]>();
678   return *locksInst;
679 }
680
681 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
682   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
683   return *lockTypesInst;
684 }
685
686 static void callbackLocking(int mode, int n, const char*, int) {
687   if (mode & CRYPTO_LOCK) {
688     locks()[n].lock();
689   } else {
690     locks()[n].unlock();
691   }
692 }
693
694 static unsigned long callbackThreadID() {
695   return static_cast<unsigned long>(
696 #ifdef __APPLE__
697     pthread_mach_thread_np(pthread_self())
698 #else
699     pthread_self()
700 #endif
701   );
702 }
703
704 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
705   return new CRYPTO_dynlock_value;
706 }
707
708 static void dyn_lock(int mode,
709                      struct CRYPTO_dynlock_value* lock,
710                      const char*, int) {
711   if (lock != nullptr) {
712     if (mode & CRYPTO_LOCK) {
713       lock->mutex.lock();
714     } else {
715       lock->mutex.unlock();
716     }
717   }
718 }
719
720 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
721   delete lock;
722 }
723
724 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
725   lockTypes() = inLockTypes;
726 }
727
728 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
729 void SSLContext::enableFalseStart() {
730   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
731 }
732 #endif
733
734 void SSLContext::markInitialized() {
735   std::lock_guard<std::mutex> g(initMutex());
736   initialized_ = true;
737 }
738
739 void SSLContext::initializeOpenSSL() {
740   std::lock_guard<std::mutex> g(initMutex());
741   initializeOpenSSLLocked();
742 }
743
744 void SSLContext::initializeOpenSSLLocked() {
745   if (initialized_) {
746     return;
747   }
748   SSL_library_init();
749   SSL_load_error_strings();
750   ERR_load_crypto_strings();
751   // static locking
752   locks().reset(new SSLLock[::CRYPTO_num_locks()]);
753   for (auto it: lockTypes()) {
754     locks()[it.first].lockType = it.second;
755   }
756   CRYPTO_set_id_callback(callbackThreadID);
757   CRYPTO_set_locking_callback(callbackLocking);
758   // dynamic locking
759   CRYPTO_set_dynlock_create_callback(dyn_create);
760   CRYPTO_set_dynlock_lock_callback(dyn_lock);
761   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
762   randomize();
763 #ifdef OPENSSL_NPN_NEGOTIATED
764   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
765       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
766 #endif
767   initialized_ = true;
768 }
769
770 void SSLContext::cleanupOpenSSL() {
771   std::lock_guard<std::mutex> g(initMutex());
772   cleanupOpenSSLLocked();
773 }
774
775 void SSLContext::cleanupOpenSSLLocked() {
776   if (!initialized_) {
777     return;
778   }
779
780   CRYPTO_set_id_callback(nullptr);
781   CRYPTO_set_locking_callback(nullptr);
782   CRYPTO_set_dynlock_create_callback(nullptr);
783   CRYPTO_set_dynlock_lock_callback(nullptr);
784   CRYPTO_set_dynlock_destroy_callback(nullptr);
785   CRYPTO_cleanup_all_ex_data();
786   ERR_free_strings();
787   EVP_cleanup();
788   ERR_remove_state(0);
789   locks().reset();
790   initialized_ = false;
791 }
792
793 void SSLContext::setOptions(long options) {
794   long newOpt = SSL_CTX_set_options(ctx_, options);
795   if ((newOpt & options) != options) {
796     throw std::runtime_error("SSL_CTX_set_options failed");
797   }
798 }
799
800 std::string SSLContext::getErrors(int errnoCopy) {
801   std::string errors;
802   unsigned long  errorCode;
803   char   message[256];
804
805   errors.reserve(512);
806   while ((errorCode = ERR_get_error()) != 0) {
807     if (!errors.empty()) {
808       errors += "; ";
809     }
810     const char* reason = ERR_reason_error_string(errorCode);
811     if (reason == nullptr) {
812       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
813       reason = message;
814     }
815     errors += reason;
816   }
817   if (errors.empty()) {
818     errors = "error code: " + folly::to<std::string>(errnoCopy);
819   }
820   return errors;
821 }
822
823 std::ostream&
824 operator<<(std::ostream& os, const PasswordCollector& collector) {
825   os << collector.describe();
826   return os;
827 }
828
829 } // folly