add callback to specify a client next protocol filter
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/SpinLock.h>
26
27 // ---------------------------------------------------------------------
28 // SSLContext implementation
29 // ---------------------------------------------------------------------
30
31 struct CRYPTO_dynlock_value {
32   std::mutex mutex;
33 };
34
35 namespace folly {
36
37 bool SSLContext::initialized_ = false;
38
39 namespace {
40
41 std::mutex& initMutex() {
42   static std::mutex m;
43   return m;
44 }
45
46 }  // anonymous namespace
47
48 #ifdef OPENSSL_NPN_NEGOTIATED
49 int SSLContext::sNextProtocolsExDataIndex_ = -1;
50 #endif
51
52 // SSLContext implementation
53 SSLContext::SSLContext(SSLVersion version) {
54   {
55     std::lock_guard<std::mutex> g(initMutex());
56     initializeOpenSSLLocked();
57   }
58
59   ctx_ = SSL_CTX_new(SSLv23_method());
60   if (ctx_ == nullptr) {
61     throw std::runtime_error("SSL_CTX_new: " + getErrors());
62   }
63
64   int opt = 0;
65   switch (version) {
66     case TLSv1:
67       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
68       break;
69     case SSLv3:
70       opt = SSL_OP_NO_SSLv2;
71       break;
72     default:
73       // do nothing
74       break;
75   }
76   int newOpt = SSL_CTX_set_options(ctx_, opt);
77   DCHECK((newOpt & opt) == opt);
78
79   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
80
81   checkPeerName_ = false;
82
83 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
84   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
85   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
86 #endif
87 }
88
89 SSLContext::~SSLContext() {
90   if (ctx_ != nullptr) {
91     SSL_CTX_free(ctx_);
92     ctx_ = nullptr;
93   }
94
95 #ifdef OPENSSL_NPN_NEGOTIATED
96   deleteNextProtocolsStrings();
97 #endif
98 }
99
100 void SSLContext::ciphers(const std::string& ciphers) {
101   providedCiphersString_ = ciphers;
102   setCiphersOrThrow(ciphers);
103 }
104
105 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
106   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
107   if (ERR_peek_error() != 0) {
108     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
109   }
110   if (rc == 0) {
111     throw std::runtime_error("None of specified ciphers are supported");
112   }
113 }
114
115 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
116     verifyPeer) {
117   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
118   verifyPeer_ = verifyPeer;
119 }
120
121 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
122     verifyPeer) {
123   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
124   int mode = SSL_VERIFY_NONE;
125   switch(verifyPeer) {
126     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
127     // break;
128
129     case SSLVerifyPeerEnum::VERIFY:
130       mode = SSL_VERIFY_PEER;
131       break;
132
133     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
134       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
135       break;
136
137     case SSLVerifyPeerEnum::NO_VERIFY:
138       mode = SSL_VERIFY_NONE;
139       break;
140
141     default:
142       break;
143   }
144   return mode;
145 }
146
147 int SSLContext::getVerificationMode() {
148   return getVerificationMode(verifyPeer_);
149 }
150
151 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
152                               const std::string& peerName) {
153   int mode;
154   if (checkPeerCert) {
155     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
156     checkPeerName_ = checkPeerName;
157     peerFixedName_ = peerName;
158   } else {
159     mode = SSL_VERIFY_NONE;
160     checkPeerName_ = false; // can't check name without cert!
161     peerFixedName_.clear();
162   }
163   SSL_CTX_set_verify(ctx_, mode, nullptr);
164 }
165
166 void SSLContext::loadCertificate(const char* path, const char* format) {
167   if (path == nullptr || format == nullptr) {
168     throw std::invalid_argument(
169          "loadCertificateChain: either <path> or <format> is nullptr");
170   }
171   if (strcmp(format, "PEM") == 0) {
172     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
173       int errnoCopy = errno;
174       std::string reason("SSL_CTX_use_certificate_chain_file: ");
175       reason.append(path);
176       reason.append(": ");
177       reason.append(getErrors(errnoCopy));
178       throw std::runtime_error(reason);
179     }
180   } else {
181     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
182   }
183 }
184
185 void SSLContext::loadPrivateKey(const char* path, const char* format) {
186   if (path == nullptr || format == nullptr) {
187     throw std::invalid_argument(
188          "loadPrivateKey: either <path> or <format> is nullptr");
189   }
190   if (strcmp(format, "PEM") == 0) {
191     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
192       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
193     }
194   } else {
195     throw std::runtime_error("Unsupported private key format: " + std::string(format));
196   }
197 }
198
199 void SSLContext::loadTrustedCertificates(const char* path) {
200   if (path == nullptr) {
201     throw std::invalid_argument(
202          "loadTrustedCertificates: <path> is nullptr");
203   }
204   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
205     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
206   }
207 }
208
209 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
210   SSL_CTX_set_cert_store(ctx_, store);
211 }
212
213 void SSLContext::loadClientCAList(const char* path) {
214   auto clientCAs = SSL_load_client_CA_file(path);
215   if (clientCAs == nullptr) {
216     LOG(ERROR) << "Unable to load ca file: " << path;
217     return;
218   }
219   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
220 }
221
222 void SSLContext::randomize() {
223   RAND_poll();
224 }
225
226 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
227   if (collector == nullptr) {
228     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
229     return;
230   }
231   collector_ = collector;
232   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
233   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
234 }
235
236 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
237
238 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
239   serverNameCb_ = cb;
240 }
241
242 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
243   clientHelloCbs_.push_back(cb);
244 }
245
246 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
247   SSLContext* context = (SSLContext*)data;
248
249   if (context == nullptr) {
250     return SSL_TLSEXT_ERR_NOACK;
251   }
252
253   for (auto& cb : context->clientHelloCbs_) {
254     // Generic callbacks to happen after we receive the Client Hello.
255     // For example, we use one to switch which cipher we use depending
256     // on the user's TLS version.  Because the primary purpose of
257     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
258     // are side-uses, we ignore any possible failures other than just logging
259     // them.
260     cb(ssl);
261   }
262
263   if (!context->serverNameCb_) {
264     return SSL_TLSEXT_ERR_NOACK;
265   }
266
267   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
268   switch (ret) {
269     case SERVER_NAME_FOUND:
270       return SSL_TLSEXT_ERR_OK;
271     case SERVER_NAME_NOT_FOUND:
272       return SSL_TLSEXT_ERR_NOACK;
273     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
274       *al = TLS1_AD_UNRECOGNIZED_NAME;
275       return SSL_TLSEXT_ERR_ALERT_FATAL;
276     default:
277       CHECK(false);
278   }
279
280   return SSL_TLSEXT_ERR_NOACK;
281 }
282
283 void SSLContext::switchCiphersIfTLS11(
284     SSL* ssl,
285     const std::string& tls11CipherString) {
286
287   CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
288
289   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
290     // We only do this for TLS v 1.1 and later
291     return;
292   }
293
294   // Prefer AES for TLS versions 1.1 and later since these are not
295   // vulnerable to BEAST attacks on AES.  Note that we're setting the
296   // cipher list on the SSL object, not the SSL_CTX object, so it will
297   // only last for this request.
298   int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
299   if ((rc == 0) || ERR_peek_error() != 0) {
300     // This shouldn't happen since we checked for this when proxygen
301     // started up.
302     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
303     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
304   }
305 }
306 #endif
307
308 #ifdef OPENSSL_NPN_NEGOTIATED
309 bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
310   return setRandomizedAdvertisedNextProtocols({{1, protocols}});
311 }
312
313 bool SSLContext::setRandomizedAdvertisedNextProtocols(
314     const std::list<NextProtocolsItem>& items) {
315   unsetNextProtocols();
316   if (items.size() == 0) {
317     return false;
318   }
319   int total_weight = 0;
320   for (const auto &item : items) {
321     if (item.protocols.size() == 0) {
322       continue;
323     }
324     AdvertisedNextProtocolsItem advertised_item;
325     advertised_item.length = 0;
326     for (const auto& proto : item.protocols) {
327       ++advertised_item.length;
328       unsigned protoLength = proto.length();
329       if (protoLength >= 256) {
330         deleteNextProtocolsStrings();
331         return false;
332       }
333       advertised_item.length += protoLength;
334     }
335     advertised_item.protocols = new unsigned char[advertised_item.length];
336     if (!advertised_item.protocols) {
337       throw std::runtime_error("alloc failure");
338     }
339     unsigned char* dst = advertised_item.protocols;
340     for (auto& proto : item.protocols) {
341       unsigned protoLength = proto.length();
342       *dst++ = (unsigned char)protoLength;
343       memcpy(dst, proto.data(), protoLength);
344       dst += protoLength;
345     }
346     total_weight += item.weight;
347     advertised_item.probability = item.weight;
348     advertisedNextProtocols_.push_back(advertised_item);
349   }
350   if (total_weight == 0) {
351     deleteNextProtocolsStrings();
352     return false;
353   }
354   for (auto &advertised_item : advertisedNextProtocols_) {
355     advertised_item.probability /= total_weight;
356   }
357   SSL_CTX_set_next_protos_advertised_cb(
358     ctx_, advertisedNextProtocolCallback, this);
359   SSL_CTX_set_next_proto_select_cb(
360     ctx_, selectNextProtocolCallback, this);
361   return true;
362 }
363
364 void SSLContext::deleteNextProtocolsStrings() {
365   for (auto protocols : advertisedNextProtocols_) {
366     delete[] protocols.protocols;
367   }
368   advertisedNextProtocols_.clear();
369 }
370
371 void SSLContext::unsetNextProtocols() {
372   deleteNextProtocolsStrings();
373   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
374   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
375 }
376
377 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
378       const unsigned char** out, unsigned int* outlen, void* data) {
379   SSLContext* context = (SSLContext*)data;
380   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
381     *out = nullptr;
382     *outlen = 0;
383   } else if (context->advertisedNextProtocols_.size() == 1) {
384     *out = context->advertisedNextProtocols_[0].protocols;
385     *outlen = context->advertisedNextProtocols_[0].length;
386   } else {
387     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
388           sNextProtocolsExDataIndex_));
389     if (selected_index) {
390       --selected_index;
391       *out = context->advertisedNextProtocols_[selected_index].protocols;
392       *outlen = context->advertisedNextProtocols_[selected_index].length;
393     } else {
394       unsigned char random_byte;
395       RAND_bytes(&random_byte, 1);
396       double random_value = random_byte / 255.0;
397       double sum = 0;
398       for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
399         sum += context->advertisedNextProtocols_[i].probability;
400         if (sum < random_value &&
401             i + 1 < context->advertisedNextProtocols_.size()) {
402           continue;
403         }
404         uintptr_t selected = i + 1;
405         SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
406         *out = context->advertisedNextProtocols_[i].protocols;
407         *outlen = context->advertisedNextProtocols_[i].length;
408         break;
409       }
410     }
411   }
412   return SSL_TLSEXT_ERR_OK;
413 }
414
415 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
416   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
417 SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
418   /**
419    * The list was generated as follows:
420    * grep "_CK_" openssl-1.0.1e/ssl/s3_lib.c -A 4 |
421    * while read A && read B && read C && read D && read E && read F; do
422    * echo $A $B $C $D $E; done |
423    * grep "\(SSL_kDHr\|SSL_kDHd\|SSL_kEDH\|SSL_kECDHr\|
424    *         SSL_kECDHe\|SSL_kEECDH\)" | grep -v SSL_aNULL | grep SSL_AES |
425    * awk -F, '{ print $1"," }'
426    */
427   ciphers_{
428     TLS1_CK_DH_DSS_WITH_AES_128_SHA,
429     TLS1_CK_DH_RSA_WITH_AES_128_SHA,
430     TLS1_CK_DHE_DSS_WITH_AES_128_SHA,
431     TLS1_CK_DHE_RSA_WITH_AES_128_SHA,
432     TLS1_CK_DH_DSS_WITH_AES_256_SHA,
433     TLS1_CK_DH_RSA_WITH_AES_256_SHA,
434     TLS1_CK_DHE_DSS_WITH_AES_256_SHA,
435     TLS1_CK_DHE_RSA_WITH_AES_256_SHA,
436     TLS1_CK_DH_DSS_WITH_AES_128_SHA256,
437     TLS1_CK_DH_RSA_WITH_AES_128_SHA256,
438     TLS1_CK_DHE_DSS_WITH_AES_128_SHA256,
439     TLS1_CK_DHE_RSA_WITH_AES_128_SHA256,
440     TLS1_CK_DH_DSS_WITH_AES_256_SHA256,
441     TLS1_CK_DH_RSA_WITH_AES_256_SHA256,
442     TLS1_CK_DHE_DSS_WITH_AES_256_SHA256,
443     TLS1_CK_DHE_RSA_WITH_AES_256_SHA256,
444     TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256,
445     TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384,
446     TLS1_CK_DH_RSA_WITH_AES_128_GCM_SHA256,
447     TLS1_CK_DH_RSA_WITH_AES_256_GCM_SHA384,
448     TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256,
449     TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384,
450     TLS1_CK_DH_DSS_WITH_AES_128_GCM_SHA256,
451     TLS1_CK_DH_DSS_WITH_AES_256_GCM_SHA384,
452     TLS1_CK_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
453     TLS1_CK_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
454     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
455     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
456     TLS1_CK_ECDH_RSA_WITH_AES_128_CBC_SHA,
457     TLS1_CK_ECDH_RSA_WITH_AES_256_CBC_SHA,
458     TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA,
459     TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA,
460     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256,
461     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384,
462     TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256,
463     TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384,
464     TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256,
465     TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384,
466     TLS1_CK_ECDH_RSA_WITH_AES_128_SHA256,
467     TLS1_CK_ECDH_RSA_WITH_AES_256_SHA384,
468     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
469     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
470     TLS1_CK_ECDH_ECDSA_WITH_AES_128_GCM_SHA256,
471     TLS1_CK_ECDH_ECDSA_WITH_AES_256_GCM_SHA384,
472     TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
473     TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
474     TLS1_CK_ECDH_RSA_WITH_AES_128_GCM_SHA256,
475   } {
476   length_ = sizeof(ciphers_)/sizeof(ciphers_[0]);
477   width_ = sizeof(ciphers_[0]);
478   qsort(ciphers_, length_, width_, compare_ulong);
479 }
480
481 bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher(
482   const SSL_CIPHER *cipher) {
483   unsigned long cid = cipher->id;
484   unsigned long *r =
485     (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong);
486   return r != nullptr;
487 }
488
489 int
490 SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) {
491   if (*(unsigned long *)x < *(unsigned long *)y) {
492     return -1;
493   }
494   if (*(unsigned long *)x > *(unsigned long *)y) {
495     return 1;
496   }
497   return 0;
498 };
499
500 bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) {
501   return falseStartChecker_.canUseFalseStartWithCipher(cipher);
502 }
503 #endif
504
505 int SSLContext::selectNextProtocolCallback(
506   SSL* ssl, unsigned char **out, unsigned char *outlen,
507   const unsigned char *server, unsigned int server_len, void *data) {
508
509   SSLContext* ctx = (SSLContext*)data;
510   if (ctx->advertisedNextProtocols_.size() > 1) {
511     VLOG(3) << "SSLContext::selectNextProcolCallback() "
512             << "client should be deterministic in selecting protocols.";
513   }
514
515   unsigned char *client;
516   unsigned int client_len;
517   bool filtered = false;
518   auto cpf = ctx->getClientProtocolFilterCallback();
519   if (cpf) {
520     filtered = (*cpf)(&client, &client_len, server, server_len);
521   }
522
523   if (!filtered) {
524     if (ctx->advertisedNextProtocols_.empty()) {
525       client = (unsigned char *) "";
526       client_len = 0;
527     } else {
528       client = ctx->advertisedNextProtocols_[0].protocols;
529       client_len = ctx->advertisedNextProtocols_[0].length;
530     }
531   }
532
533   int retval = SSL_select_next_proto(out, outlen, server, server_len,
534                                      client, client_len);
535   if (retval != OPENSSL_NPN_NEGOTIATED) {
536     VLOG(3) << "SSLContext::selectNextProcolCallback() "
537             << "unable to pick a next protocol.";
538 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
539   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
540   } else {
541     const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher;
542     if (cipher && ctx->canUseFalseStartWithCipher(cipher)) {
543       SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH);
544     }
545 #endif
546   }
547   return SSL_TLSEXT_ERR_OK;
548 }
549 #endif // OPENSSL_NPN_NEGOTIATED
550
551 SSL* SSLContext::createSSL() const {
552   SSL* ssl = SSL_new(ctx_);
553   if (ssl == nullptr) {
554     throw std::runtime_error("SSL_new: " + getErrors());
555   }
556   return ssl;
557 }
558
559 /**
560  * Match a name with a pattern. The pattern may include wildcard. A single
561  * wildcard "*" can match up to one component in the domain name.
562  *
563  * @param  host    Host name, typically the name of the remote host
564  * @param  pattern Name retrieved from certificate
565  * @param  size    Size of "pattern"
566  * @return True, if "host" matches "pattern". False otherwise.
567  */
568 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
569   bool match = false;
570   int i = 0, j = 0;
571   while (i < size && host[j] != '\0') {
572     if (toupper(pattern[i]) == toupper(host[j])) {
573       i++;
574       j++;
575       continue;
576     }
577     if (pattern[i] == '*') {
578       while (host[j] != '.' && host[j] != '\0') {
579         j++;
580       }
581       i++;
582       continue;
583     }
584     break;
585   }
586   if (i == size && host[j] == '\0') {
587     match = true;
588   }
589   return match;
590 }
591
592 int SSLContext::passwordCallback(char* password,
593                                  int size,
594                                  int,
595                                  void* data) {
596   SSLContext* context = (SSLContext*)data;
597   if (context == nullptr || context->passwordCollector() == nullptr) {
598     return 0;
599   }
600   std::string userPassword;
601   // call user defined password collector to get password
602   context->passwordCollector()->getPassword(userPassword, size);
603   int length = userPassword.size();
604   if (length > size) {
605     length = size;
606   }
607   strncpy(password, userPassword.c_str(), length);
608   return length;
609 }
610
611 struct SSLLock {
612   explicit SSLLock(
613     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
614       lockType(inLockType) {
615   }
616
617   void lock() {
618     if (lockType == SSLContext::LOCK_MUTEX) {
619       mutex.lock();
620     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
621       spinLock.lock();
622     }
623     // lockType == LOCK_NONE, no-op
624   }
625
626   void unlock() {
627     if (lockType == SSLContext::LOCK_MUTEX) {
628       mutex.unlock();
629     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
630       spinLock.unlock();
631     }
632     // lockType == LOCK_NONE, no-op
633   }
634
635   SSLContext::SSLLockType lockType;
636   folly::SpinLock spinLock{};
637   std::mutex mutex;
638 };
639
640 // Statics are unsafe in environments that call exit().
641 // If one thread calls exit() while another thread is
642 // references a member of SSLContext, bad things can happen.
643 // SSLContext runs in such environments.
644 // Instead of declaring a static member we "new" the static
645 // member so that it won't be destructed on exit().
646 static std::unique_ptr<SSLLock[]>& locks() {
647   static auto locksInst = new std::unique_ptr<SSLLock[]>();
648   return *locksInst;
649 }
650
651 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
652   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
653   return *lockTypesInst;
654 }
655
656 static void callbackLocking(int mode, int n, const char*, int) {
657   if (mode & CRYPTO_LOCK) {
658     locks()[n].lock();
659   } else {
660     locks()[n].unlock();
661   }
662 }
663
664 static unsigned long callbackThreadID() {
665   return static_cast<unsigned long>(
666 #ifdef __APPLE__
667     pthread_mach_thread_np(pthread_self())
668 #else
669     pthread_self()
670 #endif
671   );
672 }
673
674 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
675   return new CRYPTO_dynlock_value;
676 }
677
678 static void dyn_lock(int mode,
679                      struct CRYPTO_dynlock_value* lock,
680                      const char*, int) {
681   if (lock != nullptr) {
682     if (mode & CRYPTO_LOCK) {
683       lock->mutex.lock();
684     } else {
685       lock->mutex.unlock();
686     }
687   }
688 }
689
690 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
691   delete lock;
692 }
693
694 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
695   lockTypes() = inLockTypes;
696 }
697
698 void SSLContext::markInitialized() {
699   std::lock_guard<std::mutex> g(initMutex());
700   initialized_ = true;
701 }
702
703 void SSLContext::initializeOpenSSL() {
704   std::lock_guard<std::mutex> g(initMutex());
705   initializeOpenSSLLocked();
706 }
707
708 void SSLContext::initializeOpenSSLLocked() {
709   if (initialized_) {
710     return;
711   }
712   SSL_library_init();
713   SSL_load_error_strings();
714   ERR_load_crypto_strings();
715   // static locking
716   locks().reset(new SSLLock[::CRYPTO_num_locks()]);
717   for (auto it: lockTypes()) {
718     locks()[it.first].lockType = it.second;
719   }
720   CRYPTO_set_id_callback(callbackThreadID);
721   CRYPTO_set_locking_callback(callbackLocking);
722   // dynamic locking
723   CRYPTO_set_dynlock_create_callback(dyn_create);
724   CRYPTO_set_dynlock_lock_callback(dyn_lock);
725   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
726   randomize();
727 #ifdef OPENSSL_NPN_NEGOTIATED
728   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
729       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
730 #endif
731   initialized_ = true;
732 }
733
734 void SSLContext::cleanupOpenSSL() {
735   std::lock_guard<std::mutex> g(initMutex());
736   cleanupOpenSSLLocked();
737 }
738
739 void SSLContext::cleanupOpenSSLLocked() {
740   if (!initialized_) {
741     return;
742   }
743
744   CRYPTO_set_id_callback(nullptr);
745   CRYPTO_set_locking_callback(nullptr);
746   CRYPTO_set_dynlock_create_callback(nullptr);
747   CRYPTO_set_dynlock_lock_callback(nullptr);
748   CRYPTO_set_dynlock_destroy_callback(nullptr);
749   CRYPTO_cleanup_all_ex_data();
750   ERR_free_strings();
751   EVP_cleanup();
752   ERR_remove_state(0);
753   locks().reset();
754   initialized_ = false;
755 }
756
757 void SSLContext::setOptions(long options) {
758   long newOpt = SSL_CTX_set_options(ctx_, options);
759   if ((newOpt & options) != options) {
760     throw std::runtime_error("SSL_CTX_set_options failed");
761   }
762 }
763
764 std::string SSLContext::getErrors(int errnoCopy) {
765   std::string errors;
766   unsigned long  errorCode;
767   char   message[256];
768
769   errors.reserve(512);
770   while ((errorCode = ERR_get_error()) != 0) {
771     if (!errors.empty()) {
772       errors += "; ";
773     }
774     const char* reason = ERR_reason_error_string(errorCode);
775     if (reason == nullptr) {
776       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
777       reason = message;
778     }
779     errors += reason;
780   }
781   if (errors.empty()) {
782     errors = "error code: " + folly::to<std::string>(errnoCopy);
783   }
784   return errors;
785 }
786
787 std::ostream&
788 operator<<(std::ostream& os, const PasswordCollector& collector) {
789   os << collector.describe();
790   return os;
791 }
792
793 bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx,
794                                                   sockaddr_storage* addrStorage,
795                                                   socklen_t* addrLen) {
796   // Grab the ssl idx and then the ssl object so that we can get the peer
797   // name to compare against the ips in the subjectAltName
798   auto sslIdx = SSL_get_ex_data_X509_STORE_CTX_idx();
799   auto ssl =
800     reinterpret_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, sslIdx));
801   int fd = SSL_get_fd(ssl);
802   if (fd < 0) {
803     LOG(ERROR) << "Inexplicably couldn't get fd from SSL";
804     return false;
805   }
806
807   *addrLen = sizeof(*addrStorage);
808   if (getpeername(fd, reinterpret_cast<sockaddr*>(addrStorage), addrLen) != 0) {
809     PLOG(ERROR) << "Unable to get peer name";
810     return false;
811   }
812   CHECK(*addrLen <= sizeof(*addrStorage));
813   return true;
814 }
815
816 bool OpenSSLUtils::validatePeerCertNames(X509* cert,
817                                          const sockaddr* addr,
818                                          socklen_t addrLen) {
819   // Try to extract the names within the SAN extension from the certificate
820   auto altNames =
821     reinterpret_cast<STACK_OF(GENERAL_NAME)*>(
822         X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
823   SCOPE_EXIT {
824     if (altNames != nullptr) {
825       sk_GENERAL_NAME_pop_free(altNames, GENERAL_NAME_free);
826     }
827   };
828   if (altNames == nullptr) {
829     LOG(WARNING) << "No subjectAltName provided and we only support ip auth";
830     return false;
831   }
832
833   const sockaddr_in* addr4 = nullptr;
834   const sockaddr_in6* addr6 = nullptr;
835   if (addr != nullptr) {
836     if (addr->sa_family == AF_INET) {
837       addr4 = reinterpret_cast<const sockaddr_in*>(addr);
838     } else if (addr->sa_family == AF_INET6) {
839       addr6 = reinterpret_cast<const sockaddr_in6*>(addr);
840     } else {
841       LOG(FATAL) << "Unsupported sockaddr family: " << addr->sa_family;
842     }
843   }
844
845
846   for (int i = 0; i < sk_GENERAL_NAME_num(altNames); i++) {
847     auto name = sk_GENERAL_NAME_value(altNames, i);
848     if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) {
849       // Extra const-ness for paranoia
850       unsigned char const * const rawIpStr = name->d.iPAddress->data;
851       int const rawIpLen = name->d.iPAddress->length;
852
853       if (rawIpLen == 4 && addr4 != nullptr) {
854         if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) {
855           return true;
856         }
857       } else if (rawIpLen == 16 && addr6 != nullptr) {
858         if (::memcmp(rawIpStr, &addr6->sin6_addr, rawIpLen) == 0) {
859           return true;
860         }
861       } else if (rawIpLen != 4 && rawIpLen != 16) {
862         LOG(WARNING) << "Unexpected IP length: " << rawIpLen;
863       }
864     }
865   }
866
867   LOG(WARNING) << "Unable to match client cert against alt name ip";
868   return false;
869 }
870
871
872 } // folly