Copyright 2014->2015
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/SpinLock.h>
26
27 // ---------------------------------------------------------------------
28 // SSLContext implementation
29 // ---------------------------------------------------------------------
30
31 struct CRYPTO_dynlock_value {
32   std::mutex mutex;
33 };
34
35 namespace folly {
36
37 bool SSLContext::initialized_ = false;
38 std::mutex    SSLContext::mutex_;
39 #ifdef OPENSSL_NPN_NEGOTIATED
40 int SSLContext::sNextProtocolsExDataIndex_ = -1;
41 #endif
42
43 #ifndef SSLCONTEXT_NO_REFCOUNT
44 uint64_t SSLContext::count_ = 0;
45 #endif
46
47 // SSLContext implementation
48 SSLContext::SSLContext(SSLVersion version) {
49   {
50     std::lock_guard<std::mutex> g(mutex_);
51 #ifndef SSLCONTEXT_NO_REFCOUNT
52     count_++;
53 #endif
54     initializeOpenSSLLocked();
55   }
56
57   ctx_ = SSL_CTX_new(SSLv23_method());
58   if (ctx_ == nullptr) {
59     throw std::runtime_error("SSL_CTX_new: " + getErrors());
60   }
61
62   int opt = 0;
63   switch (version) {
64     case TLSv1:
65       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
66       break;
67     case SSLv3:
68       opt = SSL_OP_NO_SSLv2;
69       break;
70     default:
71       // do nothing
72       break;
73   }
74   int newOpt = SSL_CTX_set_options(ctx_, opt);
75   DCHECK((newOpt & opt) == opt);
76
77   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
78
79   checkPeerName_ = false;
80
81 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
82   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
83   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
84 #endif
85 }
86
87 SSLContext::~SSLContext() {
88   if (ctx_ != nullptr) {
89     SSL_CTX_free(ctx_);
90     ctx_ = nullptr;
91   }
92
93 #ifdef OPENSSL_NPN_NEGOTIATED
94   deleteNextProtocolsStrings();
95 #endif
96
97 #ifndef SSLCONTEXT_NO_REFCOUNT
98   {
99     std::lock_guard<std::mutex> g(mutex_);
100     if (!--count_) {
101       cleanupOpenSSLLocked();
102     }
103   }
104 #endif
105 }
106
107 void SSLContext::ciphers(const std::string& ciphers) {
108   providedCiphersString_ = ciphers;
109   setCiphersOrThrow(ciphers);
110 }
111
112 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
113   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
114   if (ERR_peek_error() != 0) {
115     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
116   }
117   if (rc == 0) {
118     throw std::runtime_error("None of specified ciphers are supported");
119   }
120 }
121
122 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
123     verifyPeer) {
124   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
125   verifyPeer_ = verifyPeer;
126 }
127
128 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
129     verifyPeer) {
130   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
131   int mode = SSL_VERIFY_NONE;
132   switch(verifyPeer) {
133     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
134     // break;
135
136     case SSLVerifyPeerEnum::VERIFY:
137       mode = SSL_VERIFY_PEER;
138       break;
139
140     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
141       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
142       break;
143
144     case SSLVerifyPeerEnum::NO_VERIFY:
145       mode = SSL_VERIFY_NONE;
146       break;
147
148     default:
149       break;
150   }
151   return mode;
152 }
153
154 int SSLContext::getVerificationMode() {
155   return getVerificationMode(verifyPeer_);
156 }
157
158 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
159                               const std::string& peerName) {
160   int mode;
161   if (checkPeerCert) {
162     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
163     checkPeerName_ = checkPeerName;
164     peerFixedName_ = peerName;
165   } else {
166     mode = SSL_VERIFY_NONE;
167     checkPeerName_ = false; // can't check name without cert!
168     peerFixedName_.clear();
169   }
170   SSL_CTX_set_verify(ctx_, mode, nullptr);
171 }
172
173 void SSLContext::loadCertificate(const char* path, const char* format) {
174   if (path == nullptr || format == nullptr) {
175     throw std::invalid_argument(
176          "loadCertificateChain: either <path> or <format> is nullptr");
177   }
178   if (strcmp(format, "PEM") == 0) {
179     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
180       int errnoCopy = errno;
181       std::string reason("SSL_CTX_use_certificate_chain_file: ");
182       reason.append(path);
183       reason.append(": ");
184       reason.append(getErrors(errnoCopy));
185       throw std::runtime_error(reason);
186     }
187   } else {
188     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
189   }
190 }
191
192 void SSLContext::loadPrivateKey(const char* path, const char* format) {
193   if (path == nullptr || format == nullptr) {
194     throw std::invalid_argument(
195          "loadPrivateKey: either <path> or <format> is nullptr");
196   }
197   if (strcmp(format, "PEM") == 0) {
198     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
199       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
200     }
201   } else {
202     throw std::runtime_error("Unsupported private key format: " + std::string(format));
203   }
204 }
205
206 void SSLContext::loadTrustedCertificates(const char* path) {
207   if (path == nullptr) {
208     throw std::invalid_argument(
209          "loadTrustedCertificates: <path> is nullptr");
210   }
211   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
212     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
213   }
214 }
215
216 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
217   SSL_CTX_set_cert_store(ctx_, store);
218 }
219
220 void SSLContext::loadClientCAList(const char* path) {
221   auto clientCAs = SSL_load_client_CA_file(path);
222   if (clientCAs == nullptr) {
223     LOG(ERROR) << "Unable to load ca file: " << path;
224     return;
225   }
226   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
227 }
228
229 void SSLContext::randomize() {
230   RAND_poll();
231 }
232
233 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
234   if (collector == nullptr) {
235     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
236     return;
237   }
238   collector_ = collector;
239   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
240   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
241 }
242
243 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
244
245 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
246   serverNameCb_ = cb;
247 }
248
249 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
250   clientHelloCbs_.push_back(cb);
251 }
252
253 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
254   SSLContext* context = (SSLContext*)data;
255
256   if (context == nullptr) {
257     return SSL_TLSEXT_ERR_NOACK;
258   }
259
260   for (auto& cb : context->clientHelloCbs_) {
261     // Generic callbacks to happen after we receive the Client Hello.
262     // For example, we use one to switch which cipher we use depending
263     // on the user's TLS version.  Because the primary purpose of
264     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
265     // are side-uses, we ignore any possible failures other than just logging
266     // them.
267     cb(ssl);
268   }
269
270   if (!context->serverNameCb_) {
271     return SSL_TLSEXT_ERR_NOACK;
272   }
273
274   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
275   switch (ret) {
276     case SERVER_NAME_FOUND:
277       return SSL_TLSEXT_ERR_OK;
278     case SERVER_NAME_NOT_FOUND:
279       return SSL_TLSEXT_ERR_NOACK;
280     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
281       *al = TLS1_AD_UNRECOGNIZED_NAME;
282       return SSL_TLSEXT_ERR_ALERT_FATAL;
283     default:
284       CHECK(false);
285   }
286
287   return SSL_TLSEXT_ERR_NOACK;
288 }
289
290 void SSLContext::switchCiphersIfTLS11(
291     SSL* ssl,
292     const std::string& tls11CipherString) {
293
294   CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
295
296   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
297     // We only do this for TLS v 1.1 and later
298     return;
299   }
300
301   // Prefer AES for TLS versions 1.1 and later since these are not
302   // vulnerable to BEAST attacks on AES.  Note that we're setting the
303   // cipher list on the SSL object, not the SSL_CTX object, so it will
304   // only last for this request.
305   int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
306   if ((rc == 0) || ERR_peek_error() != 0) {
307     // This shouldn't happen since we checked for this when proxygen
308     // started up.
309     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
310     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
311   }
312 }
313 #endif
314
315 #ifdef OPENSSL_NPN_NEGOTIATED
316 bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
317   return setRandomizedAdvertisedNextProtocols({{1, protocols}});
318 }
319
320 bool SSLContext::setRandomizedAdvertisedNextProtocols(
321     const std::list<NextProtocolsItem>& items) {
322   unsetNextProtocols();
323   if (items.size() == 0) {
324     return false;
325   }
326   int total_weight = 0;
327   for (const auto &item : items) {
328     if (item.protocols.size() == 0) {
329       continue;
330     }
331     AdvertisedNextProtocolsItem advertised_item;
332     advertised_item.length = 0;
333     for (const auto& proto : item.protocols) {
334       ++advertised_item.length;
335       unsigned protoLength = proto.length();
336       if (protoLength >= 256) {
337         deleteNextProtocolsStrings();
338         return false;
339       }
340       advertised_item.length += protoLength;
341     }
342     advertised_item.protocols = new unsigned char[advertised_item.length];
343     if (!advertised_item.protocols) {
344       throw std::runtime_error("alloc failure");
345     }
346     unsigned char* dst = advertised_item.protocols;
347     for (auto& proto : item.protocols) {
348       unsigned protoLength = proto.length();
349       *dst++ = (unsigned char)protoLength;
350       memcpy(dst, proto.data(), protoLength);
351       dst += protoLength;
352     }
353     total_weight += item.weight;
354     advertised_item.probability = item.weight;
355     advertisedNextProtocols_.push_back(advertised_item);
356   }
357   if (total_weight == 0) {
358     deleteNextProtocolsStrings();
359     return false;
360   }
361   for (auto &advertised_item : advertisedNextProtocols_) {
362     advertised_item.probability /= total_weight;
363   }
364   SSL_CTX_set_next_protos_advertised_cb(
365     ctx_, advertisedNextProtocolCallback, this);
366   SSL_CTX_set_next_proto_select_cb(
367     ctx_, selectNextProtocolCallback, this);
368   return true;
369 }
370
371 void SSLContext::deleteNextProtocolsStrings() {
372   for (auto protocols : advertisedNextProtocols_) {
373     delete[] protocols.protocols;
374   }
375   advertisedNextProtocols_.clear();
376 }
377
378 void SSLContext::unsetNextProtocols() {
379   deleteNextProtocolsStrings();
380   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
381   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
382 }
383
384 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
385       const unsigned char** out, unsigned int* outlen, void* data) {
386   SSLContext* context = (SSLContext*)data;
387   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
388     *out = nullptr;
389     *outlen = 0;
390   } else if (context->advertisedNextProtocols_.size() == 1) {
391     *out = context->advertisedNextProtocols_[0].protocols;
392     *outlen = context->advertisedNextProtocols_[0].length;
393   } else {
394     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
395           sNextProtocolsExDataIndex_));
396     if (selected_index) {
397       --selected_index;
398       *out = context->advertisedNextProtocols_[selected_index].protocols;
399       *outlen = context->advertisedNextProtocols_[selected_index].length;
400     } else {
401       unsigned char random_byte;
402       RAND_bytes(&random_byte, 1);
403       double random_value = random_byte / 255.0;
404       double sum = 0;
405       for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
406         sum += context->advertisedNextProtocols_[i].probability;
407         if (sum < random_value &&
408             i + 1 < context->advertisedNextProtocols_.size()) {
409           continue;
410         }
411         uintptr_t selected = i + 1;
412         SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
413         *out = context->advertisedNextProtocols_[i].protocols;
414         *outlen = context->advertisedNextProtocols_[i].length;
415         break;
416       }
417     }
418   }
419   return SSL_TLSEXT_ERR_OK;
420 }
421
422 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
423   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
424 SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
425   /**
426    * The list was generated as follows:
427    * grep "_CK_" openssl-1.0.1e/ssl/s3_lib.c -A 4 |
428    * while read A && read B && read C && read D && read E && read F; do
429    * echo $A $B $C $D $E; done |
430    * grep "\(SSL_kDHr\|SSL_kDHd\|SSL_kEDH\|SSL_kECDHr\|
431    *         SSL_kECDHe\|SSL_kEECDH\)" | grep -v SSL_aNULL | grep SSL_AES |
432    * awk -F, '{ print $1"," }'
433    */
434   ciphers_{
435     TLS1_CK_DH_DSS_WITH_AES_128_SHA,
436     TLS1_CK_DH_RSA_WITH_AES_128_SHA,
437     TLS1_CK_DHE_DSS_WITH_AES_128_SHA,
438     TLS1_CK_DHE_RSA_WITH_AES_128_SHA,
439     TLS1_CK_DH_DSS_WITH_AES_256_SHA,
440     TLS1_CK_DH_RSA_WITH_AES_256_SHA,
441     TLS1_CK_DHE_DSS_WITH_AES_256_SHA,
442     TLS1_CK_DHE_RSA_WITH_AES_256_SHA,
443     TLS1_CK_DH_DSS_WITH_AES_128_SHA256,
444     TLS1_CK_DH_RSA_WITH_AES_128_SHA256,
445     TLS1_CK_DHE_DSS_WITH_AES_128_SHA256,
446     TLS1_CK_DHE_RSA_WITH_AES_128_SHA256,
447     TLS1_CK_DH_DSS_WITH_AES_256_SHA256,
448     TLS1_CK_DH_RSA_WITH_AES_256_SHA256,
449     TLS1_CK_DHE_DSS_WITH_AES_256_SHA256,
450     TLS1_CK_DHE_RSA_WITH_AES_256_SHA256,
451     TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256,
452     TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384,
453     TLS1_CK_DH_RSA_WITH_AES_128_GCM_SHA256,
454     TLS1_CK_DH_RSA_WITH_AES_256_GCM_SHA384,
455     TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256,
456     TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384,
457     TLS1_CK_DH_DSS_WITH_AES_128_GCM_SHA256,
458     TLS1_CK_DH_DSS_WITH_AES_256_GCM_SHA384,
459     TLS1_CK_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
460     TLS1_CK_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
461     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
462     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
463     TLS1_CK_ECDH_RSA_WITH_AES_128_CBC_SHA,
464     TLS1_CK_ECDH_RSA_WITH_AES_256_CBC_SHA,
465     TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA,
466     TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA,
467     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256,
468     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384,
469     TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256,
470     TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384,
471     TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256,
472     TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384,
473     TLS1_CK_ECDH_RSA_WITH_AES_128_SHA256,
474     TLS1_CK_ECDH_RSA_WITH_AES_256_SHA384,
475     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
476     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
477     TLS1_CK_ECDH_ECDSA_WITH_AES_128_GCM_SHA256,
478     TLS1_CK_ECDH_ECDSA_WITH_AES_256_GCM_SHA384,
479     TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
480     TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
481     TLS1_CK_ECDH_RSA_WITH_AES_128_GCM_SHA256,
482   } {
483   length_ = sizeof(ciphers_)/sizeof(ciphers_[0]);
484   width_ = sizeof(ciphers_[0]);
485   qsort(ciphers_, length_, width_, compare_ulong);
486 }
487
488 bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher(
489   const SSL_CIPHER *cipher) {
490   unsigned long cid = cipher->id;
491   unsigned long *r =
492     (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong);
493   return r != nullptr;
494 }
495
496 int
497 SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) {
498   if (*(unsigned long *)x < *(unsigned long *)y) {
499     return -1;
500   }
501   if (*(unsigned long *)x > *(unsigned long *)y) {
502     return 1;
503   }
504   return 0;
505 };
506
507 bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) {
508   return falseStartChecker_.canUseFalseStartWithCipher(cipher);
509 }
510 #endif
511
512 int SSLContext::selectNextProtocolCallback(
513   SSL* ssl, unsigned char **out, unsigned char *outlen,
514   const unsigned char *server, unsigned int server_len, void *data) {
515
516   SSLContext* ctx = (SSLContext*)data;
517   if (ctx->advertisedNextProtocols_.size() > 1) {
518     VLOG(3) << "SSLContext::selectNextProcolCallback() "
519             << "client should be deterministic in selecting protocols.";
520   }
521
522   unsigned char *client;
523   int client_len;
524   if (ctx->advertisedNextProtocols_.empty()) {
525     client = (unsigned char *) "";
526     client_len = 0;
527   } else {
528     client = ctx->advertisedNextProtocols_[0].protocols;
529     client_len = ctx->advertisedNextProtocols_[0].length;
530   }
531
532   int retval = SSL_select_next_proto(out, outlen, server, server_len,
533                                      client, client_len);
534   if (retval != OPENSSL_NPN_NEGOTIATED) {
535     VLOG(3) << "SSLContext::selectNextProcolCallback() "
536             << "unable to pick a next protocol.";
537 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
538   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
539   } else {
540     const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher;
541     if (cipher && ctx->canUseFalseStartWithCipher(cipher)) {
542       SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH);
543     }
544 #endif
545   }
546   return SSL_TLSEXT_ERR_OK;
547 }
548 #endif // OPENSSL_NPN_NEGOTIATED
549
550 SSL* SSLContext::createSSL() const {
551   SSL* ssl = SSL_new(ctx_);
552   if (ssl == nullptr) {
553     throw std::runtime_error("SSL_new: " + getErrors());
554   }
555   return ssl;
556 }
557
558 /**
559  * Match a name with a pattern. The pattern may include wildcard. A single
560  * wildcard "*" can match up to one component in the domain name.
561  *
562  * @param  host    Host name, typically the name of the remote host
563  * @param  pattern Name retrieved from certificate
564  * @param  size    Size of "pattern"
565  * @return True, if "host" matches "pattern". False otherwise.
566  */
567 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
568   bool match = false;
569   int i = 0, j = 0;
570   while (i < size && host[j] != '\0') {
571     if (toupper(pattern[i]) == toupper(host[j])) {
572       i++;
573       j++;
574       continue;
575     }
576     if (pattern[i] == '*') {
577       while (host[j] != '.' && host[j] != '\0') {
578         j++;
579       }
580       i++;
581       continue;
582     }
583     break;
584   }
585   if (i == size && host[j] == '\0') {
586     match = true;
587   }
588   return match;
589 }
590
591 int SSLContext::passwordCallback(char* password,
592                                  int size,
593                                  int,
594                                  void* data) {
595   SSLContext* context = (SSLContext*)data;
596   if (context == nullptr || context->passwordCollector() == nullptr) {
597     return 0;
598   }
599   std::string userPassword;
600   // call user defined password collector to get password
601   context->passwordCollector()->getPassword(userPassword, size);
602   int length = userPassword.size();
603   if (length > size) {
604     length = size;
605   }
606   strncpy(password, userPassword.c_str(), length);
607   return length;
608 }
609
610 struct SSLLock {
611   explicit SSLLock(
612     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
613       lockType(inLockType) {
614   }
615
616   void lock() {
617     if (lockType == SSLContext::LOCK_MUTEX) {
618       mutex.lock();
619     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
620       spinLock.lock();
621     }
622     // lockType == LOCK_NONE, no-op
623   }
624
625   void unlock() {
626     if (lockType == SSLContext::LOCK_MUTEX) {
627       mutex.unlock();
628     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
629       spinLock.unlock();
630     }
631     // lockType == LOCK_NONE, no-op
632   }
633
634   SSLContext::SSLLockType lockType;
635   folly::SpinLock spinLock{};
636   std::mutex mutex;
637 };
638
639 // Statics are unsafe in environments that call exit().
640 // If one thread calls exit() while another thread is
641 // references a member of SSLContext, bad things can happen.
642 // SSLContext runs in such environments.
643 // Instead of declaring a static member we "new" the static
644 // member so that it won't be destructed on exit().
645 static std::map<int, SSLContext::SSLLockType>* lockTypesInst =
646   new std::map<int, SSLContext::SSLLockType>();
647
648 static std::unique_ptr<SSLLock[]>* locksInst =
649   new std::unique_ptr<SSLLock[]>();
650
651 static std::unique_ptr<SSLLock[]>& locks() {
652   return *locksInst;
653 }
654
655 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
656   return *lockTypesInst;
657 }
658
659 static void callbackLocking(int mode, int n, const char*, int) {
660   if (mode & CRYPTO_LOCK) {
661     locks()[n].lock();
662   } else {
663     locks()[n].unlock();
664   }
665 }
666
667 static unsigned long callbackThreadID() {
668   return static_cast<unsigned long>(
669 #ifdef __APPLE__
670     pthread_mach_thread_np(pthread_self())
671 #else
672     pthread_self()
673 #endif
674   );
675 }
676
677 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
678   return new CRYPTO_dynlock_value;
679 }
680
681 static void dyn_lock(int mode,
682                      struct CRYPTO_dynlock_value* lock,
683                      const char*, int) {
684   if (lock != nullptr) {
685     if (mode & CRYPTO_LOCK) {
686       lock->mutex.lock();
687     } else {
688       lock->mutex.unlock();
689     }
690   }
691 }
692
693 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
694   delete lock;
695 }
696
697 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
698   lockTypes() = inLockTypes;
699 }
700
701 void SSLContext::initializeOpenSSL() {
702   std::lock_guard<std::mutex> g(mutex_);
703   initializeOpenSSLLocked();
704 }
705
706 void SSLContext::initializeOpenSSLLocked() {
707   if (initialized_) {
708     return;
709   }
710   SSL_library_init();
711   SSL_load_error_strings();
712   ERR_load_crypto_strings();
713   // static locking
714   locks().reset(new SSLLock[::CRYPTO_num_locks()]);
715   for (auto it: lockTypes()) {
716     locks()[it.first].lockType = it.second;
717   }
718   CRYPTO_set_id_callback(callbackThreadID);
719   CRYPTO_set_locking_callback(callbackLocking);
720   // dynamic locking
721   CRYPTO_set_dynlock_create_callback(dyn_create);
722   CRYPTO_set_dynlock_lock_callback(dyn_lock);
723   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
724   randomize();
725 #ifdef OPENSSL_NPN_NEGOTIATED
726   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
727       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
728 #endif
729   initialized_ = true;
730 }
731
732 void SSLContext::cleanupOpenSSL() {
733   std::lock_guard<std::mutex> g(mutex_);
734   cleanupOpenSSLLocked();
735 }
736
737 void SSLContext::cleanupOpenSSLLocked() {
738   if (!initialized_) {
739     return;
740   }
741
742   CRYPTO_set_id_callback(nullptr);
743   CRYPTO_set_locking_callback(nullptr);
744   CRYPTO_set_dynlock_create_callback(nullptr);
745   CRYPTO_set_dynlock_lock_callback(nullptr);
746   CRYPTO_set_dynlock_destroy_callback(nullptr);
747   CRYPTO_cleanup_all_ex_data();
748   ERR_free_strings();
749   EVP_cleanup();
750   ERR_remove_state(0);
751   locks().reset();
752   initialized_ = false;
753 }
754
755 void SSLContext::setOptions(long options) {
756   long newOpt = SSL_CTX_set_options(ctx_, options);
757   if ((newOpt & options) != options) {
758     throw std::runtime_error("SSL_CTX_set_options failed");
759   }
760 }
761
762 std::string SSLContext::getErrors(int errnoCopy) {
763   std::string errors;
764   unsigned long  errorCode;
765   char   message[256];
766
767   errors.reserve(512);
768   while ((errorCode = ERR_get_error()) != 0) {
769     if (!errors.empty()) {
770       errors += "; ";
771     }
772     const char* reason = ERR_reason_error_string(errorCode);
773     if (reason == nullptr) {
774       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
775       reason = message;
776     }
777     errors += reason;
778   }
779   if (errors.empty()) {
780     errors = "error code: " + folly::to<std::string>(errnoCopy);
781   }
782   return errors;
783 }
784
785 std::ostream&
786 operator<<(std::ostream& os, const PasswordCollector& collector) {
787   os << collector.describe();
788   return os;
789 }
790
791 bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx,
792                                                   sockaddr_storage* addrStorage,
793                                                   socklen_t* addrLen) {
794   // Grab the ssl idx and then the ssl object so that we can get the peer
795   // name to compare against the ips in the subjectAltName
796   auto sslIdx = SSL_get_ex_data_X509_STORE_CTX_idx();
797   auto ssl =
798     reinterpret_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, sslIdx));
799   int fd = SSL_get_fd(ssl);
800   if (fd < 0) {
801     LOG(ERROR) << "Inexplicably couldn't get fd from SSL";
802     return false;
803   }
804
805   *addrLen = sizeof(*addrStorage);
806   if (getpeername(fd, reinterpret_cast<sockaddr*>(addrStorage), addrLen) != 0) {
807     PLOG(ERROR) << "Unable to get peer name";
808     return false;
809   }
810   CHECK(*addrLen <= sizeof(*addrStorage));
811   return true;
812 }
813
814 bool OpenSSLUtils::validatePeerCertNames(X509* cert,
815                                          const sockaddr* addr,
816                                          socklen_t addrLen) {
817   // Try to extract the names within the SAN extension from the certificate
818   auto altNames =
819     reinterpret_cast<STACK_OF(GENERAL_NAME)*>(
820         X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
821   SCOPE_EXIT {
822     if (altNames != nullptr) {
823       sk_GENERAL_NAME_pop_free(altNames, GENERAL_NAME_free);
824     }
825   };
826   if (altNames == nullptr) {
827     LOG(WARNING) << "No subjectAltName provided and we only support ip auth";
828     return false;
829   }
830
831   const sockaddr_in* addr4 = nullptr;
832   const sockaddr_in6* addr6 = nullptr;
833   if (addr != nullptr) {
834     if (addr->sa_family == AF_INET) {
835       addr4 = reinterpret_cast<const sockaddr_in*>(addr);
836     } else if (addr->sa_family == AF_INET6) {
837       addr6 = reinterpret_cast<const sockaddr_in6*>(addr);
838     } else {
839       LOG(FATAL) << "Unsupported sockaddr family: " << addr->sa_family;
840     }
841   }
842
843
844   for (int i = 0; i < sk_GENERAL_NAME_num(altNames); i++) {
845     auto name = sk_GENERAL_NAME_value(altNames, i);
846     if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) {
847       // Extra const-ness for paranoia
848       unsigned char const * const rawIpStr = name->d.iPAddress->data;
849       int const rawIpLen = name->d.iPAddress->length;
850
851       if (rawIpLen == 4 && addr4 != nullptr) {
852         if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) {
853           return true;
854         }
855       } else if (rawIpLen == 16 && addr6 != nullptr) {
856         if (::memcmp(rawIpStr, &addr6->sin6_addr, rawIpLen) == 0) {
857           return true;
858         }
859       } else if (rawIpLen != 4 && rawIpLen != 16) {
860         LOG(WARNING) << "Unexpected IP length: " << rawIpLen;
861       }
862     }
863   }
864
865   LOG(WARNING) << "Unable to match client cert against alt name ip";
866   return false;
867 }
868
869
870 } // folly