Fix case where ssl cert does not match key
[folly.git] / folly / io / async / SSLContext.cpp
index 95ae99a12edff49cdaa921f2cf243f769b0503b8..4a95a5723c267e5f715567454416de25675061b8 100644 (file)
@@ -21,8 +21,8 @@
 #include <folly/Random.h>
 #include <folly/SharedMutex.h>
 #include <folly/SpinLock.h>
-#include <folly/ThreadId.h>
 #include <folly/ssl/Init.h>
+#include <folly/system/ThreadId.h>
 
 // ---------------------------------------------------------------------
 // SSLContext implementation
@@ -49,6 +49,10 @@ SSLContext::SSLContext(SSLVersion version) {
     case SSLv3:
       opt = SSL_OP_NO_SSLv2;
       break;
+    case TLSv1_2:
+      opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 |
+          SSL_OP_NO_TLSv1_1;
+      break;
     default:
       // do nothing
       break;
@@ -83,30 +87,6 @@ void SSLContext::ciphers(const std::string& ciphers) {
   setCiphersOrThrow(ciphers);
 }
 
-void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
-  if (ciphers.size() == 0) {
-    return;
-  }
-  std::string opensslCipherList;
-  join(":", ciphers, opensslCipherList);
-  setCiphersOrThrow(opensslCipherList);
-}
-
-void SSLContext::setSignatureAlgorithms(
-    const std::vector<std::string>& sigalgs) {
-  if (sigalgs.size() == 0) {
-    return;
-  }
-#if OPENSSL_VERSION_NUMBER >= 0x1000200fL
-  std::string opensslSigAlgsList;
-  join(":", sigalgs, opensslSigAlgsList);
-  int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
-  if (rc == 0) {
-    throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
-  }
-#endif
-}
-
 void SSLContext::setClientECCurvesList(
     const std::vector<std::string>& ecCurves) {
   if (ecCurves.size() == 0) {
@@ -226,7 +206,7 @@ void SSLContext::loadCertificate(const char* path, const char* format) {
          "loadCertificateChain: either <path> or <format> is nullptr");
   }
   if (strcmp(format, "PEM") == 0) {
-    if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
+    if (SSL_CTX_use_certificate_chain_file(ctx_, path) != 1) {
       int errnoCopy = errno;
       std::string reason("SSL_CTX_use_certificate_chain_file: ");
       reason.append(path);
@@ -307,6 +287,32 @@ void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
   }
 }
 
+void SSLContext::loadCertKeyPairFromBufferPEM(
+    folly::StringPiece cert,
+    folly::StringPiece pkey) {
+  loadCertificateFromBufferPEM(cert);
+  loadPrivateKeyFromBufferPEM(pkey);
+  if (!isCertKeyPairValid()) {
+    throw std::runtime_error("SSL certificate and private key do not match");
+  }
+}
+
+void SSLContext::loadCertKeyPairFromFiles(
+    const char* certPath,
+    const char* keyPath,
+    const char* certFormat,
+    const char* keyFormat) {
+  loadCertificate(certPath, certFormat);
+  loadPrivateKey(keyPath, keyFormat);
+  if (!isCertKeyPairValid()) {
+    throw std::runtime_error("SSL certificate and private key do not match");
+  }
+}
+
+bool SSLContext::isCertKeyPairValid() const {
+  return SSL_CTX_check_private_key(ctx_) == 1;
+}
+
 void SSLContext::loadTrustedCertificates(const char* path) {
   if (path == nullptr) {
     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
@@ -387,55 +393,6 @@ int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
 
   return SSL_TLSEXT_ERR_NOACK;
 }
-
-void SSLContext::switchCiphersIfTLS11(
-    SSL* ssl,
-    const std::string& tls11CipherString,
-    const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
-  CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
-      << "Shouldn't call if empty ciphers / alt ciphers";
-
-  if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
-    // We only do this for TLS v 1.1 and later
-    return;
-  }
-
-  const std::string* ciphers = &tls11CipherString;
-  if (!tls11AltCipherlist.empty()) {
-    if (!cipherListPicker_) {
-      std::vector<int> weights;
-      std::for_each(
-          tls11AltCipherlist.begin(),
-          tls11AltCipherlist.end(),
-          [&](const std::pair<std::string, int>& e) {
-            weights.push_back(e.second);
-          });
-      cipherListPicker_.reset(
-          new std::discrete_distribution<int>(weights.begin(), weights.end()));
-    }
-    auto rng = ThreadLocalPRNG();
-    auto index = (*cipherListPicker_)(rng);
-    if ((size_t)index >= tls11AltCipherlist.size()) {
-      LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
-                 << ", but tls11AltCipherlist is of length "
-                 << tls11AltCipherlist.size();
-    } else {
-      ciphers = &tls11AltCipherlist[size_t(index)].first;
-    }
-  }
-
-  // Prefer AES for TLS versions 1.1 and later since these are not
-  // vulnerable to BEAST attacks on AES.  Note that we're setting the
-  // cipher list on the SSL object, not the SSL_CTX object, so it will
-  // only last for this request.
-  int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
-  if ((rc == 0) || ERR_peek_error() != 0) {
-    // This shouldn't happen since we checked for this when proxygen
-    // started up.
-    LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
-    SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
-  }
-}
 #endif // FOLLY_OPENSSL_HAS_SNI
 
 #if FOLLY_OPENSSL_HAS_ALPN
@@ -642,8 +599,7 @@ void SSLContext::setSessionCacheContext(const std::string& context) {
       ctx_,
       reinterpret_cast<const unsigned char*>(context.data()),
       std::min<unsigned int>(
-          static_cast<unsigned int>(context.length()),
-          SSL_MAX_SSL_SESSION_ID_LENGTH));
+          static_cast<unsigned int>(context.length()), SSL_MAX_SID_CTX_LENGTH));
 }
 
 /**
@@ -690,16 +646,9 @@ int SSLContext::passwordCallback(char* password,
   std::string userPassword;
   // call user defined password collector to get password
   context->passwordCollector()->getPassword(userPassword, size);
-  auto length = int(userPassword.size());
-  if (length > size) {
-    length = size;
-  }
-  strncpy(password, userPassword.c_str(), size_t(length));
-  return length;
-}
-
-void SSLContext::setSSLLockTypes(std::map<int, LockType> inLockTypes) {
-  folly::ssl::setLockTypes(inLockTypes);
+  auto const length = std::min(userPassword.size(), size_t(size));
+  std::memcpy(password, userPassword.data(), length);
+  return int(length);
 }
 
 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
@@ -748,4 +697,4 @@ operator<<(std::ostream& os, const PasswordCollector& collector) {
   return os;
 }
 
-} // folly
+} // namespace folly