SSLContext
[folly.git] / folly / io / async / SSLContext.cpp
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "SSLContext.h"
18
19 #include <openssl/err.h>
20 #include <openssl/rand.h>
21 #include <openssl/ssl.h>
22 #include <openssl/x509v3.h>
23
24 #include <folly/SmallLocks.h>
25 #include <folly/Format.h>
26 #include <folly/io/PortableSpinLock.h>
27
28 // ---------------------------------------------------------------------
29 // SSLContext implementation
30 // ---------------------------------------------------------------------
31
32 struct CRYPTO_dynlock_value {
33   std::mutex mutex;
34 };
35
36 namespace folly {
37
38 uint64_t SSLContext::count_ = 0;
39 std::mutex    SSLContext::mutex_;
40 #ifdef OPENSSL_NPN_NEGOTIATED
41 int SSLContext::sNextProtocolsExDataIndex_ = -1;
42
43 #endif
44 // SSLContext implementation
45 SSLContext::SSLContext(SSLVersion version) {
46   {
47     std::lock_guard<std::mutex> g(mutex_);
48     if (!count_++) {
49       initializeOpenSSL();
50       randomize();
51 #ifdef OPENSSL_NPN_NEGOTIATED
52       sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
53           (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
54 #endif
55     }
56   }
57
58   ctx_ = SSL_CTX_new(SSLv23_method());
59   if (ctx_ == nullptr) {
60     throw std::runtime_error("SSL_CTX_new: " + getErrors());
61   }
62
63   int opt = 0;
64   switch (version) {
65     case TLSv1:
66       opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3;
67       break;
68     case SSLv3:
69       opt = SSL_OP_NO_SSLv2;
70       break;
71     default:
72       // do nothing
73       break;
74   }
75   int newOpt = SSL_CTX_set_options(ctx_, opt);
76   DCHECK((newOpt & opt) == opt);
77
78   SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
79
80   checkPeerName_ = false;
81
82 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
83   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
84   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
85 #endif
86 }
87
88 SSLContext::~SSLContext() {
89   if (ctx_ != nullptr) {
90     SSL_CTX_free(ctx_);
91     ctx_ = nullptr;
92   }
93
94 #ifdef OPENSSL_NPN_NEGOTIATED
95   deleteNextProtocolsStrings();
96 #endif
97
98   std::lock_guard<std::mutex> g(mutex_);
99   if (!--count_) {
100     cleanupOpenSSL();
101   }
102 }
103
104 void SSLContext::ciphers(const std::string& ciphers) {
105   providedCiphersString_ = ciphers;
106   setCiphersOrThrow(ciphers);
107 }
108
109 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
110   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
111   if (ERR_peek_error() != 0) {
112     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
113   }
114   if (rc == 0) {
115     throw std::runtime_error("None of specified ciphers are supported");
116   }
117 }
118
119 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
120     verifyPeer) {
121   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX); // dont recurse
122   verifyPeer_ = verifyPeer;
123 }
124
125 int SSLContext::getVerificationMode(const SSLContext::SSLVerifyPeerEnum&
126     verifyPeer) {
127   CHECK(verifyPeer != SSLVerifyPeerEnum::USE_CTX);
128   int mode = SSL_VERIFY_NONE;
129   switch(verifyPeer) {
130     // case SSLVerifyPeerEnum::USE_CTX: // can't happen
131     // break;
132
133     case SSLVerifyPeerEnum::VERIFY:
134       mode = SSL_VERIFY_PEER;
135       break;
136
137     case SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT:
138       mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
139       break;
140
141     case SSLVerifyPeerEnum::NO_VERIFY:
142       mode = SSL_VERIFY_NONE;
143       break;
144
145     default:
146       break;
147   }
148   return mode;
149 }
150
151 int SSLContext::getVerificationMode() {
152   return getVerificationMode(verifyPeer_);
153 }
154
155 void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
156                               const std::string& peerName) {
157   int mode;
158   if (checkPeerCert) {
159     mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
160     checkPeerName_ = checkPeerName;
161     peerFixedName_ = peerName;
162   } else {
163     mode = SSL_VERIFY_NONE;
164     checkPeerName_ = false; // can't check name without cert!
165     peerFixedName_.clear();
166   }
167   SSL_CTX_set_verify(ctx_, mode, nullptr);
168 }
169
170 void SSLContext::loadCertificate(const char* path, const char* format) {
171   if (path == nullptr || format == nullptr) {
172     throw std::invalid_argument(
173          "loadCertificateChain: either <path> or <format> is nullptr");
174   }
175   if (strcmp(format, "PEM") == 0) {
176     if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
177       int errnoCopy = errno;
178       std::string reason("SSL_CTX_use_certificate_chain_file: ");
179       reason.append(path);
180       reason.append(": ");
181       reason.append(getErrors(errnoCopy));
182       throw std::runtime_error(reason);
183     }
184   } else {
185     throw std::runtime_error("Unsupported certificate format: " + std::string(format));
186   }
187 }
188
189 void SSLContext::loadPrivateKey(const char* path, const char* format) {
190   if (path == nullptr || format == nullptr) {
191     throw std::invalid_argument(
192          "loadPrivateKey: either <path> or <format> is nullptr");
193   }
194   if (strcmp(format, "PEM") == 0) {
195     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
196       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
197     }
198   } else {
199     throw std::runtime_error("Unsupported private key format: " + std::string(format));
200   }
201 }
202
203 void SSLContext::loadTrustedCertificates(const char* path) {
204   if (path == nullptr) {
205     throw std::invalid_argument(
206          "loadTrustedCertificates: <path> is nullptr");
207   }
208   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
209     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
210   }
211 }
212
213 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
214   SSL_CTX_set_cert_store(ctx_, store);
215 }
216
217 void SSLContext::loadClientCAList(const char* path) {
218   auto clientCAs = SSL_load_client_CA_file(path);
219   if (clientCAs == nullptr) {
220     LOG(ERROR) << "Unable to load ca file: " << path;
221     return;
222   }
223   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
224 }
225
226 void SSLContext::randomize() {
227   RAND_poll();
228 }
229
230 void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
231   if (collector == nullptr) {
232     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
233     return;
234   }
235   collector_ = collector;
236   SSL_CTX_set_default_passwd_cb(ctx_, passwordCallback);
237   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
238 }
239
240 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
241
242 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
243   serverNameCb_ = cb;
244 }
245
246 void SSLContext::addClientHelloCallback(const ClientHelloCallback& cb) {
247   clientHelloCbs_.push_back(cb);
248 }
249
250 int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
251   SSLContext* context = (SSLContext*)data;
252
253   if (context == nullptr) {
254     return SSL_TLSEXT_ERR_NOACK;
255   }
256
257   for (auto& cb : context->clientHelloCbs_) {
258     // Generic callbacks to happen after we receive the Client Hello.
259     // For example, we use one to switch which cipher we use depending
260     // on the user's TLS version.  Because the primary purpose of
261     // baseServerNameOpenSSLCallback is for SNI support, and these callbacks
262     // are side-uses, we ignore any possible failures other than just logging
263     // them.
264     cb(ssl);
265   }
266
267   if (!context->serverNameCb_) {
268     return SSL_TLSEXT_ERR_NOACK;
269   }
270
271   ServerNameCallbackResult ret = context->serverNameCb_(ssl);
272   switch (ret) {
273     case SERVER_NAME_FOUND:
274       return SSL_TLSEXT_ERR_OK;
275     case SERVER_NAME_NOT_FOUND:
276       return SSL_TLSEXT_ERR_NOACK;
277     case SERVER_NAME_NOT_FOUND_ALERT_FATAL:
278       *al = TLS1_AD_UNRECOGNIZED_NAME;
279       return SSL_TLSEXT_ERR_ALERT_FATAL;
280     default:
281       CHECK(false);
282   }
283
284   return SSL_TLSEXT_ERR_NOACK;
285 }
286
287 void SSLContext::switchCiphersIfTLS11(
288     SSL* ssl,
289     const std::string& tls11CipherString) {
290
291   CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
292
293   if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
294     // We only do this for TLS v 1.1 and later
295     return;
296   }
297
298   // Prefer AES for TLS versions 1.1 and later since these are not
299   // vulnerable to BEAST attacks on AES.  Note that we're setting the
300   // cipher list on the SSL object, not the SSL_CTX object, so it will
301   // only last for this request.
302   int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
303   if ((rc == 0) || ERR_peek_error() != 0) {
304     // This shouldn't happen since we checked for this when proxygen
305     // started up.
306     LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
307     SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
308   }
309 }
310 #endif
311
312 #ifdef OPENSSL_NPN_NEGOTIATED
313 bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
314   return setRandomizedAdvertisedNextProtocols({{1, protocols}});
315 }
316
317 bool SSLContext::setRandomizedAdvertisedNextProtocols(
318     const std::list<NextProtocolsItem>& items) {
319   unsetNextProtocols();
320   if (items.size() == 0) {
321     return false;
322   }
323   int total_weight = 0;
324   for (const auto &item : items) {
325     if (item.protocols.size() == 0) {
326       continue;
327     }
328     AdvertisedNextProtocolsItem advertised_item;
329     advertised_item.length = 0;
330     for (const auto& proto : item.protocols) {
331       ++advertised_item.length;
332       unsigned protoLength = proto.length();
333       if (protoLength >= 256) {
334         deleteNextProtocolsStrings();
335         return false;
336       }
337       advertised_item.length += protoLength;
338     }
339     advertised_item.protocols = new unsigned char[advertised_item.length];
340     if (!advertised_item.protocols) {
341       throw std::runtime_error("alloc failure");
342     }
343     unsigned char* dst = advertised_item.protocols;
344     for (auto& proto : item.protocols) {
345       unsigned protoLength = proto.length();
346       *dst++ = (unsigned char)protoLength;
347       memcpy(dst, proto.data(), protoLength);
348       dst += protoLength;
349     }
350     total_weight += item.weight;
351     advertised_item.probability = item.weight;
352     advertisedNextProtocols_.push_back(advertised_item);
353   }
354   if (total_weight == 0) {
355     deleteNextProtocolsStrings();
356     return false;
357   }
358   for (auto &advertised_item : advertisedNextProtocols_) {
359     advertised_item.probability /= total_weight;
360   }
361   SSL_CTX_set_next_protos_advertised_cb(
362     ctx_, advertisedNextProtocolCallback, this);
363   SSL_CTX_set_next_proto_select_cb(
364     ctx_, selectNextProtocolCallback, this);
365   return true;
366 }
367
368 void SSLContext::deleteNextProtocolsStrings() {
369   for (auto protocols : advertisedNextProtocols_) {
370     delete[] protocols.protocols;
371   }
372   advertisedNextProtocols_.clear();
373 }
374
375 void SSLContext::unsetNextProtocols() {
376   deleteNextProtocolsStrings();
377   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
378   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
379 }
380
381 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
382       const unsigned char** out, unsigned int* outlen, void* data) {
383   SSLContext* context = (SSLContext*)data;
384   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
385     *out = nullptr;
386     *outlen = 0;
387   } else if (context->advertisedNextProtocols_.size() == 1) {
388     *out = context->advertisedNextProtocols_[0].protocols;
389     *outlen = context->advertisedNextProtocols_[0].length;
390   } else {
391     uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
392           sNextProtocolsExDataIndex_));
393     if (selected_index) {
394       --selected_index;
395       *out = context->advertisedNextProtocols_[selected_index].protocols;
396       *outlen = context->advertisedNextProtocols_[selected_index].length;
397     } else {
398       unsigned char random_byte;
399       RAND_bytes(&random_byte, 1);
400       double random_value = random_byte / 255.0;
401       double sum = 0;
402       for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
403         sum += context->advertisedNextProtocols_[i].probability;
404         if (sum < random_value &&
405             i + 1 < context->advertisedNextProtocols_.size()) {
406           continue;
407         }
408         uintptr_t selected = i + 1;
409         SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
410         *out = context->advertisedNextProtocols_[i].protocols;
411         *outlen = context->advertisedNextProtocols_[i].length;
412         break;
413       }
414     }
415   }
416   return SSL_TLSEXT_ERR_OK;
417 }
418
419 int SSLContext::selectNextProtocolCallback(
420   SSL* ssl, unsigned char **out, unsigned char *outlen,
421   const unsigned char *server, unsigned int server_len, void *data) {
422
423   SSLContext* ctx = (SSLContext*)data;
424   if (ctx->advertisedNextProtocols_.size() > 1) {
425     VLOG(3) << "SSLContext::selectNextProcolCallback() "
426             << "client should be deterministic in selecting protocols.";
427   }
428
429   unsigned char *client;
430   int client_len;
431   if (ctx->advertisedNextProtocols_.empty()) {
432     client = (unsigned char *) "";
433     client_len = 0;
434   } else {
435     client = ctx->advertisedNextProtocols_[0].protocols;
436     client_len = ctx->advertisedNextProtocols_[0].length;
437   }
438
439   int retval = SSL_select_next_proto(out, outlen, server, server_len,
440                                      client, client_len);
441   if (retval != OPENSSL_NPN_NEGOTIATED) {
442     VLOG(3) << "SSLContext::selectNextProcolCallback() "
443             << "unable to pick a next protocol.";
444   }
445   return SSL_TLSEXT_ERR_OK;
446 }
447 #endif // OPENSSL_NPN_NEGOTIATED
448
449 SSL* SSLContext::createSSL() const {
450   SSL* ssl = SSL_new(ctx_);
451   if (ssl == nullptr) {
452     throw std::runtime_error("SSL_new: " + getErrors());
453   }
454   return ssl;
455 }
456
457 /**
458  * Match a name with a pattern. The pattern may include wildcard. A single
459  * wildcard "*" can match up to one component in the domain name.
460  *
461  * @param  host    Host name, typically the name of the remote host
462  * @param  pattern Name retrieved from certificate
463  * @param  size    Size of "pattern"
464  * @return True, if "host" matches "pattern". False otherwise.
465  */
466 bool SSLContext::matchName(const char* host, const char* pattern, int size) {
467   bool match = false;
468   int i = 0, j = 0;
469   while (i < size && host[j] != '\0') {
470     if (toupper(pattern[i]) == toupper(host[j])) {
471       i++;
472       j++;
473       continue;
474     }
475     if (pattern[i] == '*') {
476       while (host[j] != '.' && host[j] != '\0') {
477         j++;
478       }
479       i++;
480       continue;
481     }
482     break;
483   }
484   if (i == size && host[j] == '\0') {
485     match = true;
486   }
487   return match;
488 }
489
490 int SSLContext::passwordCallback(char* password,
491                                  int size,
492                                  int,
493                                  void* data) {
494   SSLContext* context = (SSLContext*)data;
495   if (context == nullptr || context->passwordCollector() == nullptr) {
496     return 0;
497   }
498   std::string userPassword;
499   // call user defined password collector to get password
500   context->passwordCollector()->getPassword(userPassword, size);
501   int length = userPassword.size();
502   if (length > size) {
503     length = size;
504   }
505   strncpy(password, userPassword.c_str(), length);
506   return length;
507 }
508
509 struct SSLLock {
510   explicit SSLLock(
511     SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
512       lockType(inLockType) {
513   }
514
515   void lock() {
516     if (lockType == SSLContext::LOCK_MUTEX) {
517       mutex.lock();
518     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
519       spinLock.lock();
520     }
521     // lockType == LOCK_NONE, no-op
522   }
523
524   void unlock() {
525     if (lockType == SSLContext::LOCK_MUTEX) {
526       mutex.unlock();
527     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
528       spinLock.unlock();
529     }
530     // lockType == LOCK_NONE, no-op
531   }
532
533   SSLContext::SSLLockType lockType;
534   folly::io::PortableSpinLock spinLock{};
535   std::mutex mutex;
536 };
537
538 static std::map<int, SSLContext::SSLLockType> lockTypes;
539 static std::unique_ptr<SSLLock[]> locks;
540
541 static void callbackLocking(int mode, int n, const char*, int) {
542   if (mode & CRYPTO_LOCK) {
543     locks[n].lock();
544   } else {
545     locks[n].unlock();
546   }
547 }
548
549 static unsigned long callbackThreadID() {
550   return static_cast<unsigned long>(pthread_self());
551 }
552
553 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
554   return new CRYPTO_dynlock_value;
555 }
556
557 static void dyn_lock(int mode,
558                      struct CRYPTO_dynlock_value* lock,
559                      const char*, int) {
560   if (lock != nullptr) {
561     if (mode & CRYPTO_LOCK) {
562       lock->mutex.lock();
563     } else {
564       lock->mutex.unlock();
565     }
566   }
567 }
568
569 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
570   delete lock;
571 }
572
573 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
574   lockTypes = inLockTypes;
575 }
576
577 void SSLContext::initializeOpenSSL() {
578   SSL_library_init();
579   SSL_load_error_strings();
580   ERR_load_crypto_strings();
581   // static locking
582   locks.reset(new SSLLock[::CRYPTO_num_locks()]);
583   for (auto it: lockTypes) {
584     locks[it.first].lockType = it.second;
585   }
586   CRYPTO_set_id_callback(callbackThreadID);
587   CRYPTO_set_locking_callback(callbackLocking);
588   // dynamic locking
589   CRYPTO_set_dynlock_create_callback(dyn_create);
590   CRYPTO_set_dynlock_lock_callback(dyn_lock);
591   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
592 }
593
594 void SSLContext::cleanupOpenSSL() {
595   CRYPTO_set_id_callback(nullptr);
596   CRYPTO_set_locking_callback(nullptr);
597   CRYPTO_set_dynlock_create_callback(nullptr);
598   CRYPTO_set_dynlock_lock_callback(nullptr);
599   CRYPTO_set_dynlock_destroy_callback(nullptr);
600   CRYPTO_cleanup_all_ex_data();
601   ERR_free_strings();
602   EVP_cleanup();
603   ERR_remove_state(0);
604   locks.reset();
605 }
606
607 void SSLContext::setOptions(long options) {
608   long newOpt = SSL_CTX_set_options(ctx_, options);
609   if ((newOpt & options) != options) {
610     throw std::runtime_error("SSL_CTX_set_options failed");
611   }
612 }
613
614 std::string SSLContext::getErrors(int errnoCopy) {
615   std::string errors;
616   unsigned long  errorCode;
617   char   message[256];
618
619   errors.reserve(512);
620   while ((errorCode = ERR_get_error()) != 0) {
621     if (!errors.empty()) {
622       errors += "; ";
623     }
624     const char* reason = ERR_reason_error_string(errorCode);
625     if (reason == nullptr) {
626       snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
627       reason = message;
628     }
629     errors += reason;
630   }
631   if (errors.empty()) {
632     errors = "error code: " + folly::to<std::string>(errnoCopy);
633   }
634   return errors;
635 }
636
637 std::ostream&
638 operator<<(std::ostream& os, const PasswordCollector& collector) {
639   os << collector.describe();
640   return os;
641 }
642
643 } // folly