Fix case where ssl cert does not match key
[folly.git] / folly / io / async / SSLContext.cpp
index 3abb5d0b8a307439a2ec09d3a7c19ac32f014c06..4a95a5723c267e5f715567454416de25675061b8 100644 (file)
 #include <folly/Format.h>
 #include <folly/Memory.h>
 #include <folly/Random.h>
+#include <folly/SharedMutex.h>
 #include <folly/SpinLock.h>
-#include <folly/ThreadId.h>
+#include <folly/ssl/Init.h>
+#include <folly/system/ThreadId.h>
 
 // ---------------------------------------------------------------------
 // SSLContext implementation
 // ---------------------------------------------------------------------
-
-struct CRYPTO_dynlock_value {
-  std::mutex mutex;
-};
-
 namespace folly {
 //
 // For OpenSSL portability API
 using namespace folly::ssl;
 
-bool SSLContext::initialized_ = false;
-
-namespace {
-
-std::mutex& initMutex() {
-  static std::mutex m;
-  return m;
-}
-
-} // anonymous namespace
-
-#ifdef OPENSSL_NPN_NEGOTIATED
-int SSLContext::sNextProtocolsExDataIndex_ = -1;
-#endif
-
 // SSLContext implementation
 SSLContext::SSLContext(SSLVersion version) {
-  {
-    std::lock_guard<std::mutex> g(initMutex());
-    initializeOpenSSLLocked();
-  }
+  folly::ssl::init();
 
   ctx_ = SSL_CTX_new(SSLv23_method());
   if (ctx_ == nullptr) {
@@ -70,6 +49,10 @@ SSLContext::SSLContext(SSLVersion version) {
     case SSLv3:
       opt = SSL_OP_NO_SSLv2;
       break;
+    case TLSv1_2:
+      opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 |
+          SSL_OP_NO_TLSv1_1;
+      break;
     default:
       // do nothing
       break;
@@ -101,34 +84,9 @@ SSLContext::~SSLContext() {
 }
 
 void SSLContext::ciphers(const std::string& ciphers) {
-  providedCiphersString_ = ciphers;
   setCiphersOrThrow(ciphers);
 }
 
-void SSLContext::setCipherList(const std::vector<std::string>& ciphers) {
-  if (ciphers.size() == 0) {
-    return;
-  }
-  std::string opensslCipherList;
-  join(":", ciphers, opensslCipherList);
-  setCiphersOrThrow(opensslCipherList);
-}
-
-void SSLContext::setSignatureAlgorithms(
-    const std::vector<std::string>& sigalgs) {
-  if (sigalgs.size() == 0) {
-    return;
-  }
-#if OPENSSL_VERSION_NUMBER >= 0x1000200fL
-  std::string opensslSigAlgsList;
-  join(":", sigalgs, opensslSigAlgsList);
-  int rc = SSL_CTX_set1_sigalgs_list(ctx_, opensslSigAlgsList.c_str());
-  if (rc == 0) {
-    throw std::runtime_error("SSL_CTX_set1_sigalgs_list " + getErrors());
-  }
-#endif
-}
-
 void SSLContext::setClientECCurvesList(
     const std::vector<std::string>& ecCurves) {
   if (ecCurves.size() == 0) {
@@ -187,6 +145,7 @@ void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
   if (rc == 0) {
     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
   }
+  providedCiphersString_ = ciphers;
 }
 
 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
@@ -229,7 +188,8 @@ void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
                               const std::string& peerName) {
   int mode;
   if (checkPeerCert) {
-    mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
+    mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
+        SSL_VERIFY_CLIENT_ONCE;
     checkPeerName_ = checkPeerName;
     peerFixedName_ = peerName;
   } else {
@@ -246,7 +206,7 @@ void SSLContext::loadCertificate(const char* path, const char* format) {
          "loadCertificateChain: either <path> or <format> is nullptr");
   }
   if (strcmp(format, "PEM") == 0) {
-    if (SSL_CTX_use_certificate_chain_file(ctx_, path) == 0) {
+    if (SSL_CTX_use_certificate_chain_file(ctx_, path) != 1) {
       int errnoCopy = errno;
       std::string reason("SSL_CTX_use_certificate_chain_file: ");
       reason.append(path);
@@ -255,7 +215,8 @@ void SSLContext::loadCertificate(const char* path, const char* format) {
       throw std::runtime_error(reason);
     }
   } else {
-    throw std::runtime_error("Unsupported certificate format: " + std::string(format));
+    throw std::runtime_error(
+        "Unsupported certificate format: " + std::string(format));
   }
 }
 
@@ -295,7 +256,8 @@ void SSLContext::loadPrivateKey(const char* path, const char* format) {
       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
     }
   } else {
-    throw std::runtime_error("Unsupported private key format: " + std::string(format));
+    throw std::runtime_error(
+        "Unsupported private key format: " + std::string(format));
   }
 }
 
@@ -325,6 +287,32 @@ void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
   }
 }
 
+void SSLContext::loadCertKeyPairFromBufferPEM(
+    folly::StringPiece cert,
+    folly::StringPiece pkey) {
+  loadCertificateFromBufferPEM(cert);
+  loadPrivateKeyFromBufferPEM(pkey);
+  if (!isCertKeyPairValid()) {
+    throw std::runtime_error("SSL certificate and private key do not match");
+  }
+}
+
+void SSLContext::loadCertKeyPairFromFiles(
+    const char* certPath,
+    const char* keyPath,
+    const char* certFormat,
+    const char* keyFormat) {
+  loadCertificate(certPath, certFormat);
+  loadPrivateKey(keyPath, keyFormat);
+  if (!isCertKeyPairValid()) {
+    throw std::runtime_error("SSL certificate and private key do not match");
+  }
+}
+
+bool SSLContext::isCertKeyPairValid() const {
+  return SSL_CTX_check_private_key(ctx_) == 1;
+}
+
 void SSLContext::loadTrustedCertificates(const char* path) {
   if (path == nullptr) {
     throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
@@ -348,11 +336,8 @@ void SSLContext::loadClientCAList(const char* path) {
   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
 }
 
-void SSLContext::randomize() {
-  RAND_poll();
-}
-
-void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
+void SSLContext::passwordCollector(
+    std::shared_ptr<PasswordCollector> collector) {
   if (collector == nullptr) {
     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
     return;
@@ -408,55 +393,6 @@ int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
 
   return SSL_TLSEXT_ERR_NOACK;
 }
-
-void SSLContext::switchCiphersIfTLS11(
-    SSL* ssl,
-    const std::string& tls11CipherString,
-    const std::vector<std::pair<std::string, int>>& tls11AltCipherlist) {
-  CHECK(!(tls11CipherString.empty() && tls11AltCipherlist.empty()))
-      << "Shouldn't call if empty ciphers / alt ciphers";
-
-  if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
-    // We only do this for TLS v 1.1 and later
-    return;
-  }
-
-  const std::string* ciphers = &tls11CipherString;
-  if (!tls11AltCipherlist.empty()) {
-    if (!cipherListPicker_) {
-      std::vector<int> weights;
-      std::for_each(
-          tls11AltCipherlist.begin(),
-          tls11AltCipherlist.end(),
-          [&](const std::pair<std::string, int>& e) {
-            weights.push_back(e.second);
-          });
-      cipherListPicker_.reset(
-          new std::discrete_distribution<int>(weights.begin(), weights.end()));
-    }
-    auto rng = ThreadLocalPRNG();
-    auto index = (*cipherListPicker_)(rng);
-    if ((size_t)index >= tls11AltCipherlist.size()) {
-      LOG(ERROR) << "Trying to pick alt TLS11 cipher index " << index
-                 << ", but tls11AltCipherlist is of length "
-                 << tls11AltCipherlist.size();
-    } else {
-      ciphers = &tls11AltCipherlist[size_t(index)].first;
-    }
-  }
-
-  // Prefer AES for TLS versions 1.1 and later since these are not
-  // vulnerable to BEAST attacks on AES.  Note that we're setting the
-  // cipher list on the SSL object, not the SSL_CTX object, so it will
-  // only last for this request.
-  int rc = SSL_set_cipher_list(ssl, ciphers->c_str());
-  if ((rc == 0) || ERR_peek_error() != 0) {
-    // This shouldn't happen since we checked for this when proxygen
-    // started up.
-    LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
-    SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
-  }
-}
 #endif // FOLLY_OPENSSL_HAS_SNI
 
 #if FOLLY_OPENSSL_HAS_ALPN
@@ -581,6 +517,9 @@ size_t SSLContext::pickNextProtocols() {
 
 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
       const unsigned char** out, unsigned int* outlen, void* data) {
+  static int nextProtocolsExDataIndex = SSL_get_ex_new_index(
+      0, (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
+
   SSLContext* context = (SSLContext*)data;
   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
     *out = nullptr;
@@ -589,8 +528,8 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
     *out = context->advertisedNextProtocols_[0].protocols;
     *outlen = context->advertisedNextProtocols_[0].length;
   } else {
-    uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
-          sNextProtocolsExDataIndex_));
+    uintptr_t selected_index = reinterpret_cast<uintptr_t>(
+        SSL_get_ex_data(ssl, nextProtocolsExDataIndex));
     if (selected_index) {
       --selected_index;
       *out = context->advertisedNextProtocols_[selected_index].protocols;
@@ -598,7 +537,7 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
     } else {
       auto i = context->pickNextProtocols();
       uintptr_t selected = i + 1;
-      SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
+      SSL_set_ex_data(ssl, nextProtocolsExDataIndex, (void*)selected);
       *out = context->advertisedNextProtocols_[i].protocols;
       *outlen = context->advertisedNextProtocols_[i].length;
     }
@@ -660,8 +599,7 @@ void SSLContext::setSessionCacheContext(const std::string& context) {
       ctx_,
       reinterpret_cast<const unsigned char*>(context.data()),
       std::min<unsigned int>(
-          static_cast<unsigned int>(context.length()),
-          SSL_MAX_SSL_SESSION_ID_LENGTH));
+          static_cast<unsigned int>(context.length()), SSL_MAX_SID_CTX_LENGTH));
 }
 
 /**
@@ -708,107 +646,9 @@ int SSLContext::passwordCallback(char* password,
   std::string userPassword;
   // call user defined password collector to get password
   context->passwordCollector()->getPassword(userPassword, size);
-  auto length = int(userPassword.size());
-  if (length > size) {
-    length = size;
-  }
-  strncpy(password, userPassword.c_str(), size_t(length));
-  return length;
-}
-
-struct SSLLock {
-  explicit SSLLock(
-    SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
-      lockType(inLockType) {
-  }
-
-  void lock() {
-    if (lockType == SSLContext::LOCK_MUTEX) {
-      mutex.lock();
-    } else if (lockType == SSLContext::LOCK_SPINLOCK) {
-      spinLock.lock();
-    }
-    // lockType == LOCK_NONE, no-op
-  }
-
-  void unlock() {
-    if (lockType == SSLContext::LOCK_MUTEX) {
-      mutex.unlock();
-    } else if (lockType == SSLContext::LOCK_SPINLOCK) {
-      spinLock.unlock();
-    }
-    // lockType == LOCK_NONE, no-op
-  }
-
-  SSLContext::SSLLockType lockType;
-  folly::SpinLock spinLock{};
-  std::mutex mutex;
-};
-
-// Statics are unsafe in environments that call exit().
-// If one thread calls exit() while another thread is
-// references a member of SSLContext, bad things can happen.
-// SSLContext runs in such environments.
-// Instead of declaring a static member we "new" the static
-// member so that it won't be destructed on exit().
-static std::unique_ptr<SSLLock[]>& locks() {
-  static auto locksInst = new std::unique_ptr<SSLLock[]>();
-  return *locksInst;
-}
-
-static std::map<int, SSLContext::SSLLockType>& lockTypes() {
-  static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
-  return *lockTypesInst;
-}
-
-static void callbackLocking(int mode, int n, const char*, int) {
-  if (mode & CRYPTO_LOCK) {
-    locks()[size_t(n)].lock();
-  } else {
-    locks()[size_t(n)].unlock();
-  }
-}
-
-static unsigned long callbackThreadID() {
-  return static_cast<unsigned long>(folly::getCurrentThreadID());
-}
-
-static CRYPTO_dynlock_value* dyn_create(const char*, int) {
-  return new CRYPTO_dynlock_value;
-}
-
-static void dyn_lock(int mode,
-                     struct CRYPTO_dynlock_value* lock,
-                     const char*, int) {
-  if (lock != nullptr) {
-    if (mode & CRYPTO_LOCK) {
-      lock->mutex.lock();
-    } else {
-      lock->mutex.unlock();
-    }
-  }
-}
-
-static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
-  delete lock;
-}
-
-void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
-  if (initialized_) {
-    // We set the locks on initialization, so if we are already initialized
-    // this would have no affect.
-    LOG(INFO) << "Ignoring setSSLLockTypes after initialization";
-    return;
-  }
-
-  lockTypes() = inLockTypes;
-}
-
-bool SSLContext::isSSLLockDisabled(int lockId) {
-  const auto& sslLocks = lockTypes();
-  const auto it = sslLocks.find(lockId);
-  return it != sslLocks.end() &&
-      it->second == SSLContext::SSLLockType::LOCK_NONE;
+  auto const length = std::min(userPassword.size(), size_t(size));
+  std::memcpy(password, userPassword.data(), length);
+  return int(length);
 }
 
 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
@@ -817,63 +657,8 @@ void SSLContext::enableFalseStart() {
 }
 #endif
 
-void SSLContext::markInitialized() {
-  std::lock_guard<std::mutex> g(initMutex());
-  initialized_ = true;
-}
-
 void SSLContext::initializeOpenSSL() {
-  std::lock_guard<std::mutex> g(initMutex());
-  initializeOpenSSLLocked();
-}
-
-void SSLContext::initializeOpenSSLLocked() {
-  if (initialized_) {
-    return;
-  }
-  SSL_library_init();
-  SSL_load_error_strings();
-  ERR_load_crypto_strings();
-  // static locking
-  locks().reset(new SSLLock[size_t(CRYPTO_num_locks())]);
-  for (auto it: lockTypes()) {
-    locks()[size_t(it.first)].lockType = it.second;
-  }
-  CRYPTO_set_id_callback(callbackThreadID);
-  CRYPTO_set_locking_callback(callbackLocking);
-  // dynamic locking
-  CRYPTO_set_dynlock_create_callback(dyn_create);
-  CRYPTO_set_dynlock_lock_callback(dyn_lock);
-  CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
-  randomize();
-#ifdef OPENSSL_NPN_NEGOTIATED
-  sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
-      (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
-#endif
-  initialized_ = true;
-}
-
-void SSLContext::cleanupOpenSSL() {
-  std::lock_guard<std::mutex> g(initMutex());
-  cleanupOpenSSLLocked();
-}
-
-void SSLContext::cleanupOpenSSLLocked() {
-  if (!initialized_) {
-    return;
-  }
-
-  CRYPTO_set_id_callback(nullptr);
-  CRYPTO_set_locking_callback(nullptr);
-  CRYPTO_set_dynlock_create_callback(nullptr);
-  CRYPTO_set_dynlock_lock_callback(nullptr);
-  CRYPTO_set_dynlock_destroy_callback(nullptr);
-  CRYPTO_cleanup_all_ex_data();
-  ERR_free_strings();
-  EVP_cleanup();
-  ERR_clear_error();
-  locks().reset();
-  initialized_ = false;
+  folly::ssl::init();
 }
 
 void SSLContext::setOptions(long options) {
@@ -912,4 +697,4 @@ operator<<(std::ostream& os, const PasswordCollector& collector) {
   return os;
 }
 
-} // folly
+} // namespace folly