folly: build with -Wunused-parameter
[folly.git] / folly / io / async / SSLContext.cpp
index c05fe330f9a0a95c8f264950cb712e987c8f371e..788491c0ec1369de2684c93f5b77d2b2ab36d01c 100644 (file)
@@ -22,7 +22,9 @@
 #include <openssl/x509v3.h>
 
 #include <folly/Format.h>
+#include <folly/Memory.h>
 #include <folly/SpinLock.h>
+#include <folly/io/async/OpenSSLPtrTypes.h>
 
 // ---------------------------------------------------------------------
 // SSLContext implementation
@@ -35,7 +37,19 @@ struct CRYPTO_dynlock_value {
 namespace folly {
 
 bool SSLContext::initialized_ = false;
-std::mutex    SSLContext::mutex_;
+
+namespace {
+
+std::mutex& initMutex() {
+  static std::mutex m;
+  return m;
+}
+
+inline void BIO_free_fb(BIO* bio) { CHECK_EQ(1, BIO_free(bio)); }
+using BIO_deleter = folly::static_function_deleter<BIO, &BIO_free_fb>;
+
+} // anonymous namespace
+
 #ifdef OPENSSL_NPN_NEGOTIATED
 int SSLContext::sNextProtocolsExDataIndex_ = -1;
 #endif
@@ -43,7 +57,7 @@ int SSLContext::sNextProtocolsExDataIndex_ = -1;
 // SSLContext implementation
 SSLContext::SSLContext(SSLVersion version) {
   {
-    std::lock_guard<std::mutex> g(mutex_);
+    std::lock_guard<std::mutex> g(initMutex());
     initializeOpenSSLLocked();
   }
 
@@ -75,6 +89,10 @@ SSLContext::SSLContext(SSLVersion version) {
   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
 #endif
+
+#ifdef OPENSSL_NPN_NEGOTIATED
+  Random::seed(nextProtocolPicker_);
+#endif
 }
 
 SSLContext::~SSLContext() {
@@ -173,10 +191,35 @@ void SSLContext::loadCertificate(const char* path, const char* format) {
   }
 }
 
+void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
+  if (cert.data() == nullptr) {
+    throw std::invalid_argument("loadCertificate: <cert> is nullptr");
+  }
+
+  std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
+  if (bio == nullptr) {
+    throw std::runtime_error("BIO_new: " + getErrors());
+  }
+
+  int written = BIO_write(bio.get(), cert.data(), cert.size());
+  if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
+    throw std::runtime_error("BIO_write: " + getErrors());
+  }
+
+  X509_UniquePtr x509(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
+  if (x509 == nullptr) {
+    throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
+  }
+
+  if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
+    throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
+  }
+}
+
 void SSLContext::loadPrivateKey(const char* path, const char* format) {
   if (path == nullptr || format == nullptr) {
     throw std::invalid_argument(
-         "loadPrivateKey: either <path> or <format> is nullptr");
+        "loadPrivateKey: either <path> or <format> is nullptr");
   }
   if (strcmp(format, "PEM") == 0) {
     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
@@ -187,10 +230,35 @@ void SSLContext::loadPrivateKey(const char* path, const char* format) {
   }
 }
 
+void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
+  if (pkey.data() == nullptr) {
+    throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
+  }
+
+  std::unique_ptr<BIO, BIO_deleter> bio(BIO_new(BIO_s_mem()));
+  if (bio == nullptr) {
+    throw std::runtime_error("BIO_new: " + getErrors());
+  }
+
+  int written = BIO_write(bio.get(), pkey.data(), pkey.size());
+  if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
+    throw std::runtime_error("BIO_write: " + getErrors());
+  }
+
+  EVP_PKEY_UniquePtr key(
+      PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
+  if (key == nullptr) {
+    throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
+  }
+
+  if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
+    throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
+  }
+}
+
 void SSLContext::loadTrustedCertificates(const char* path) {
   if (path == nullptr) {
-    throw std::invalid_argument(
-         "loadTrustedCertificates: <path> is nullptr");
+    throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
   }
   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
@@ -296,13 +364,43 @@ void SSLContext::switchCiphersIfTLS11(
 }
 #endif
 
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+int SSLContext::alpnSelectCallback(SSL* /* ssl */,
+                                   const unsigned char** out,
+                                   unsigned char* outlen,
+                                   const unsigned char* in,
+                                   unsigned int inlen,
+                                   void* data) {
+  SSLContext* context = (SSLContext*)data;
+  CHECK(context);
+  if (context->advertisedNextProtocols_.empty()) {
+    *out = nullptr;
+    *outlen = 0;
+  } else {
+    auto i = context->pickNextProtocols();
+    const auto& item = context->advertisedNextProtocols_[i];
+    if (SSL_select_next_proto((unsigned char**)out,
+                              outlen,
+                              item.protocols,
+                              item.length,
+                              in,
+                              inlen) != OPENSSL_NPN_NEGOTIATED) {
+      return SSL_TLSEXT_ERR_NOACK;
+    }
+  }
+  return SSL_TLSEXT_ERR_OK;
+}
+#endif
+
 #ifdef OPENSSL_NPN_NEGOTIATED
-bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
-  return setRandomizedAdvertisedNextProtocols({{1, protocols}});
+
+bool SSLContext::setAdvertisedNextProtocols(
+    const std::list<std::string>& protocols, NextProtocolType protocolType) {
+  return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
 }
 
 bool SSLContext::setRandomizedAdvertisedNextProtocols(
-    const std::list<NextProtocolsItem>& items) {
+    const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
   unsetNextProtocols();
   if (items.size() == 0) {
     return false;
@@ -335,20 +433,30 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols(
       dst += protoLength;
     }
     total_weight += item.weight;
-    advertised_item.probability = item.weight;
     advertisedNextProtocols_.push_back(advertised_item);
+    advertisedNextProtocolWeights_.push_back(item.weight);
   }
   if (total_weight == 0) {
     deleteNextProtocolsStrings();
     return false;
   }
-  for (auto &advertised_item : advertisedNextProtocols_) {
-    advertised_item.probability /= total_weight;
+  nextProtocolDistribution_ =
+      std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
+                                   advertisedNextProtocolWeights_.end());
+  if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
+    SSL_CTX_set_next_protos_advertised_cb(
+        ctx_, advertisedNextProtocolCallback, this);
+    SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
+  }
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+  if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
+    SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
+    // Client cannot really use randomized alpn
+    SSL_CTX_set_alpn_protos(ctx_,
+                            advertisedNextProtocols_[0].protocols,
+                            advertisedNextProtocols_[0].length);
   }
-  SSL_CTX_set_next_protos_advertised_cb(
-    ctx_, advertisedNextProtocolCallback, this);
-  SSL_CTX_set_next_proto_select_cb(
-    ctx_, selectNextProtocolCallback, this);
+#endif
   return true;
 }
 
@@ -357,12 +465,22 @@ void SSLContext::deleteNextProtocolsStrings() {
     delete[] protocols.protocols;
   }
   advertisedNextProtocols_.clear();
+  advertisedNextProtocolWeights_.clear();
 }
 
 void SSLContext::unsetNextProtocols() {
   deleteNextProtocolsStrings();
   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+  SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
+  SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
+#endif
+}
+
+size_t SSLContext::pickNextProtocols() {
+  CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
+  return nextProtocolDistribution_(nextProtocolPicker_);
 }
 
 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
@@ -382,22 +500,11 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
       *out = context->advertisedNextProtocols_[selected_index].protocols;
       *outlen = context->advertisedNextProtocols_[selected_index].length;
     } else {
-      unsigned char random_byte;
-      RAND_bytes(&random_byte, 1);
-      double random_value = random_byte / 255.0;
-      double sum = 0;
-      for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
-        sum += context->advertisedNextProtocols_[i].probability;
-        if (sum < random_value &&
-            i + 1 < context->advertisedNextProtocols_.size()) {
-          continue;
-        }
-        uintptr_t selected = i + 1;
-        SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
-        *out = context->advertisedNextProtocols_[i].protocols;
-        *outlen = context->advertisedNextProtocols_[i].length;
-        break;
-      }
+      auto i = context->pickNextProtocols();
+      uintptr_t selected = i + 1;
+      SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
+      *out = context->advertisedNextProtocols_[i].protocols;
+      *outlen = context->advertisedNextProtocols_[i].length;
     }
   }
   return SSL_TLSEXT_ERR_OK;
@@ -406,46 +513,21 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
 SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
-  /**
-   * The list was generated as follows:
-   * grep "_CK_" openssl-1.0.1e/ssl/s3_lib.c -A 4 |
-   * while read A && read B && read C && read D && read E && read F; do
-   * echo $A $B $C $D $E; done |
-   * grep "\(SSL_kDHr\|SSL_kDHd\|SSL_kEDH\|SSL_kECDHr\|
-   *         SSL_kECDHe\|SSL_kEECDH\)" | grep -v SSL_aNULL | grep SSL_AES |
-   * awk -F, '{ print $1"," }'
-   */
   ciphers_{
-    TLS1_CK_DH_DSS_WITH_AES_128_SHA,
-    TLS1_CK_DH_RSA_WITH_AES_128_SHA,
     TLS1_CK_DHE_DSS_WITH_AES_128_SHA,
     TLS1_CK_DHE_RSA_WITH_AES_128_SHA,
-    TLS1_CK_DH_DSS_WITH_AES_256_SHA,
-    TLS1_CK_DH_RSA_WITH_AES_256_SHA,
     TLS1_CK_DHE_DSS_WITH_AES_256_SHA,
     TLS1_CK_DHE_RSA_WITH_AES_256_SHA,
-    TLS1_CK_DH_DSS_WITH_AES_128_SHA256,
-    TLS1_CK_DH_RSA_WITH_AES_128_SHA256,
     TLS1_CK_DHE_DSS_WITH_AES_128_SHA256,
     TLS1_CK_DHE_RSA_WITH_AES_128_SHA256,
-    TLS1_CK_DH_DSS_WITH_AES_256_SHA256,
-    TLS1_CK_DH_RSA_WITH_AES_256_SHA256,
     TLS1_CK_DHE_DSS_WITH_AES_256_SHA256,
     TLS1_CK_DHE_RSA_WITH_AES_256_SHA256,
     TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256,
     TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384,
-    TLS1_CK_DH_RSA_WITH_AES_128_GCM_SHA256,
-    TLS1_CK_DH_RSA_WITH_AES_256_GCM_SHA384,
     TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256,
     TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384,
-    TLS1_CK_DH_DSS_WITH_AES_128_GCM_SHA256,
-    TLS1_CK_DH_DSS_WITH_AES_256_GCM_SHA384,
-    TLS1_CK_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
-    TLS1_CK_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
-    TLS1_CK_ECDH_RSA_WITH_AES_128_CBC_SHA,
-    TLS1_CK_ECDH_RSA_WITH_AES_256_CBC_SHA,
     TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA,
     TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA,
     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256,
@@ -454,15 +536,10 @@ SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
     TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384,
     TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256,
     TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384,
-    TLS1_CK_ECDH_RSA_WITH_AES_128_SHA256,
-    TLS1_CK_ECDH_RSA_WITH_AES_256_SHA384,
     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
-    TLS1_CK_ECDH_ECDSA_WITH_AES_128_GCM_SHA256,
-    TLS1_CK_ECDH_ECDSA_WITH_AES_256_GCM_SHA384,
     TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
     TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
-    TLS1_CK_ECDH_RSA_WITH_AES_128_GCM_SHA256,
   } {
   length_ = sizeof(ciphers_)/sizeof(ciphers_[0]);
   width_ = sizeof(ciphers_[0]);
@@ -493,9 +570,12 @@ bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) {
 }
 #endif
 
-int SSLContext::selectNextProtocolCallback(
-  SSL* ssl, unsigned char **out, unsigned char *outlen,
-  const unsigned char *server, unsigned int server_len, void *data) {
+int SSLContext::selectNextProtocolCallback(SSL* /* ssl */,
+                                           unsigned char** out,
+                                           unsigned char* outlen,
+                                           const unsigned char* server,
+                                           unsigned int server_len,
+                                           void* data) {
 
   SSLContext* ctx = (SSLContext*)data;
   if (ctx->advertisedNextProtocols_.size() > 1) {
@@ -504,13 +584,21 @@ int SSLContext::selectNextProtocolCallback(
   }
 
   unsigned char *client;
-  int client_len;
-  if (ctx->advertisedNextProtocols_.empty()) {
-    client = (unsigned char *) "";
-    client_len = 0;
-  } else {
-    client = ctx->advertisedNextProtocols_[0].protocols;
-    client_len = ctx->advertisedNextProtocols_[0].length;
+  unsigned int client_len;
+  bool filtered = false;
+  auto cpf = ctx->getClientProtocolFilterCallback();
+  if (cpf) {
+    filtered = (*cpf)(&client, &client_len, server, server_len);
+  }
+
+  if (!filtered) {
+    if (ctx->advertisedNextProtocols_.empty()) {
+      client = (unsigned char *) "";
+      client_len = 0;
+    } else {
+      client = ctx->advertisedNextProtocols_[0].protocols;
+      client_len = ctx->advertisedNextProtocols_[0].length;
+    }
   }
 
   int retval = SSL_select_next_proto(out, outlen, server, server_len,
@@ -626,17 +714,13 @@ struct SSLLock {
 // SSLContext runs in such environments.
 // Instead of declaring a static member we "new" the static
 // member so that it won't be destructed on exit().
-static std::map<int, SSLContext::SSLLockType>* lockTypesInst =
-  new std::map<int, SSLContext::SSLLockType>();
-
-static std::unique_ptr<SSLLock[]>* locksInst =
-  new std::unique_ptr<SSLLock[]>();
-
 static std::unique_ptr<SSLLock[]>& locks() {
+  static auto locksInst = new std::unique_ptr<SSLLock[]>();
   return *locksInst;
 }
 
 static std::map<int, SSLContext::SSLLockType>& lockTypes() {
+  static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
   return *lockTypesInst;
 }
 
@@ -682,8 +766,13 @@ void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
   lockTypes() = inLockTypes;
 }
 
+void SSLContext::markInitialized() {
+  std::lock_guard<std::mutex> g(initMutex());
+  initialized_ = true;
+}
+
 void SSLContext::initializeOpenSSL() {
-  std::lock_guard<std::mutex> g(mutex_);
+  std::lock_guard<std::mutex> g(initMutex());
   initializeOpenSSLLocked();
 }
 
@@ -714,7 +803,7 @@ void SSLContext::initializeOpenSSLLocked() {
 }
 
 void SSLContext::cleanupOpenSSL() {
-  std::lock_guard<std::mutex> g(mutex_);
+  std::lock_guard<std::mutex> g(initMutex());
   cleanupOpenSSLLocked();
 }
 
@@ -797,7 +886,7 @@ bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx,
 
 bool OpenSSLUtils::validatePeerCertNames(X509* cert,
                                          const sockaddr* addr,
-                                         socklen_t addrLen) {
+                                         socklen_t /* addrLen */) {
   // Try to extract the names within the SAN extension from the certificate
   auto altNames =
     reinterpret_cast<STACK_OF(GENERAL_NAME)*>(