Add static method to skip SSL init
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/SpinLock.h>
26
27 // ---------------------------------------------------------------------
28 // SSLContext implementation
29 // ---------------------------------------------------------------------
30
31 struct CRYPTO_dynlock_value {
32   std::mutex mutex;
33 };
34
35 namespace folly {
36
37 bool SSLContext::initialized_ = false;
38 std::mutex    SSLContext::mutex_;
39 #ifdef OPENSSL_NPN_NEGOTIATED
40 int SSLContext::sNextProtocolsExDataIndex_ = -1;
41 #endif
42
43 // SSLContext implementation
44 SSLContext::SSLContext(SSLVersion version) {
45   {
46     std::lock_guard<std::mutex> g(mutex_);
47     initializeOpenSSLLocked();
48   }
49
50   ctx_ = SSL_CTX_new(SSLv23_method());
51   if (ctx_ == nullptr) {
52     throw std::runtime_error("SSL_CTX_new: " + getErrors());
53   }
54
55   int opt = 0;
56   switch (version) {
57     case TLSv1:
58       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
59       break;
60     case SSLv3:
61       opt = SSL_OP_NO_SSLv2;
62       break;
63     default:
64       // do nothing
65       break;
66   }
67   int newOpt = SSL_CTX_set_options(ctx_, opt);
68   DCHECK((newOpt & opt) == opt);
69
70   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
71
72   checkPeerName_ = false;
73
74 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
75   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
76   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
77 #endif
78 }
79
80 SSLContext::~SSLContext() {
81   if (ctx_ != nullptr) {
82     SSL_CTX_free(ctx_);
83     ctx_ = nullptr;
84   }
85
86 #ifdef OPENSSL_NPN_NEGOTIATED
87   deleteNextProtocolsStrings();
88 #endif
89 }
90
91 void SSLContext::ciphers(const std::string& ciphers) {
92   providedCiphersString_ = ciphers;
93   setCiphersOrThrow(ciphers);
94 }
95
96 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
97   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
98   if (ERR_peek_error() != 0) {
99     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
100   }
101   if (rc == 0) {
102     throw std::runtime_error("None of specified ciphers are supported");
103   }
104 }
105
106 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
107     verifyPeer) {
108   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
109   verifyPeer_ = verifyPeer;
110 }
111
112 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
113     verifyPeer) {
114   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
115   int mode = SSL_VERIFY_NONE;
116   switch(verifyPeer) {
117     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
118     // break;
119
120     case SSLVerifyPeerEnum::VERIFY:
121       mode = SSL_VERIFY_PEER;
122       break;
123
124     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
125       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
126       break;
127
128     case SSLVerifyPeerEnum::NO_VERIFY:
129       mode = SSL_VERIFY_NONE;
130       break;
131
132     default:
133       break;
134   }
135   return mode;
136 }
137
138 int SSLContext::getVerificationMode() {
139   return getVerificationMode(verifyPeer_);
140 }
141
142 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
143                               const std::string& peerName) {
144   int mode;
145   if (checkPeerCert) {
146     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
147     checkPeerName_ = checkPeerName;
148     peerFixedName_ = peerName;
149   } else {
150     mode = SSL_VERIFY_NONE;
151     checkPeerName_ = false; // can't check name without cert!
152     peerFixedName_.clear();
153   }
154   SSL_CTX_set_verify(ctx_, mode, nullptr);
155 }
156
157 void SSLContext::loadCertificate(const char* path, const char* format) {
158   if (path == nullptr || format == nullptr) {
159     throw std::invalid_argument(
160          "loadCertificateChain: either <path> or <format> is nullptr");
161   }
162   if (strcmp(format, "PEM") == 0) {
163     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
164       int errnoCopy = errno;
165       std::string reason("SSL_CTX_use_certificate_chain_file: ");
166       reason.append(path);
167       reason.append(": ");
168       reason.append(getErrors(errnoCopy));
169       throw std::runtime_error(reason);
170     }
171   } else {
172     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
173   }
174 }
175
176 void SSLContext::loadPrivateKey(const char* path, const char* format) {
177   if (path == nullptr || format == nullptr) {
178     throw std::invalid_argument(
179          "loadPrivateKey: either <path> or <format> is nullptr");
180   }
181   if (strcmp(format, "PEM") == 0) {
182     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
183       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
184     }
185   } else {
186     throw std::runtime_error("Unsupported private key format: " + std::string(format));
187   }
188 }
189
190 void SSLContext::loadTrustedCertificates(const char* path) {
191   if (path == nullptr) {
192     throw std::invalid_argument(
193          "loadTrustedCertificates: <path> is nullptr");
194   }
195   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
196     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
197   }
198 }
199
200 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
201   SSL_CTX_set_cert_store(ctx_, store);
202 }
203
204 void SSLContext::loadClientCAList(const char* path) {
205   auto clientCAs = SSL_load_client_CA_file(path);
206   if (clientCAs == nullptr) {
207     LOG(ERROR) << "Unable to load ca file: " << path;
208     return;
209   }
210   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
211 }
212
213 void SSLContext::randomize() {
214   RAND_poll();
215 }
216
217 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
218   if (collector == nullptr) {
219     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
220     return;
221   }
222   collector_ = collector;
223   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
224   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
225 }
226
227 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
228
229 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
230   serverNameCb_ = cb;
231 }
232
233 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
234   clientHelloCbs_.push_back(cb);
235 }
236
237 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
238   SSLContext* context = (SSLContext*)data;
239
240   if (context == nullptr) {
241     return SSL_TLSEXT_ERR_NOACK;
242   }
243
244   for (auto& cb : context->clientHelloCbs_) {
245     // Generic callbacks to happen after we receive the Client Hello.
246     // For example, we use one to switch which cipher we use depending
247     // on the user's TLS version.  Because the primary purpose of
248     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
249     // are side-uses, we ignore any possible failures other than just logging
250     // them.
251     cb(ssl);
252   }
253
254   if (!context->serverNameCb_) {
255     return SSL_TLSEXT_ERR_NOACK;
256   }
257
258   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
259   switch (ret) {
260     case SERVER_NAME_FOUND:
261       return SSL_TLSEXT_ERR_OK;
262     case SERVER_NAME_NOT_FOUND:
263       return SSL_TLSEXT_ERR_NOACK;
264     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
265       *al = TLS1_AD_UNRECOGNIZED_NAME;
266       return SSL_TLSEXT_ERR_ALERT_FATAL;
267     default:
268       CHECK(false);
269   }
270
271   return SSL_TLSEXT_ERR_NOACK;
272 }
273
274 void SSLContext::switchCiphersIfTLS11(
275     SSL* ssl,
276     const std::string& tls11CipherString) {
277
278   CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
279
280   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
281     // We only do this for TLS v 1.1 and later
282     return;
283   }
284
285   // Prefer AES for TLS versions 1.1 and later since these are not
286   // vulnerable to BEAST attacks on AES.  Note that we're setting the
287   // cipher list on the SSL object, not the SSL_CTX object, so it will
288   // only last for this request.
289   int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
290   if ((rc == 0) || ERR_peek_error() != 0) {
291     // This shouldn't happen since we checked for this when proxygen
292     // started up.
293     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
294     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
295   }
296 }
297 #endif
298
299 #ifdef OPENSSL_NPN_NEGOTIATED
300 bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
301   return setRandomizedAdvertisedNextProtocols({{1, protocols}});
302 }
303
304 bool SSLContext::setRandomizedAdvertisedNextProtocols(
305     const std::list<NextProtocolsItem>& items) {
306   unsetNextProtocols();
307   if (items.size() == 0) {
308     return false;
309   }
310   int total_weight = 0;
311   for (const auto &item : items) {
312     if (item.protocols.size() == 0) {
313       continue;
314     }
315     AdvertisedNextProtocolsItem advertised_item;
316     advertised_item.length = 0;
317     for (const auto& proto : item.protocols) {
318       ++advertised_item.length;
319       unsigned protoLength = proto.length();
320       if (protoLength >= 256) {
321         deleteNextProtocolsStrings();
322         return false;
323       }
324       advertised_item.length += protoLength;
325     }
326     advertised_item.protocols = new unsigned char[advertised_item.length];
327     if (!advertised_item.protocols) {
328       throw std::runtime_error("alloc failure");
329     }
330     unsigned char* dst = advertised_item.protocols;
331     for (auto& proto : item.protocols) {
332       unsigned protoLength = proto.length();
333       *dst++ = (unsigned char)protoLength;
334       memcpy(dst, proto.data(), protoLength);
335       dst += protoLength;
336     }
337     total_weight += item.weight;
338     advertised_item.probability = item.weight;
339     advertisedNextProtocols_.push_back(advertised_item);
340   }
341   if (total_weight == 0) {
342     deleteNextProtocolsStrings();
343     return false;
344   }
345   for (auto &advertised_item : advertisedNextProtocols_) {
346     advertised_item.probability /= total_weight;
347   }
348   SSL_CTX_set_next_protos_advertised_cb(
349     ctx_, advertisedNextProtocolCallback, this);
350   SSL_CTX_set_next_proto_select_cb(
351     ctx_, selectNextProtocolCallback, this);
352   return true;
353 }
354
355 void SSLContext::deleteNextProtocolsStrings() {
356   for (auto protocols : advertisedNextProtocols_) {
357     delete[] protocols.protocols;
358   }
359   advertisedNextProtocols_.clear();
360 }
361
362 void SSLContext::unsetNextProtocols() {
363   deleteNextProtocolsStrings();
364   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
365   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
366 }
367
368 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
369       const unsigned char** out, unsigned int* outlen, void* data) {
370   SSLContext* context = (SSLContext*)data;
371   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
372     *out = nullptr;
373     *outlen = 0;
374   } else if (context->advertisedNextProtocols_.size() == 1) {
375     *out = context->advertisedNextProtocols_[0].protocols;
376     *outlen = context->advertisedNextProtocols_[0].length;
377   } else {
378     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
379           sNextProtocolsExDataIndex_));
380     if (selected_index) {
381       --selected_index;
382       *out = context->advertisedNextProtocols_[selected_index].protocols;
383       *outlen = context->advertisedNextProtocols_[selected_index].length;
384     } else {
385       unsigned char random_byte;
386       RAND_bytes(&random_byte, 1);
387       double random_value = random_byte / 255.0;
388       double sum = 0;
389       for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
390         sum += context->advertisedNextProtocols_[i].probability;
391         if (sum < random_value &&
392             i + 1 < context->advertisedNextProtocols_.size()) {
393           continue;
394         }
395         uintptr_t selected = i + 1;
396         SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
397         *out = context->advertisedNextProtocols_[i].protocols;
398         *outlen = context->advertisedNextProtocols_[i].length;
399         break;
400       }
401     }
402   }
403   return SSL_TLSEXT_ERR_OK;
404 }
405
406 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
407   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
408 SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
409   /**
410    * The list was generated as follows:
411    * grep "_CK_" openssl-1.0.1e/ssl/s3_lib.c -A 4 |
412    * while read A && read B && read C && read D && read E && read F; do
413    * echo $A $B $C $D $E; done |
414    * grep "\(SSL_kDHr\|SSL_kDHd\|SSL_kEDH\|SSL_kECDHr\|
415    *         SSL_kECDHe\|SSL_kEECDH\)" | grep -v SSL_aNULL | grep SSL_AES |
416    * awk -F, '{ print $1"," }'
417    */
418   ciphers_{
419     TLS1_CK_DH_DSS_WITH_AES_128_SHA,
420     TLS1_CK_DH_RSA_WITH_AES_128_SHA,
421     TLS1_CK_DHE_DSS_WITH_AES_128_SHA,
422     TLS1_CK_DHE_RSA_WITH_AES_128_SHA,
423     TLS1_CK_DH_DSS_WITH_AES_256_SHA,
424     TLS1_CK_DH_RSA_WITH_AES_256_SHA,
425     TLS1_CK_DHE_DSS_WITH_AES_256_SHA,
426     TLS1_CK_DHE_RSA_WITH_AES_256_SHA,
427     TLS1_CK_DH_DSS_WITH_AES_128_SHA256,
428     TLS1_CK_DH_RSA_WITH_AES_128_SHA256,
429     TLS1_CK_DHE_DSS_WITH_AES_128_SHA256,
430     TLS1_CK_DHE_RSA_WITH_AES_128_SHA256,
431     TLS1_CK_DH_DSS_WITH_AES_256_SHA256,
432     TLS1_CK_DH_RSA_WITH_AES_256_SHA256,
433     TLS1_CK_DHE_DSS_WITH_AES_256_SHA256,
434     TLS1_CK_DHE_RSA_WITH_AES_256_SHA256,
435     TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256,
436     TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384,
437     TLS1_CK_DH_RSA_WITH_AES_128_GCM_SHA256,
438     TLS1_CK_DH_RSA_WITH_AES_256_GCM_SHA384,
439     TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256,
440     TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384,
441     TLS1_CK_DH_DSS_WITH_AES_128_GCM_SHA256,
442     TLS1_CK_DH_DSS_WITH_AES_256_GCM_SHA384,
443     TLS1_CK_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
444     TLS1_CK_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
445     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
446     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
447     TLS1_CK_ECDH_RSA_WITH_AES_128_CBC_SHA,
448     TLS1_CK_ECDH_RSA_WITH_AES_256_CBC_SHA,
449     TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA,
450     TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA,
451     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256,
452     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384,
453     TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256,
454     TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384,
455     TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256,
456     TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384,
457     TLS1_CK_ECDH_RSA_WITH_AES_128_SHA256,
458     TLS1_CK_ECDH_RSA_WITH_AES_256_SHA384,
459     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
460     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
461     TLS1_CK_ECDH_ECDSA_WITH_AES_128_GCM_SHA256,
462     TLS1_CK_ECDH_ECDSA_WITH_AES_256_GCM_SHA384,
463     TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
464     TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
465     TLS1_CK_ECDH_RSA_WITH_AES_128_GCM_SHA256,
466   } {
467   length_ = sizeof(ciphers_)/sizeof(ciphers_[0]);
468   width_ = sizeof(ciphers_[0]);
469   qsort(ciphers_, length_, width_, compare_ulong);
470 }
471
472 bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher(
473   const SSL_CIPHER *cipher) {
474   unsigned long cid = cipher->id;
475   unsigned long *r =
476     (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong);
477   return r != nullptr;
478 }
479
480 int
481 SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) {
482   if (*(unsigned long *)x < *(unsigned long *)y) {
483     return -1;
484   }
485   if (*(unsigned long *)x > *(unsigned long *)y) {
486     return 1;
487   }
488   return 0;
489 };
490
491 bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) {
492   return falseStartChecker_.canUseFalseStartWithCipher(cipher);
493 }
494 #endif
495
496 int SSLContext::selectNextProtocolCallback(
497   SSL* ssl, unsigned char **out, unsigned char *outlen,
498   const unsigned char *server, unsigned int server_len, void *data) {
499
500   SSLContext* ctx = (SSLContext*)data;
501   if (ctx->advertisedNextProtocols_.size() > 1) {
502     VLOG(3) << "SSLContext::selectNextProcolCallback() "
503             << "client should be deterministic in selecting protocols.";
504   }
505
506   unsigned char *client;
507   int client_len;
508   if (ctx->advertisedNextProtocols_.empty()) {
509     client = (unsigned char *) "";
510     client_len = 0;
511   } else {
512     client = ctx->advertisedNextProtocols_[0].protocols;
513     client_len = ctx->advertisedNextProtocols_[0].length;
514   }
515
516   int retval = SSL_select_next_proto(out, outlen, server, server_len,
517                                      client, client_len);
518   if (retval != OPENSSL_NPN_NEGOTIATED) {
519     VLOG(3) << "SSLContext::selectNextProcolCallback() "
520             << "unable to pick a next protocol.";
521 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
522   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
523   } else {
524     const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher;
525     if (cipher && ctx->canUseFalseStartWithCipher(cipher)) {
526       SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH);
527     }
528 #endif
529   }
530   return SSL_TLSEXT_ERR_OK;
531 }
532 #endif // OPENSSL_NPN_NEGOTIATED
533
534 SSL* SSLContext::createSSL() const {
535   SSL* ssl = SSL_new(ctx_);
536   if (ssl == nullptr) {
537     throw std::runtime_error("SSL_new: " + getErrors());
538   }
539   return ssl;
540 }
541
542 /**
543  * Match a name with a pattern. The pattern may include wildcard. A single
544  * wildcard "*" can match up to one component in the domain name.
545  *
546  * @param  host    Host name, typically the name of the remote host
547  * @param  pattern Name retrieved from certificate
548  * @param  size    Size of "pattern"
549  * @return True, if "host" matches "pattern". False otherwise.
550  */
551 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
552   bool match = false;
553   int i = 0, j = 0;
554   while (i < size && host[j] != '\0') {
555     if (toupper(pattern[i]) == toupper(host[j])) {
556       i++;
557       j++;
558       continue;
559     }
560     if (pattern[i] == '*') {
561       while (host[j] != '.' && host[j] != '\0') {
562         j++;
563       }
564       i++;
565       continue;
566     }
567     break;
568   }
569   if (i == size && host[j] == '\0') {
570     match = true;
571   }
572   return match;
573 }
574
575 int SSLContext::passwordCallback(char* password,
576                                  int size,
577                                  int,
578                                  void* data) {
579   SSLContext* context = (SSLContext*)data;
580   if (context == nullptr || context->passwordCollector() == nullptr) {
581     return 0;
582   }
583   std::string userPassword;
584   // call user defined password collector to get password
585   context->passwordCollector()->getPassword(userPassword, size);
586   int length = userPassword.size();
587   if (length > size) {
588     length = size;
589   }
590   strncpy(password, userPassword.c_str(), length);
591   return length;
592 }
593
594 struct SSLLock {
595   explicit SSLLock(
596     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
597       lockType(inLockType) {
598   }
599
600   void lock() {
601     if (lockType == SSLContext::LOCK_MUTEX) {
602       mutex.lock();
603     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
604       spinLock.lock();
605     }
606     // lockType == LOCK_NONE, no-op
607   }
608
609   void unlock() {
610     if (lockType == SSLContext::LOCK_MUTEX) {
611       mutex.unlock();
612     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
613       spinLock.unlock();
614     }
615     // lockType == LOCK_NONE, no-op
616   }
617
618   SSLContext::SSLLockType lockType;
619   folly::SpinLock spinLock{};
620   std::mutex mutex;
621 };
622
623 // Statics are unsafe in environments that call exit().
624 // If one thread calls exit() while another thread is
625 // references a member of SSLContext, bad things can happen.
626 // SSLContext runs in such environments.
627 // Instead of declaring a static member we "new" the static
628 // member so that it won't be destructed on exit().
629 static std::map<int, SSLContext::SSLLockType>* lockTypesInst =
630   new std::map<int, SSLContext::SSLLockType>();
631
632 static std::unique_ptr<SSLLock[]>* locksInst =
633   new std::unique_ptr<SSLLock[]>();
634
635 static std::unique_ptr<SSLLock[]>& locks() {
636   return *locksInst;
637 }
638
639 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
640   return *lockTypesInst;
641 }
642
643 static void callbackLocking(int mode, int n, const char*, int) {
644   if (mode & CRYPTO_LOCK) {
645     locks()[n].lock();
646   } else {
647     locks()[n].unlock();
648   }
649 }
650
651 static unsigned long callbackThreadID() {
652   return static_cast<unsigned long>(
653 #ifdef __APPLE__
654     pthread_mach_thread_np(pthread_self())
655 #else
656     pthread_self()
657 #endif
658   );
659 }
660
661 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
662   return new CRYPTO_dynlock_value;
663 }
664
665 static void dyn_lock(int mode,
666                      struct CRYPTO_dynlock_value* lock,
667                      const char*, int) {
668   if (lock != nullptr) {
669     if (mode & CRYPTO_LOCK) {
670       lock->mutex.lock();
671     } else {
672       lock->mutex.unlock();
673     }
674   }
675 }
676
677 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
678   delete lock;
679 }
680
681 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
682   lockTypes() = inLockTypes;
683 }
684
685 void SSLContext::markInitialized() {
686   std::lock_guard<std::mutex> g(mutex_);
687   initialized_ = true;
688 }
689
690 void SSLContext::initializeOpenSSL() {
691   std::lock_guard<std::mutex> g(mutex_);
692   initializeOpenSSLLocked();
693 }
694
695 void SSLContext::initializeOpenSSLLocked() {
696   if (initialized_) {
697     return;
698   }
699   SSL_library_init();
700   SSL_load_error_strings();
701   ERR_load_crypto_strings();
702   // static locking
703   locks().reset(new SSLLock[::CRYPTO_num_locks()]);
704   for (auto it: lockTypes()) {
705     locks()[it.first].lockType = it.second;
706   }
707   CRYPTO_set_id_callback(callbackThreadID);
708   CRYPTO_set_locking_callback(callbackLocking);
709   // dynamic locking
710   CRYPTO_set_dynlock_create_callback(dyn_create);
711   CRYPTO_set_dynlock_lock_callback(dyn_lock);
712   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
713   randomize();
714 #ifdef OPENSSL_NPN_NEGOTIATED
715   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
716       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
717 #endif
718   initialized_ = true;
719 }
720
721 void SSLContext::cleanupOpenSSL() {
722   std::lock_guard<std::mutex> g(mutex_);
723   cleanupOpenSSLLocked();
724 }
725
726 void SSLContext::cleanupOpenSSLLocked() {
727   if (!initialized_) {
728     return;
729   }
730
731   CRYPTO_set_id_callback(nullptr);
732   CRYPTO_set_locking_callback(nullptr);
733   CRYPTO_set_dynlock_create_callback(nullptr);
734   CRYPTO_set_dynlock_lock_callback(nullptr);
735   CRYPTO_set_dynlock_destroy_callback(nullptr);
736   CRYPTO_cleanup_all_ex_data();
737   ERR_free_strings();
738   EVP_cleanup();
739   ERR_remove_state(0);
740   locks().reset();
741   initialized_ = false;
742 }
743
744 void SSLContext::setOptions(long options) {
745   long newOpt = SSL_CTX_set_options(ctx_, options);
746   if ((newOpt & options) != options) {
747     throw std::runtime_error("SSL_CTX_set_options failed");
748   }
749 }
750
751 std::string SSLContext::getErrors(int errnoCopy) {
752   std::string errors;
753   unsigned long  errorCode;
754   char   message[256];
755
756   errors.reserve(512);
757   while ((errorCode = ERR_get_error()) != 0) {
758     if (!errors.empty()) {
759       errors += "; ";
760     }
761     const char* reason = ERR_reason_error_string(errorCode);
762     if (reason == nullptr) {
763       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
764       reason = message;
765     }
766     errors += reason;
767   }
768   if (errors.empty()) {
769     errors = "error code: " + folly::to<std::string>(errnoCopy);
770   }
771   return errors;
772 }
773
774 std::ostream&
775 operator<<(std::ostream& os, const PasswordCollector& collector) {
776   os << collector.describe();
777   return os;
778 }
779
780 bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx,
781                                                   sockaddr_storage* addrStorage,
782                                                   socklen_t* addrLen) {
783   // Grab the ssl idx and then the ssl object so that we can get the peer
784   // name to compare against the ips in the subjectAltName
785   auto sslIdx = SSL_get_ex_data_X509_STORE_CTX_idx();
786   auto ssl =
787     reinterpret_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, sslIdx));
788   int fd = SSL_get_fd(ssl);
789   if (fd < 0) {
790     LOG(ERROR) << "Inexplicably couldn't get fd from SSL";
791     return false;
792   }
793
794   *addrLen = sizeof(*addrStorage);
795   if (getpeername(fd, reinterpret_cast<sockaddr*>(addrStorage), addrLen) != 0) {
796     PLOG(ERROR) << "Unable to get peer name";
797     return false;
798   }
799   CHECK(*addrLen <= sizeof(*addrStorage));
800   return true;
801 }
802
803 bool OpenSSLUtils::validatePeerCertNames(X509* cert,
804                                          const sockaddr* addr,
805                                          socklen_t addrLen) {
806   // Try to extract the names within the SAN extension from the certificate
807   auto altNames =
808     reinterpret_cast<STACK_OF(GENERAL_NAME)*>(
809         X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
810   SCOPE_EXIT {
811     if (altNames != nullptr) {
812       sk_GENERAL_NAME_pop_free(altNames, GENERAL_NAME_free);
813     }
814   };
815   if (altNames == nullptr) {
816     LOG(WARNING) << "No subjectAltName provided and we only support ip auth";
817     return false;
818   }
819
820   const sockaddr_in* addr4 = nullptr;
821   const sockaddr_in6* addr6 = nullptr;
822   if (addr != nullptr) {
823     if (addr->sa_family == AF_INET) {
824       addr4 = reinterpret_cast<const sockaddr_in*>(addr);
825     } else if (addr->sa_family == AF_INET6) {
826       addr6 = reinterpret_cast<const sockaddr_in6*>(addr);
827     } else {
828       LOG(FATAL) << "Unsupported sockaddr family: " << addr->sa_family;
829     }
830   }
831
832
833   for (int i = 0; i < sk_GENERAL_NAME_num(altNames); i++) {
834     auto name = sk_GENERAL_NAME_value(altNames, i);
835     if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) {
836       // Extra const-ness for paranoia
837       unsigned char const * const rawIpStr = name->d.iPAddress->data;
838       int const rawIpLen = name->d.iPAddress->length;
839
840       if (rawIpLen == 4 && addr4 != nullptr) {
841         if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) {
842           return true;
843         }
844       } else if (rawIpLen == 16 && addr6 != nullptr) {
845         if (::memcmp(rawIpStr, &addr6->sin6_addr, rawIpLen) == 0) {
846           return true;
847         }
848       } else if (rawIpLen != 4 && rawIpLen != 16) {
849         LOG(WARNING) << "Unexpected IP length: " << rawIpLen;
850       }
851     }
852   }
853
854   LOG(WARNING) << "Unable to match client cert against alt name ip";
855   return false;
856 }
857
858
859 } // folly