Break long lines in SSLContext.
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <folly/Format.h>
20 #include <folly/Memory.h>
21 #include <folly/Random.h>
22 #include <folly/SharedMutex.h>
23 #include <folly/SpinLock.h>
24 #include <folly/ThreadId.h>
25
26 // ---------------------------------------------------------------------
27 // SSLContext implementation
28 // ---------------------------------------------------------------------
29
30 struct CRYPTO_dynlock_value {
31   std::mutex mutex;
32 };
33
34 namespace folly {
35 //
36 // For OpenSSL portability API
37 using namespace folly::ssl;
38
39 bool SSLContext::initialized_ = false;
40
41 namespace {
42
43 std::mutex& initMutex() {
44   static std::mutex m;
45   return m;
46 }
47
48 } // anonymous namespace
49
50 #ifdef OPENSSL_NPN_NEGOTIATED
51 int SSLContext::sNextProtocolsExDataIndex_ = -1;
52 #endif
53
54 // SSLContext implementation
55 SSLContext::SSLContext(SSLVersion version) {
56   {
57     std::lock_guard<std::mutex> g(initMutex());
58     initializeOpenSSLLocked();
59   }
60
61   ctx_ = SSL_CTX_new(SSLv23_method());
62   if (ctx_ == nullptr) {
63     throw std::runtime_error("SSL_CTX_new: " + getErrors());
64   }
65
66   int opt = 0;
67   switch (version) {
68     case TLSv1:
69       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
70       break;
71     case SSLv3:
72       opt = SSL_OP_NO_SSLv2;
73       break;
74     default:
75       // do nothing
76       break;
77   }
78   int newOpt = SSL_CTX_set_options(ctx_, opt);
79   DCHECK((newOpt & opt) == opt);
80
81   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
82
83   checkPeerName_ = false;
84
85   SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
86
87 #if FOLLY_OPENSSL_HAS_SNI
88   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
89   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
90 #endif
91 }
92
93 SSLContext::~SSLContext() {
94   if (ctx_ != nullptr) {
95     SSL_CTX_free(ctx_);
96     ctx_ = nullptr;
97   }
98
99 #ifdef OPENSSL_NPN_NEGOTIATED
100   deleteNextProtocolsStrings();
101 #endif
102 }
103
104 void SSLContext::ciphers(const std::string& ciphers) {
105   setCiphersOrThrow(ciphers);
106 }
107
108 void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
109   if (ciphers.size() == 0) {
110     return;
111   }
112   std::string opensslCipherList;
113   join(":", ciphers, opensslCipherList);
114   setCiphersOrThrow(opensslCipherList);
115 }
116
117 void SSLContext::setSignatureAlgorithms(
118     const std::vector<std::string>& sigalgs) {
119   if (sigalgs.size() == 0) {
120     return;
121   }
122 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
123   std::string opensslSigAlgsList;
124   join(":", sigalgs, opensslSigAlgsList);
125   int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
126   if (rc == 0) {
127     throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
128   }
129 #endif
130 }
131
132 void SSLContext::setClientECCurvesList(
133     const std::vector<std::string>& ecCurves) {
134   if (ecCurves.size() == 0) {
135     return;
136   }
137 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL
138   std::string ecCurvesList;
139   join(":", ecCurves, ecCurvesList);
140   int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
141   if (rc == 0) {
142     throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
143   }
144 #endif
145 }
146
147 void SSLContext::setServerECCurve(const std::string& curveName) {
148 #if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
149   EC_KEY* ecdh = nullptr;
150   int nid;
151
152   /*
153    * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
154    * from RFC 4492 section 5.1.1, or explicitly described curves over
155    * binary fields. OpenSSL only supports the "named curves", which provide
156    * maximum interoperability.
157    */
158
159   nid = OBJ_sn2nid(curveName.c_str());
160   if (nid == 0) {
161     LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
162   }
163   ecdh = EC_KEY_new_by_curve_name(nid);
164   if (ecdh == nullptr) {
165     LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
166   }
167
168   SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
169   EC_KEY_free(ecdh);
170 #else
171   throw std::runtime_error("Elliptic curve encryption not allowed");
172 #endif
173 }
174
175 void SSLContext::setX509VerifyParam(
176     const ssl::X509VerifyParam& x509VerifyParam) {
177   if (!x509VerifyParam) {
178     return;
179   }
180   if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
181     throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
182   }
183 }
184
185 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
186   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
187   if (rc == 0) {
188     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
189   }
190   providedCiphersString_ = ciphers;
191 }
192
193 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
194     verifyPeer) {
195   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
196   verifyPeer_ = verifyPeer;
197 }
198
199 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
200     verifyPeer) {
201   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
202   int mode = SSL_VERIFY_NONE;
203   switch(verifyPeer) {
204     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
205     // break;
206
207     case SSLVerifyPeerEnum::VERIFY:
208       mode = SSL_VERIFY_PEER;
209       break;
210
211     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
212       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
213       break;
214
215     case SSLVerifyPeerEnum::NO_VERIFY:
216       mode = SSL_VERIFY_NONE;
217       break;
218
219     default:
220       break;
221   }
222   return mode;
223 }
224
225 int SSLContext::getVerificationMode() {
226   return getVerificationMode(verifyPeer_);
227 }
228
229 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
230                               const std::string& peerName) {
231   int mode;
232   if (checkPeerCert) {
233     mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
234         SSL_VERIFY_CLIENT_ONCE;
235     checkPeerName_ = checkPeerName;
236     peerFixedName_ = peerName;
237   } else {
238     mode = SSL_VERIFY_NONE;
239     checkPeerName_ = false; // can't check name without cert!
240     peerFixedName_.clear();
241   }
242   SSL_CTX_set_verify(ctx_, mode, nullptr);
243 }
244
245 void SSLContext::loadCertificate(const char* path, const char* format) {
246   if (path == nullptr || format == nullptr) {
247     throw std::invalid_argument(
248          "loadCertificateChain: either <path> or <format> is nullptr");
249   }
250   if (strcmp(format, "PEM") == 0) {
251     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
252       int errnoCopy = errno;
253       std::string reason("SSL_CTX_use_certificate_chain_file: ");
254       reason.append(path);
255       reason.append(": ");
256       reason.append(getErrors(errnoCopy));
257       throw std::runtime_error(reason);
258     }
259   } else {
260     throw std::runtime_error(
261         "Unsupported certificate format: " + std::string(format));
262   }
263 }
264
265 void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
266   if (cert.data() == nullptr) {
267     throw std::invalid_argument("loadCertificate: <cert> is nullptr");
268   }
269
270   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
271   if (bio == nullptr) {
272     throw std::runtime_error("BIO_new: " + getErrors());
273   }
274
275   int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
276   if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
277     throw std::runtime_error("BIO_write: " + getErrors());
278   }
279
280   ssl::X509UniquePtr x509(
281       PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
282   if (x509 == nullptr) {
283     throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
284   }
285
286   if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
287     throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
288   }
289 }
290
291 void SSLContext::loadPrivateKey(const char* path, const char* format) {
292   if (path == nullptr || format == nullptr) {
293     throw std::invalid_argument(
294         "loadPrivateKey: either <path> or <format> is nullptr");
295   }
296   if (strcmp(format, "PEM") == 0) {
297     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
298       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
299     }
300   } else {
301     throw std::runtime_error(
302         "Unsupported private key format: " + std::string(format));
303   }
304 }
305
306 void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
307   if (pkey.data() == nullptr) {
308     throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
309   }
310
311   ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
312   if (bio == nullptr) {
313     throw std::runtime_error("BIO_new: " + getErrors());
314   }
315
316   int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
317   if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
318     throw std::runtime_error("BIO_write: " + getErrors());
319   }
320
321   ssl::EvpPkeyUniquePtr key(
322       PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
323   if (key == nullptr) {
324     throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
325   }
326
327   if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
328     throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
329   }
330 }
331
332 void SSLContext::loadTrustedCertificates(const char* path) {
333   if (path == nullptr) {
334     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
335   }
336   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
337     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
338   }
339   ERR_clear_error();
340 }
341
342 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
343   SSL_CTX_set_cert_store(ctx_, store);
344 }
345
346 void SSLContext::loadClientCAList(const char* path) {
347   auto clientCAs = SSL_load_client_CA_file(path);
348   if (clientCAs == nullptr) {
349     LOG(ERROR) << "Unable to load ca file: " << path;
350     return;
351   }
352   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
353 }
354
355 void SSLContext::randomize() {
356   RAND_poll();
357 }
358
359 void SSLContext::passwordCollector(
360     std::shared_ptr<PasswordCollector> collector) {
361   if (collector == nullptr) {
362     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
363     return;
364   }
365   collector_ = collector;
366   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
367   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
368 }
369
370 #if FOLLY_OPENSSL_HAS_SNI
371
372 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
373   serverNameCb_ = cb;
374 }
375
376 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
377   clientHelloCbs_.push_back(cb);
378 }
379
380 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
381   SSLContext* context = (SSLContext*)data;
382
383   if (context == nullptr) {
384     return SSL_TLSEXT_ERR_NOACK;
385   }
386
387   for (auto& cb : context->clientHelloCbs_) {
388     // Generic callbacks to happen after we receive the Client Hello.
389     // For example, we use one to switch which cipher we use depending
390     // on the user's TLS version.  Because the primary purpose of
391     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
392     // are side-uses, we ignore any possible failures other than just logging
393     // them.
394     cb(ssl);
395   }
396
397   if (!context->serverNameCb_) {
398     return SSL_TLSEXT_ERR_NOACK;
399   }
400
401   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
402   switch (ret) {
403     case SERVER_NAME_FOUND:
404       return SSL_TLSEXT_ERR_OK;
405     case SERVER_NAME_NOT_FOUND:
406       return SSL_TLSEXT_ERR_NOACK;
407     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
408       *al = TLS1_AD_UNRECOGNIZED_NAME;
409       return SSL_TLSEXT_ERR_ALERT_FATAL;
410     default:
411       CHECK(false);
412   }
413
414   return SSL_TLSEXT_ERR_NOACK;
415 }
416
417 void SSLContext::switchCiphersIfTLS11(
418     SSL* ssl,
419     const std::string& tls11CipherString,
420     const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
421   CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
422       << "Shouldn't call if empty ciphers / alt ciphers";
423
424   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
425     // We only do this for TLS v 1.1 and later
426     return;
427   }
428
429   const std::string* ciphers = &tls11CipherString;
430   if (!tls11AltCipherlist.empty()) {
431     if (!cipherListPicker_) {
432       std::vector<int> weights;
433       std::for_each(
434           tls11AltCipherlist.begin(),
435           tls11AltCipherlist.end(),
436           [&](const std::pair<std::string, int>& e) {
437             weights.push_back(e.second);
438           });
439       cipherListPicker_.reset(
440           new std::discrete_distribution<int>(weights.begin(), weights.end()));
441     }
442     auto rng = ThreadLocalPRNG();
443     auto index = (*cipherListPicker_)(rng);
444     if ((size_t)index >= tls11AltCipherlist.size()) {
445       LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
446                  << ", but tls11AltCipherlist is of length "
447                  << tls11AltCipherlist.size();
448     } else {
449       ciphers = &tls11AltCipherlist[size_t(index)].first;
450     }
451   }
452
453   // Prefer AES for TLS versions 1.1 and later since these are not
454   // vulnerable to BEAST attacks on AES.  Note that we're setting the
455   // cipher list on the SSL object, not the SSL_CTX object, so it will
456   // only last for this request.
457   int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
458   if ((rc == 0) || ERR_peek_error() != 0) {
459     // This shouldn't happen since we checked for this when proxygen
460     // started up.
461     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
462     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
463   }
464 }
465 #endif // FOLLY_OPENSSL_HAS_SNI
466
467 #if FOLLY_OPENSSL_HAS_ALPN
468 int SSLContext::alpnSelectCallback(SSL* /* ssl */,
469                                    const unsigned char** out,
470                                    unsigned char* outlen,
471                                    const unsigned char* in,
472                                    unsigned int inlen,
473                                    void* data) {
474   SSLContext* context = (SSLContext*)data;
475   CHECK(context);
476   if (context->advertisedNextProtocols_.empty()) {
477     *out = nullptr;
478     *outlen = 0;
479   } else {
480     auto i = context->pickNextProtocols();
481     const auto& item = context->advertisedNextProtocols_[i];
482     if (SSL_select_next_proto((unsigned char**)out,
483                               outlen,
484                               item.protocols,
485                               item.length,
486                               in,
487                               inlen) != OPENSSL_NPN_NEGOTIATED) {
488       return SSL_TLSEXT_ERR_NOACK;
489     }
490   }
491   return SSL_TLSEXT_ERR_OK;
492 }
493 #endif // FOLLY_OPENSSL_HAS_ALPN
494
495 #ifdef OPENSSL_NPN_NEGOTIATED
496
497 bool SSLContext::setAdvertisedNextProtocols(
498     const std::list<std::string>& protocols, NextProtocolType protocolType) {
499   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
500 }
501
502 bool SSLContext::setRandomizedAdvertisedNextProtocols(
503     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
504   unsetNextProtocols();
505   if (items.size() == 0) {
506     return false;
507   }
508   int total_weight = 0;
509   for (const auto &item : items) {
510     if (item.protocols.size() == 0) {
511       continue;
512     }
513     AdvertisedNextProtocolsItem advertised_item;
514     advertised_item.length = 0;
515     for (const auto& proto : item.protocols) {
516       ++advertised_item.length;
517       auto protoLength = proto.length();
518       if (protoLength >= 256) {
519         deleteNextProtocolsStrings();
520         return false;
521       }
522       advertised_item.length += unsigned(protoLength);
523     }
524     advertised_item.protocols = new unsigned char[advertised_item.length];
525     if (!advertised_item.protocols) {
526       throw std::runtime_error("alloc failure");
527     }
528     unsigned char* dst = advertised_item.protocols;
529     for (auto& proto : item.protocols) {
530       uint8_t protoLength = uint8_t(proto.length());
531       *dst++ = (unsigned char)protoLength;
532       memcpy(dst, proto.data(), protoLength);
533       dst += protoLength;
534     }
535     total_weight += item.weight;
536     advertisedNextProtocols_.push_back(advertised_item);
537     advertisedNextProtocolWeights_.push_back(item.weight);
538   }
539   if (total_weight == 0) {
540     deleteNextProtocolsStrings();
541     return false;
542   }
543   nextProtocolDistribution_ =
544       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
545                                    advertisedNextProtocolWeights_.end());
546   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
547     SSL_CTX_set_next_protos_advertised_cb(
548         ctx_, advertisedNextProtocolCallback, this);
549     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
550   }
551 #if FOLLY_OPENSSL_HAS_ALPN
552   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
553     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
554     // Client cannot really use randomized alpn
555     SSL_CTX_set_alpn_protos(ctx_,
556                             advertisedNextProtocols_[0].protocols,
557                             advertisedNextProtocols_[0].length);
558   }
559 #endif
560   return true;
561 }
562
563 void SSLContext::deleteNextProtocolsStrings() {
564   for (auto protocols : advertisedNextProtocols_) {
565     delete[] protocols.protocols;
566   }
567   advertisedNextProtocols_.clear();
568   advertisedNextProtocolWeights_.clear();
569 }
570
571 void SSLContext::unsetNextProtocols() {
572   deleteNextProtocolsStrings();
573   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
574   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
575 #if FOLLY_OPENSSL_HAS_ALPN
576   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
577   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
578 #endif
579 }
580
581 size_t SSLContext::pickNextProtocols() {
582   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
583   auto rng = ThreadLocalPRNG();
584   return size_t(nextProtocolDistribution_(rng));
585 }
586
587 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
588       const unsigned char** out, unsigned int* outlen, void* data) {
589   SSLContext* context = (SSLContext*)data;
590   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
591     *out = nullptr;
592     *outlen = 0;
593   } else if (context->advertisedNextProtocols_.size() == 1) {
594     *out = context->advertisedNextProtocols_[0].protocols;
595     *outlen = context->advertisedNextProtocols_[0].length;
596   } else {
597     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
598           sNextProtocolsExDataIndex_));
599     if (selected_index) {
600       --selected_index;
601       *out = context->advertisedNextProtocols_[selected_index].protocols;
602       *outlen = context->advertisedNextProtocols_[selected_index].length;
603     } else {
604       auto i = context->pickNextProtocols();
605       uintptr_t selected = i + 1;
606       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
607       *out = context->advertisedNextProtocols_[i].protocols;
608       *outlen = context->advertisedNextProtocols_[i].length;
609     }
610   }
611   return SSL_TLSEXT_ERR_OK;
612 }
613
614 int SSLContext::selectNextProtocolCallback(SSL* ssl,
615                                            unsigned char** out,
616                                            unsigned char* outlen,
617                                            const unsigned char* server,
618                                            unsigned int server_len,
619                                            void* data) {
620   (void)ssl; // Make -Wunused-parameters happy
621   SSLContext* ctx = (SSLContext*)data;
622   if (ctx->advertisedNextProtocols_.size() > 1) {
623     VLOG(3) << "SSLContext::selectNextProcolCallback() "
624             << "client should be deterministic in selecting protocols.";
625   }
626
627   unsigned char* client = nullptr;
628   unsigned int client_len = 0;
629   bool filtered = false;
630   auto cpf = ctx->getClientProtocolFilterCallback();
631   if (cpf) {
632     filtered = (*cpf)(&client, &client_len, server, server_len);
633   }
634
635   if (!filtered) {
636     if (ctx->advertisedNextProtocols_.empty()) {
637       client = (unsigned char *) "";
638       client_len = 0;
639     } else {
640       client = ctx->advertisedNextProtocols_[0].protocols;
641       client_len = ctx->advertisedNextProtocols_[0].length;
642     }
643   }
644
645   int retval = SSL_select_next_proto(out, outlen, server, server_len,
646                                      client, client_len);
647   if (retval != OPENSSL_NPN_NEGOTIATED) {
648     VLOG(3) << "SSLContext::selectNextProcolCallback() "
649             << "unable to pick a next protocol.";
650   }
651   return SSL_TLSEXT_ERR_OK;
652 }
653 #endif // OPENSSL_NPN_NEGOTIATED
654
655 SSL* SSLContext::createSSL() const {
656   SSL* ssl = SSL_new(ctx_);
657   if (ssl == nullptr) {
658     throw std::runtime_error("SSL_new: " + getErrors());
659   }
660   return ssl;
661 }
662
663 void SSLContext::setSessionCacheContext(const std::string& context) {
664   SSL_CTX_set_session_id_context(
665       ctx_,
666       reinterpret_cast<const unsigned char*>(context.data()),
667       std::min<unsigned int>(
668           static_cast<unsigned int>(context.length()),
669           SSL_MAX_SSL_SESSION_ID_LENGTH));
670 }
671
672 /**
673  * Match a name with a pattern. The pattern may include wildcard. A single
674  * wildcard "*" can match up to one component in the domain name.
675  *
676  * @param  host    Host name, typically the name of the remote host
677  * @param  pattern Name retrieved from certificate
678  * @param  size    Size of "pattern"
679  * @return True, if "host" matches "pattern". False otherwise.
680  */
681 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
682   bool match = false;
683   int i = 0, j = 0;
684   while (i < size && host[j] != '\0') {
685     if (toupper(pattern[i]) == toupper(host[j])) {
686       i++;
687       j++;
688       continue;
689     }
690     if (pattern[i] == '*') {
691       while (host[j] != '.' && host[j] != '\0') {
692         j++;
693       }
694       i++;
695       continue;
696     }
697     break;
698   }
699   if (i == size && host[j] == '\0') {
700     match = true;
701   }
702   return match;
703 }
704
705 int SSLContext::passwordCallback(char* password,
706                                  int size,
707                                  int,
708                                  void* data) {
709   SSLContext* context = (SSLContext*)data;
710   if (context == nullptr || context->passwordCollector() == nullptr) {
711     return 0;
712   }
713   std::string userPassword;
714   // call user defined password collector to get password
715   context->passwordCollector()->getPassword(userPassword, size);
716   auto length = int(userPassword.size());
717   if (length > size) {
718     length = size;
719   }
720   strncpy(password, userPassword.c_str(), size_t(length));
721   return length;
722 }
723
724 struct SSLLock {
725   explicit SSLLock(
726     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
727       lockType(inLockType) {
728   }
729
730   void lock(bool read) {
731     if (lockType == SSLContext::LOCK_MUTEX) {
732       mutex.lock();
733     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
734       spinLock.lock();
735     } else if (lockType == SSLContext::LOCK_SHAREDMUTEX) {
736       if (read) {
737         sharedMutex.lock_shared();
738       } else {
739         sharedMutex.lock();
740       }
741     }
742     // lockType == LOCK_NONE, no-op
743   }
744
745   void unlock(bool read) {
746     if (lockType == SSLContext::LOCK_MUTEX) {
747       mutex.unlock();
748     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
749       spinLock.unlock();
750     } else if (lockType == SSLContext::LOCK_SHAREDMUTEX) {
751       if (read) {
752         sharedMutex.unlock_shared();
753       } else {
754         sharedMutex.unlock();
755       }
756     }
757     // lockType == LOCK_NONE, no-op
758   }
759
760   SSLContext::SSLLockType lockType;
761   folly::SpinLock spinLock{};
762   std::mutex mutex;
763   SharedMutex sharedMutex;
764 };
765
766 // Statics are unsafe in environments that call exit().
767 // If one thread calls exit() while another thread is
768 // references a member of SSLContext, bad things can happen.
769 // SSLContext runs in such environments.
770 // Instead of declaring a static member we "new" the static
771 // member so that it won't be destructed on exit().
772 static std::unique_ptr<SSLLock[]>& locks() {
773   static auto locksInst = new std::unique_ptr<SSLLock[]>();
774   return *locksInst;
775 }
776
777 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
778   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
779   return *lockTypesInst;
780 }
781
782 static void callbackLocking(int mode, int n, const char*, int) {
783   if (mode & CRYPTO_LOCK) {
784     locks()[size_t(n)].lock(mode & CRYPTO_READ);
785   } else {
786     locks()[size_t(n)].unlock(mode & CRYPTO_READ);
787   }
788 }
789
790 static unsigned long callbackThreadID() {
791   return static_cast<unsigned long>(folly::getCurrentThreadID());
792 }
793
794 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
795   return new CRYPTO_dynlock_value;
796 }
797
798 static void dyn_lock(int mode,
799                      struct CRYPTO_dynlock_value* lock,
800                      const char*, int) {
801   if (lock != nullptr) {
802     if (mode & CRYPTO_LOCK) {
803       lock->mutex.lock();
804     } else {
805       lock->mutex.unlock();
806     }
807   }
808 }
809
810 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
811   delete lock;
812 }
813
814 void SSLContext::setSSLLockTypesLocked(std::map<int, SSLLockType> inLockTypes) {
815   lockTypes() = inLockTypes;
816 }
817
818 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
819   std::lock_guard<std::mutex> g(initMutex());
820   if (initialized_) {
821     // We set the locks on initialization, so if we are already initialized
822     // this would have no affect.
823     LOG(INFO) << "Ignoring setSSLLockTypes after initialization";
824     return;
825   }
826   setSSLLockTypesLocked(std::move(inLockTypes));
827 }
828
829 void SSLContext::setSSLLockTypesAndInitOpenSSL(
830     std::map<int, SSLLockType> inLockTypes) {
831   std::lock_guard<std::mutex> g(initMutex());
832   CHECK(!initialized_) << "OpenSSL is already initialized";
833   setSSLLockTypesLocked(std::move(inLockTypes));
834   initializeOpenSSLLocked();
835 }
836
837 bool SSLContext::isSSLLockDisabled(int lockId) {
838   std::lock_guard<std::mutex> g(initMutex());
839   CHECK(initialized_) << "OpenSSL is not initialized yet";
840   const auto& sslLocks = lockTypes();
841   const auto it = sslLocks.find(lockId);
842   return it != sslLocks.end() &&
843       it->second == SSLContext::SSLLockType::LOCK_NONE;
844 }
845
846 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
847 void SSLContext::enableFalseStart() {
848   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
849 }
850 #endif
851
852 void SSLContext::markInitialized() {
853   std::lock_guard<std::mutex> g(initMutex());
854   initialized_ = true;
855 }
856
857 void SSLContext::initializeOpenSSL() {
858   std::lock_guard<std::mutex> g(initMutex());
859   initializeOpenSSLLocked();
860 }
861
862 void SSLContext::initializeOpenSSLLocked() {
863   if (initialized_) {
864     return;
865   }
866   SSL_library_init();
867   SSL_load_error_strings();
868   ERR_load_crypto_strings();
869   // static locking
870   locks().reset(new SSLLock[size_t(CRYPTO_num_locks())]);
871   for (auto it: lockTypes()) {
872     locks()[size_t(it.first)].lockType = it.second;
873   }
874   CRYPTO_set_id_callback(callbackThreadID);
875   CRYPTO_set_locking_callback(callbackLocking);
876   // dynamic locking
877   CRYPTO_set_dynlock_create_callback(dyn_create);
878   CRYPTO_set_dynlock_lock_callback(dyn_lock);
879   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
880   randomize();
881 #ifdef OPENSSL_NPN_NEGOTIATED
882   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
883       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
884 #endif
885   initialized_ = true;
886 }
887
888 void SSLContext::cleanupOpenSSL() {
889   std::lock_guard<std::mutex> g(initMutex());
890   cleanupOpenSSLLocked();
891 }
892
893 void SSLContext::cleanupOpenSSLLocked() {
894   if (!initialized_) {
895     return;
896   }
897
898   CRYPTO_set_id_callback(nullptr);
899   CRYPTO_set_locking_callback(nullptr);
900   CRYPTO_set_dynlock_create_callback(nullptr);
901   CRYPTO_set_dynlock_lock_callback(nullptr);
902   CRYPTO_set_dynlock_destroy_callback(nullptr);
903   CRYPTO_cleanup_all_ex_data();
904   ERR_free_strings();
905   EVP_cleanup();
906   ERR_clear_error();
907   locks().reset();
908   initialized_ = false;
909 }
910
911 void SSLContext::setOptions(long options) {
912   long newOpt = SSL_CTX_set_options(ctx_, options);
913   if ((newOpt & options) != options) {
914     throw std::runtime_error("SSL_CTX_set_options failed");
915   }
916 }
917
918 std::string SSLContext::getErrors(int errnoCopy) {
919   std::string errors;
920   unsigned long  errorCode;
921   char   message[256];
922
923   errors.reserve(512);
924   while ((errorCode = ERR_get_error()) != 0) {
925     if (!errors.empty()) {
926       errors += "; ";
927     }
928     const char* reason = ERR_reason_error_string(errorCode);
929     if (reason == nullptr) {
930       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
931       reason = message;
932     }
933     errors += reason;
934   }
935   if (errors.empty()) {
936     errors = "error code: " + folly::to<std::string>(errnoCopy);
937   }
938   return errors;
939 }
940
941 std::ostream&
942 operator<<(std::ostream& os, const PasswordCollector& collector) {
943   os << collector.describe();
944   return os;
945 }
946
947 } // folly