Update SSLContext to use discrete_distribution
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/Format.h>
25 #include <folly/SpinLock.h>
26
27 // ---------------------------------------------------------------------
28 // SSLContext implementation
29 // ---------------------------------------------------------------------
30
31 struct CRYPTO_dynlock_value {
32   std::mutex mutex;
33 };
34
35 namespace folly {
36
37 bool SSLContext::initialized_ = false;
38
39 namespace {
40
41 std::mutex& initMutex() {
42   static std::mutex m;
43   return m;
44 }
45
46 }  // anonymous namespace
47
48 #ifdef OPENSSL_NPN_NEGOTIATED
49 int SSLContext::sNextProtocolsExDataIndex_ = -1;
50 #endif
51
52 // SSLContext implementation
53 SSLContext::SSLContext(SSLVersion version) {
54   {
55     std::lock_guard<std::mutex> g(initMutex());
56     initializeOpenSSLLocked();
57   }
58
59   ctx_ = SSL_CTX_new(SSLv23_method());
60   if (ctx_ == nullptr) {
61     throw std::runtime_error("SSL_CTX_new: " + getErrors());
62   }
63
64   int opt = 0;
65   switch (version) {
66     case TLSv1:
67       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
68       break;
69     case SSLv3:
70       opt = SSL_OP_NO_SSLv2;
71       break;
72     default:
73       // do nothing
74       break;
75   }
76   int newOpt = SSL_CTX_set_options(ctx_, opt);
77   DCHECK((newOpt & opt) == opt);
78
79   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
80
81   checkPeerName_ = false;
82
83 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
84   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
85   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
86 #endif
87
88 #ifdef OPENSSL_NPN_NEGOTIATED
89   Random::seed(nextProtocolPicker_);
90 #endif
91 }
92
93 SSLContext::~SSLContext() {
94   if (ctx_ != nullptr) {
95     SSL_CTX_free(ctx_);
96     ctx_ = nullptr;
97   }
98
99 #ifdef OPENSSL_NPN_NEGOTIATED
100   deleteNextProtocolsStrings();
101 #endif
102 }
103
104 void SSLContext::ciphers(const std::string& ciphers) {
105   providedCiphersString_ = ciphers;
106   setCiphersOrThrow(ciphers);
107 }
108
109 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
110   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
111   if (ERR_peek_error() != 0) {
112     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
113   }
114   if (rc == 0) {
115     throw std::runtime_error("None of specified ciphers are supported");
116   }
117 }
118
119 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
120     verifyPeer) {
121   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
122   verifyPeer_ = verifyPeer;
123 }
124
125 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
126     verifyPeer) {
127   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
128   int mode = SSL_VERIFY_NONE;
129   switch(verifyPeer) {
130     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
131     // break;
132
133     case SSLVerifyPeerEnum::VERIFY:
134       mode = SSL_VERIFY_PEER;
135       break;
136
137     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
138       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
139       break;
140
141     case SSLVerifyPeerEnum::NO_VERIFY:
142       mode = SSL_VERIFY_NONE;
143       break;
144
145     default:
146       break;
147   }
148   return mode;
149 }
150
151 int SSLContext::getVerificationMode() {
152   return getVerificationMode(verifyPeer_);
153 }
154
155 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
156                               const std::string& peerName) {
157   int mode;
158   if (checkPeerCert) {
159     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
160     checkPeerName_ = checkPeerName;
161     peerFixedName_ = peerName;
162   } else {
163     mode = SSL_VERIFY_NONE;
164     checkPeerName_ = false; // can't check name without cert!
165     peerFixedName_.clear();
166   }
167   SSL_CTX_set_verify(ctx_, mode, nullptr);
168 }
169
170 void SSLContext::loadCertificate(const char* path, const char* format) {
171   if (path == nullptr || format == nullptr) {
172     throw std::invalid_argument(
173          "loadCertificateChain: either <path> or <format> is nullptr");
174   }
175   if (strcmp(format, "PEM") == 0) {
176     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
177       int errnoCopy = errno;
178       std::string reason("SSL_CTX_use_certificate_chain_file: ");
179       reason.append(path);
180       reason.append(": ");
181       reason.append(getErrors(errnoCopy));
182       throw std::runtime_error(reason);
183     }
184   } else {
185     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
186   }
187 }
188
189 void SSLContext::loadPrivateKey(const char* path, const char* format) {
190   if (path == nullptr || format == nullptr) {
191     throw std::invalid_argument(
192          "loadPrivateKey: either <path> or <format> is nullptr");
193   }
194   if (strcmp(format, "PEM") == 0) {
195     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
196       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
197     }
198   } else {
199     throw std::runtime_error("Unsupported private key format: " + std::string(format));
200   }
201 }
202
203 void SSLContext::loadTrustedCertificates(const char* path) {
204   if (path == nullptr) {
205     throw std::invalid_argument(
206          "loadTrustedCertificates: <path> is nullptr");
207   }
208   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
209     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
210   }
211 }
212
213 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
214   SSL_CTX_set_cert_store(ctx_, store);
215 }
216
217 void SSLContext::loadClientCAList(const char* path) {
218   auto clientCAs = SSL_load_client_CA_file(path);
219   if (clientCAs == nullptr) {
220     LOG(ERROR) << "Unable to load ca file: " << path;
221     return;
222   }
223   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
224 }
225
226 void SSLContext::randomize() {
227   RAND_poll();
228 }
229
230 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
231   if (collector == nullptr) {
232     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
233     return;
234   }
235   collector_ = collector;
236   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
237   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
238 }
239
240 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
241
242 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
243   serverNameCb_ = cb;
244 }
245
246 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
247   clientHelloCbs_.push_back(cb);
248 }
249
250 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
251   SSLContext* context = (SSLContext*)data;
252
253   if (context == nullptr) {
254     return SSL_TLSEXT_ERR_NOACK;
255   }
256
257   for (auto& cb : context->clientHelloCbs_) {
258     // Generic callbacks to happen after we receive the Client Hello.
259     // For example, we use one to switch which cipher we use depending
260     // on the user's TLS version.  Because the primary purpose of
261     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
262     // are side-uses, we ignore any possible failures other than just logging
263     // them.
264     cb(ssl);
265   }
266
267   if (!context->serverNameCb_) {
268     return SSL_TLSEXT_ERR_NOACK;
269   }
270
271   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
272   switch (ret) {
273     case SERVER_NAME_FOUND:
274       return SSL_TLSEXT_ERR_OK;
275     case SERVER_NAME_NOT_FOUND:
276       return SSL_TLSEXT_ERR_NOACK;
277     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
278       *al = TLS1_AD_UNRECOGNIZED_NAME;
279       return SSL_TLSEXT_ERR_ALERT_FATAL;
280     default:
281       CHECK(false);
282   }
283
284   return SSL_TLSEXT_ERR_NOACK;
285 }
286
287 void SSLContext::switchCiphersIfTLS11(
288     SSL* ssl,
289     const std::string& tls11CipherString) {
290
291   CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
292
293   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
294     // We only do this for TLS v 1.1 and later
295     return;
296   }
297
298   // Prefer AES for TLS versions 1.1 and later since these are not
299   // vulnerable to BEAST attacks on AES.  Note that we're setting the
300   // cipher list on the SSL object, not the SSL_CTX object, so it will
301   // only last for this request.
302   int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
303   if ((rc == 0) || ERR_peek_error() != 0) {
304     // This shouldn't happen since we checked for this when proxygen
305     // started up.
306     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
307     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
308   }
309 }
310 #endif
311
312 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
313 int SSLContext::alpnSelectCallback(SSL* ssl,
314                                    const unsigned char** out,
315                                    unsigned char* outlen,
316                                    const unsigned char* in,
317                                    unsigned int inlen,
318                                    void* data) {
319   SSLContext* context = (SSLContext*)data;
320   CHECK(context);
321   if (context->advertisedNextProtocols_.empty()) {
322     *out = nullptr;
323     *outlen = 0;
324   } else {
325     auto i = context->pickNextProtocols();
326     const auto& item = context->advertisedNextProtocols_[i];
327     if (SSL_select_next_proto((unsigned char**)out,
328                               outlen,
329                               item.protocols,
330                               item.length,
331                               in,
332                               inlen) != OPENSSL_NPN_NEGOTIATED) {
333       return SSL_TLSEXT_ERR_NOACK;
334     }
335   }
336   return SSL_TLSEXT_ERR_OK;
337 }
338 #endif
339
340 #ifdef OPENSSL_NPN_NEGOTIATED
341
342 bool SSLContext::setAdvertisedNextProtocols(
343     const std::list<std::string>& protocols, NextProtocolType protocolType) {
344   return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
345 }
346
347 bool SSLContext::setRandomizedAdvertisedNextProtocols(
348     const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
349   unsetNextProtocols();
350   if (items.size() == 0) {
351     return false;
352   }
353   int total_weight = 0;
354   for (const auto &item : items) {
355     if (item.protocols.size() == 0) {
356       continue;
357     }
358     AdvertisedNextProtocolsItem advertised_item;
359     advertised_item.length = 0;
360     for (const auto& proto : item.protocols) {
361       ++advertised_item.length;
362       unsigned protoLength = proto.length();
363       if (protoLength >= 256) {
364         deleteNextProtocolsStrings();
365         return false;
366       }
367       advertised_item.length += protoLength;
368     }
369     advertised_item.protocols = new unsigned char[advertised_item.length];
370     if (!advertised_item.protocols) {
371       throw std::runtime_error("alloc failure");
372     }
373     unsigned char* dst = advertised_item.protocols;
374     for (auto& proto : item.protocols) {
375       unsigned protoLength = proto.length();
376       *dst++ = (unsigned char)protoLength;
377       memcpy(dst, proto.data(), protoLength);
378       dst += protoLength;
379     }
380     total_weight += item.weight;
381     advertisedNextProtocols_.push_back(advertised_item);
382     advertisedNextProtocolWeights_.push_back(item.weight);
383   }
384   if (total_weight == 0) {
385     deleteNextProtocolsStrings();
386     return false;
387   }
388   nextProtocolDistribution_ =
389       std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
390                                    advertisedNextProtocolWeights_.end());
391   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
392     SSL_CTX_set_next_protos_advertised_cb(
393         ctx_, advertisedNextProtocolCallback, this);
394     SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
395   }
396 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
397   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
398     SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
399     // Client cannot really use randomized alpn
400     SSL_CTX_set_alpn_protos(ctx_,
401                             advertisedNextProtocols_[0].protocols,
402                             advertisedNextProtocols_[0].length);
403   }
404 #endif
405   return true;
406 }
407
408 void SSLContext::deleteNextProtocolsStrings() {
409   for (auto protocols : advertisedNextProtocols_) {
410     delete[] protocols.protocols;
411   }
412   advertisedNextProtocols_.clear();
413   advertisedNextProtocolWeights_.clear();
414 }
415
416 void SSLContext::unsetNextProtocols() {
417   deleteNextProtocolsStrings();
418   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
419   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
420 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
421   SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
422   SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
423 #endif
424 }
425
426 size_t SSLContext::pickNextProtocols() {
427   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
428   return nextProtocolDistribution_(nextProtocolPicker_);
429 }
430
431 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
432       const unsigned char** out, unsigned int* outlen, void* data) {
433   SSLContext* context = (SSLContext*)data;
434   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
435     *out = nullptr;
436     *outlen = 0;
437   } else if (context->advertisedNextProtocols_.size() == 1) {
438     *out = context->advertisedNextProtocols_[0].protocols;
439     *outlen = context->advertisedNextProtocols_[0].length;
440   } else {
441     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
442           sNextProtocolsExDataIndex_));
443     if (selected_index) {
444       --selected_index;
445       *out = context->advertisedNextProtocols_[selected_index].protocols;
446       *outlen = context->advertisedNextProtocols_[selected_index].length;
447     } else {
448       auto i = context->pickNextProtocols();
449       uintptr_t selected = i + 1;
450       SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
451       *out = context->advertisedNextProtocols_[i].protocols;
452       *outlen = context->advertisedNextProtocols_[i].length;
453     }
454   }
455   return SSL_TLSEXT_ERR_OK;
456 }
457
458 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
459   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
460 SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
461   ciphers_{
462     TLS1_CK_DHE_DSS_WITH_AES_128_SHA,
463     TLS1_CK_DHE_RSA_WITH_AES_128_SHA,
464     TLS1_CK_DHE_DSS_WITH_AES_256_SHA,
465     TLS1_CK_DHE_RSA_WITH_AES_256_SHA,
466     TLS1_CK_DHE_DSS_WITH_AES_128_SHA256,
467     TLS1_CK_DHE_RSA_WITH_AES_128_SHA256,
468     TLS1_CK_DHE_DSS_WITH_AES_256_SHA256,
469     TLS1_CK_DHE_RSA_WITH_AES_256_SHA256,
470     TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256,
471     TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384,
472     TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256,
473     TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384,
474     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
475     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
476     TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA,
477     TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA,
478     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256,
479     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384,
480     TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256,
481     TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384,
482     TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256,
483     TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384,
484     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
485     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
486     TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
487     TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
488   } {
489   length_ = sizeof(ciphers_)/sizeof(ciphers_[0]);
490   width_ = sizeof(ciphers_[0]);
491   qsort(ciphers_, length_, width_, compare_ulong);
492 }
493
494 bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher(
495   const SSL_CIPHER *cipher) {
496   unsigned long cid = cipher->id;
497   unsigned long *r =
498     (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong);
499   return r != nullptr;
500 }
501
502 int
503 SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) {
504   if (*(unsigned long *)x < *(unsigned long *)y) {
505     return -1;
506   }
507   if (*(unsigned long *)x > *(unsigned long *)y) {
508     return 1;
509   }
510   return 0;
511 };
512
513 bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) {
514   return falseStartChecker_.canUseFalseStartWithCipher(cipher);
515 }
516 #endif
517
518 int SSLContext::selectNextProtocolCallback(
519   SSL* ssl, unsigned char **out, unsigned char *outlen,
520   const unsigned char *server, unsigned int server_len, void *data) {
521
522   SSLContext* ctx = (SSLContext*)data;
523   if (ctx->advertisedNextProtocols_.size() > 1) {
524     VLOG(3) << "SSLContext::selectNextProcolCallback() "
525             << "client should be deterministic in selecting protocols.";
526   }
527
528   unsigned char *client;
529   unsigned int client_len;
530   bool filtered = false;
531   auto cpf = ctx->getClientProtocolFilterCallback();
532   if (cpf) {
533     filtered = (*cpf)(&client, &client_len, server, server_len);
534   }
535
536   if (!filtered) {
537     if (ctx->advertisedNextProtocols_.empty()) {
538       client = (unsigned char *) "";
539       client_len = 0;
540     } else {
541       client = ctx->advertisedNextProtocols_[0].protocols;
542       client_len = ctx->advertisedNextProtocols_[0].length;
543     }
544   }
545
546   int retval = SSL_select_next_proto(out, outlen, server, server_len,
547                                      client, client_len);
548   if (retval != OPENSSL_NPN_NEGOTIATED) {
549     VLOG(3) << "SSLContext::selectNextProcolCallback() "
550             << "unable to pick a next protocol.";
551 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
552   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
553   } else {
554     const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher;
555     if (cipher && ctx->canUseFalseStartWithCipher(cipher)) {
556       SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH);
557     }
558 #endif
559   }
560   return SSL_TLSEXT_ERR_OK;
561 }
562 #endif // OPENSSL_NPN_NEGOTIATED
563
564 SSL* SSLContext::createSSL() const {
565   SSL* ssl = SSL_new(ctx_);
566   if (ssl == nullptr) {
567     throw std::runtime_error("SSL_new: " + getErrors());
568   }
569   return ssl;
570 }
571
572 /**
573  * Match a name with a pattern. The pattern may include wildcard. A single
574  * wildcard "*" can match up to one component in the domain name.
575  *
576  * @param  host    Host name, typically the name of the remote host
577  * @param  pattern Name retrieved from certificate
578  * @param  size    Size of "pattern"
579  * @return True, if "host" matches "pattern". False otherwise.
580  */
581 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
582   bool match = false;
583   int i = 0, j = 0;
584   while (i < size && host[j] != '\0') {
585     if (toupper(pattern[i]) == toupper(host[j])) {
586       i++;
587       j++;
588       continue;
589     }
590     if (pattern[i] == '*') {
591       while (host[j] != '.' && host[j] != '\0') {
592         j++;
593       }
594       i++;
595       continue;
596     }
597     break;
598   }
599   if (i == size && host[j] == '\0') {
600     match = true;
601   }
602   return match;
603 }
604
605 int SSLContext::passwordCallback(char* password,
606                                  int size,
607                                  int,
608                                  void* data) {
609   SSLContext* context = (SSLContext*)data;
610   if (context == nullptr || context->passwordCollector() == nullptr) {
611     return 0;
612   }
613   std::string userPassword;
614   // call user defined password collector to get password
615   context->passwordCollector()->getPassword(userPassword, size);
616   int length = userPassword.size();
617   if (length > size) {
618     length = size;
619   }
620   strncpy(password, userPassword.c_str(), length);
621   return length;
622 }
623
624 struct SSLLock {
625   explicit SSLLock(
626     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
627       lockType(inLockType) {
628   }
629
630   void lock() {
631     if (lockType == SSLContext::LOCK_MUTEX) {
632       mutex.lock();
633     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
634       spinLock.lock();
635     }
636     // lockType == LOCK_NONE, no-op
637   }
638
639   void unlock() {
640     if (lockType == SSLContext::LOCK_MUTEX) {
641       mutex.unlock();
642     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
643       spinLock.unlock();
644     }
645     // lockType == LOCK_NONE, no-op
646   }
647
648   SSLContext::SSLLockType lockType;
649   folly::SpinLock spinLock{};
650   std::mutex mutex;
651 };
652
653 // Statics are unsafe in environments that call exit().
654 // If one thread calls exit() while another thread is
655 // references a member of SSLContext, bad things can happen.
656 // SSLContext runs in such environments.
657 // Instead of declaring a static member we "new" the static
658 // member so that it won't be destructed on exit().
659 static std::unique_ptr<SSLLock[]>& locks() {
660   static auto locksInst = new std::unique_ptr<SSLLock[]>();
661   return *locksInst;
662 }
663
664 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
665   static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
666   return *lockTypesInst;
667 }
668
669 static void callbackLocking(int mode, int n, const char*, int) {
670   if (mode & CRYPTO_LOCK) {
671     locks()[n].lock();
672   } else {
673     locks()[n].unlock();
674   }
675 }
676
677 static unsigned long callbackThreadID() {
678   return static_cast<unsigned long>(
679 #ifdef __APPLE__
680     pthread_mach_thread_np(pthread_self())
681 #else
682     pthread_self()
683 #endif
684   );
685 }
686
687 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
688   return new CRYPTO_dynlock_value;
689 }
690
691 static void dyn_lock(int mode,
692                      struct CRYPTO_dynlock_value* lock,
693                      const char*, int) {
694   if (lock != nullptr) {
695     if (mode & CRYPTO_LOCK) {
696       lock->mutex.lock();
697     } else {
698       lock->mutex.unlock();
699     }
700   }
701 }
702
703 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
704   delete lock;
705 }
706
707 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
708   lockTypes() = inLockTypes;
709 }
710
711 void SSLContext::markInitialized() {
712   std::lock_guard<std::mutex> g(initMutex());
713   initialized_ = true;
714 }
715
716 void SSLContext::initializeOpenSSL() {
717   std::lock_guard<std::mutex> g(initMutex());
718   initializeOpenSSLLocked();
719 }
720
721 void SSLContext::initializeOpenSSLLocked() {
722   if (initialized_) {
723     return;
724   }
725   SSL_library_init();
726   SSL_load_error_strings();
727   ERR_load_crypto_strings();
728   // static locking
729   locks().reset(new SSLLock[::CRYPTO_num_locks()]);
730   for (auto it: lockTypes()) {
731     locks()[it.first].lockType = it.second;
732   }
733   CRYPTO_set_id_callback(callbackThreadID);
734   CRYPTO_set_locking_callback(callbackLocking);
735   // dynamic locking
736   CRYPTO_set_dynlock_create_callback(dyn_create);
737   CRYPTO_set_dynlock_lock_callback(dyn_lock);
738   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
739   randomize();
740 #ifdef OPENSSL_NPN_NEGOTIATED
741   sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
742       (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
743 #endif
744   initialized_ = true;
745 }
746
747 void SSLContext::cleanupOpenSSL() {
748   std::lock_guard<std::mutex> g(initMutex());
749   cleanupOpenSSLLocked();
750 }
751
752 void SSLContext::cleanupOpenSSLLocked() {
753   if (!initialized_) {
754     return;
755   }
756
757   CRYPTO_set_id_callback(nullptr);
758   CRYPTO_set_locking_callback(nullptr);
759   CRYPTO_set_dynlock_create_callback(nullptr);
760   CRYPTO_set_dynlock_lock_callback(nullptr);
761   CRYPTO_set_dynlock_destroy_callback(nullptr);
762   CRYPTO_cleanup_all_ex_data();
763   ERR_free_strings();
764   EVP_cleanup();
765   ERR_remove_state(0);
766   locks().reset();
767   initialized_ = false;
768 }
769
770 void SSLContext::setOptions(long options) {
771   long newOpt = SSL_CTX_set_options(ctx_, options);
772   if ((newOpt & options) != options) {
773     throw std::runtime_error("SSL_CTX_set_options failed");
774   }
775 }
776
777 std::string SSLContext::getErrors(int errnoCopy) {
778   std::string errors;
779   unsigned long  errorCode;
780   char   message[256];
781
782   errors.reserve(512);
783   while ((errorCode = ERR_get_error()) != 0) {
784     if (!errors.empty()) {
785       errors += "; ";
786     }
787     const char* reason = ERR_reason_error_string(errorCode);
788     if (reason == nullptr) {
789       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
790       reason = message;
791     }
792     errors += reason;
793   }
794   if (errors.empty()) {
795     errors = "error code: " + folly::to<std::string>(errnoCopy);
796   }
797   return errors;
798 }
799
800 std::ostream&
801 operator<<(std::ostream& os, const PasswordCollector& collector) {
802   os << collector.describe();
803   return os;
804 }
805
806 bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx,
807                                                   sockaddr_storage* addrStorage,
808                                                   socklen_t* addrLen) {
809   // Grab the ssl idx and then the ssl object so that we can get the peer
810   // name to compare against the ips in the subjectAltName
811   auto sslIdx = SSL_get_ex_data_X509_STORE_CTX_idx();
812   auto ssl =
813     reinterpret_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, sslIdx));
814   int fd = SSL_get_fd(ssl);
815   if (fd < 0) {
816     LOG(ERROR) << "Inexplicably couldn't get fd from SSL";
817     return false;
818   }
819
820   *addrLen = sizeof(*addrStorage);
821   if (getpeername(fd, reinterpret_cast<sockaddr*>(addrStorage), addrLen) != 0) {
822     PLOG(ERROR) << "Unable to get peer name";
823     return false;
824   }
825   CHECK(*addrLen <= sizeof(*addrStorage));
826   return true;
827 }
828
829 bool OpenSSLUtils::validatePeerCertNames(X509* cert,
830                                          const sockaddr* addr,
831                                          socklen_t addrLen) {
832   // Try to extract the names within the SAN extension from the certificate
833   auto altNames =
834     reinterpret_cast<STACK_OF(GENERAL_NAME)*>(
835         X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
836   SCOPE_EXIT {
837     if (altNames != nullptr) {
838       sk_GENERAL_NAME_pop_free(altNames, GENERAL_NAME_free);
839     }
840   };
841   if (altNames == nullptr) {
842     LOG(WARNING) << "No subjectAltName provided and we only support ip auth";
843     return false;
844   }
845
846   const sockaddr_in* addr4 = nullptr;
847   const sockaddr_in6* addr6 = nullptr;
848   if (addr != nullptr) {
849     if (addr->sa_family == AF_INET) {
850       addr4 = reinterpret_cast<const sockaddr_in*>(addr);
851     } else if (addr->sa_family == AF_INET6) {
852       addr6 = reinterpret_cast<const sockaddr_in6*>(addr);
853     } else {
854       LOG(FATAL) << "Unsupported sockaddr family: " << addr->sa_family;
855     }
856   }
857
858
859   for (int i = 0; i < sk_GENERAL_NAME_num(altNames); i++) {
860     auto name = sk_GENERAL_NAME_value(altNames, i);
861     if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) {
862       // Extra const-ness for paranoia
863       unsigned char const * const rawIpStr = name->d.iPAddress->data;
864       int const rawIpLen = name->d.iPAddress->length;
865
866       if (rawIpLen == 4 && addr4 != nullptr) {
867         if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) {
868           return true;
869         }
870       } else if (rawIpLen == 16 && addr6 != nullptr) {
871         if (::memcmp(rawIpStr, &addr6->sin6_addr, rawIpLen) == 0) {
872           return true;
873         }
874       } else if (rawIpLen != 4 && rawIpLen != 16) {
875         LOG(WARNING) << "Unexpected IP length: " << rawIpLen;
876       }
877     }
878   }
879
880   LOG(WARNING) << "Unable to match client cert against alt name ip";
881   return false;
882 }
883
884
885 } // folly