Revise API to load cert/key in SSLContext.
[folly.git] / folly / io / async / SSLContext.cpp
index 80d414b0ea9bee14d9dcd099e61b0d7460cff6f6..7ce05f44a7a062c974ffe8a65dd17c1bc1293f40 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2014 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 
 #include "SSLContext.h"
 
-#include <openssl/err.h>
-#include <openssl/rand.h>
-#include <openssl/ssl.h>
-#include <openssl/x509v3.h>
-
 #include <folly/Format.h>
-#include <folly/io/PortableSpinLock.h>
+#include <folly/Memory.h>
+#include <folly/Random.h>
+#include <folly/SharedMutex.h>
+#include <folly/SpinLock.h>
+#include <folly/ssl/Init.h>
+#include <folly/system/ThreadId.h>
 
 // ---------------------------------------------------------------------
 // SSLContext implementation
 // ---------------------------------------------------------------------
-
-struct CRYPTO_dynlock_value {
-  std::mutex mutex;
-};
-
 namespace folly {
-
-bool SSLContext::initialized_ = false;
-std::mutex    SSLContext::mutex_;
-#ifdef OPENSSL_NPN_NEGOTIATED
-int SSLContext::sNextProtocolsExDataIndex_ = -1;
-#endif
-
-#ifndef SSLCONTEXT_NO_REFCOUNT
-uint64_t SSLContext::count_ = 0;
-#endif
+//
+// For OpenSSL portability API
+using namespace folly::ssl;
 
 // SSLContext implementation
 SSLContext::SSLContext(SSLVersion version) {
-  {
-    std::lock_guard<std::mutex> g(mutex_);
-#ifndef SSLCONTEXT_NO_REFCOUNT
-    count_++;
-#endif
-    initializeOpenSSLLocked();
-  }
+  folly::ssl::init();
 
   ctx_ = SSL_CTX_new(SSLv23_method());
   if (ctx_ == nullptr) {
@@ -67,6 +49,10 @@ SSLContext::SSLContext(SSLVersion version) {
     case SSLv3:
       opt = SSL_OP_NO_SSLv2;
       break;
+    case TLSv1_2:
+      opt = SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_TLSv1 |
+          SSL_OP_NO_TLSv1_1;
+      break;
     default:
       // do nothing
       break;
@@ -78,7 +64,9 @@ SSLContext::SSLContext(SSLVersion version) {
 
   checkPeerName_ = false;
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+  SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
+
+#if FOLLY_OPENSSL_HAS_SNI
   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
 #endif
@@ -93,30 +81,71 @@ SSLContext::~SSLContext() {
 #ifdef OPENSSL_NPN_NEGOTIATED
   deleteNextProtocolsStrings();
 #endif
+}
 
-#ifndef SSLCONTEXT_NO_REFCOUNT
-  {
-    std::lock_guard<std::mutex> g(mutex_);
-    if (!--count_) {
-      cleanupOpenSSLLocked();
-    }
+void SSLContext::ciphers(const std::string& ciphers) {
+  setCiphersOrThrow(ciphers);
+}
+
+void SSLContext::setClientECCurvesList(
+    const std::vector<std::string>& ecCurves) {
+  if (ecCurves.size() == 0) {
+    return;
+  }
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL
+  std::string ecCurvesList;
+  join(":", ecCurves, ecCurvesList);
+  int rc = SSL_CTX_set1_curves_list(ctx_, ecCurvesList.c_str());
+  if (rc == 0) {
+    throw std::runtime_error("SSL_CTX_set1_curves_list " + getErrors());
   }
 #endif
 }
 
-void SSLContext::ciphers(const std::string& ciphers) {
-  providedCiphersString_ = ciphers;
-  setCiphersOrThrow(ciphers);
+void SSLContext::setServerECCurve(const std::string& curveName) {
+#if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
+  EC_KEY* ecdh = nullptr;
+  int nid;
+
+  /*
+   * Elliptic-Curve Diffie-Hellman parameters are either "named curves"
+   * from RFC 4492 section 5.1.1, or explicitly described curves over
+   * binary fields. OpenSSL only supports the "named curves", which provide
+   * maximum interoperability.
+   */
+
+  nid = OBJ_sn2nid(curveName.c_str());
+  if (nid == 0) {
+    LOG(FATAL) << "Unknown curve name:" << curveName.c_str();
+  }
+  ecdh = EC_KEY_new_by_curve_name(nid);
+  if (ecdh == nullptr) {
+    LOG(FATAL) << "Unable to create curve:" << curveName.c_str();
+  }
+
+  SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
+  EC_KEY_free(ecdh);
+#else
+  throw std::runtime_error("Elliptic curve encryption not allowed");
+#endif
+}
+
+void SSLContext::setX509VerifyParam(
+    const ssl::X509VerifyParam& x509VerifyParam) {
+  if (!x509VerifyParam) {
+    return;
+  }
+  if (SSL_CTX_set1_param(ctx_, x509VerifyParam.get()) != 1) {
+    throw std::runtime_error("SSL_CTX_set1_param " + getErrors());
+  }
 }
 
 void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
   int rc = SSL_CTX_set_cipher_list(ctx_, ciphers.c_str());
-  if (ERR_peek_error() != 0) {
-    throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
-  }
   if (rc == 0) {
-    throw std::runtime_error("None of specified ciphers are supported");
+    throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
   }
+  providedCiphersString_ = ciphers;
 }
 
 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
@@ -159,7 +188,8 @@ void SSLContext::authenticate(bool checkPeerCert, bool checkPeerName,
                               const std::string& peerName) {
   int mode;
   if (checkPeerCert) {
-    mode  = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
+    mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT |
+        SSL_VERIFY_CLIENT_ONCE;
     checkPeerName_ = checkPeerName;
     peerFixedName_ = peerName;
   } else {
@@ -185,32 +215,106 @@ void SSLContext::loadCertificate(const char* path, const char* format) {
       throw std::runtime_error(reason);
     }
   } else {
-    throw std::runtime_error("Unsupported certificate format: " + std::string(format));
+    throw std::runtime_error(
+        "Unsupported certificate format: " + std::string(format));
+  }
+}
+
+void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
+  if (cert.data() == nullptr) {
+    throw std::invalid_argument("loadCertificate: <cert> is nullptr");
+  }
+
+  ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
+  if (bio == nullptr) {
+    throw std::runtime_error("BIO_new: " + getErrors());
+  }
+
+  int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
+  if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
+    throw std::runtime_error("BIO_write: " + getErrors());
+  }
+
+  ssl::X509UniquePtr x509(
+      PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
+  if (x509 == nullptr) {
+    throw std::runtime_error("PEM_read_bio_X509: " + getErrors());
+  }
+
+  if (SSL_CTX_use_certificate(ctx_, x509.get()) == 0) {
+    throw std::runtime_error("SSL_CTX_use_certificate: " + getErrors());
   }
 }
 
 void SSLContext::loadPrivateKey(const char* path, const char* format) {
   if (path == nullptr || format == nullptr) {
     throw std::invalid_argument(
-         "loadPrivateKey: either <path> or <format> is nullptr");
+        "loadPrivateKey: either <path> or <format> is nullptr");
   }
   if (strcmp(format, "PEM") == 0) {
     if (SSL_CTX_use_PrivateKey_file(ctx_, path, SSL_FILETYPE_PEM) == 0) {
       throw std::runtime_error("SSL_CTX_use_PrivateKey_file: " + getErrors());
     }
   } else {
-    throw std::runtime_error("Unsupported private key format: " + std::string(format));
+    throw std::runtime_error(
+        "Unsupported private key format: " + std::string(format));
+  }
+}
+
+void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
+  if (pkey.data() == nullptr) {
+    throw std::invalid_argument("loadPrivateKey: <pkey> is nullptr");
   }
+
+  ssl::BioUniquePtr bio(BIO_new(BIO_s_mem()));
+  if (bio == nullptr) {
+    throw std::runtime_error("BIO_new: " + getErrors());
+  }
+
+  int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
+  if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
+    throw std::runtime_error("BIO_write: " + getErrors());
+  }
+
+  ssl::EvpPkeyUniquePtr key(
+      PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
+  if (key == nullptr) {
+    throw std::runtime_error("PEM_read_bio_PrivateKey: " + getErrors());
+  }
+
+  if (SSL_CTX_use_PrivateKey(ctx_, key.get()) == 0) {
+    throw std::runtime_error("SSL_CTX_use_PrivateKey: " + getErrors());
+  }
+}
+
+void SSLContext::loadCertKeyPairFromBufferPEM(
+    folly::StringPiece cert,
+    folly::StringPiece pkey) {
+  loadCertificateFromBufferPEM(cert);
+  loadPrivateKeyFromBufferPEM(pkey);
+}
+
+void SSLContext::loadCertKeyPairFromFiles(
+    const char* certPath,
+    const char* keyPath,
+    const char* certFormat,
+    const char* keyFormat) {
+  loadCertificate(certPath, certFormat);
+  loadPrivateKey(keyPath, keyFormat);
+}
+
+bool SSLContext::isCertKeyPairValid() const {
+  return SSL_CTX_check_private_key(ctx_) == 1;
 }
 
 void SSLContext::loadTrustedCertificates(const char* path) {
   if (path == nullptr) {
-    throw std::invalid_argument(
-         "loadTrustedCertificates: <path> is nullptr");
+    throw std::invalid_argument("loadTrustedCertificates: <path> is nullptr");
   }
   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
   }
+  ERR_clear_error();
 }
 
 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
@@ -226,11 +330,8 @@ void SSLContext::loadClientCAList(const char* path) {
   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
 }
 
-void SSLContext::randomize() {
-  RAND_poll();
-}
-
-void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector) {
+void SSLContext::passwordCollector(
+    std::shared_ptr<PasswordCollector> collector) {
   if (collector == nullptr) {
     LOG(ERROR) << "passwordCollector: ignore invalid password collector";
     return;
@@ -240,7 +341,7 @@ void SSLContext::passwordCollector(std::shared_ptr<PasswordCollector> collector)
   SSL_CTX_set_default_passwd_cb_userdata(ctx_, this);
 }
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
 
 void SSLContext::setServerNameCallback(const ServerNameCallback& cb) {
   serverNameCb_ = cb;
@@ -286,39 +387,45 @@ int SSLContext::baseServerNameOpenSSLCallback(SSL* ssl, int* al, void* data) {
 
   return SSL_TLSEXT_ERR_NOACK;
 }
+#endif // FOLLY_OPENSSL_HAS_SNI
 
-void SSLContext::switchCiphersIfTLS11(
-    SSL* ssl,
-    const std::string& tls11CipherString) {
-
-  CHECK(!tls11CipherString.empty()) << "Shouldn't call if empty alt ciphers";
-
-  if (TLS1_get_client_version(ssl) <= TLS1_VERSION) {
-    // We only do this for TLS v 1.1 and later
-    return;
-  }
-
-  // Prefer AES for TLS versions 1.1 and later since these are not
-  // vulnerable to BEAST attacks on AES.  Note that we're setting the
-  // cipher list on the SSL object, not the SSL_CTX object, so it will
-  // only last for this request.
-  int rc = SSL_set_cipher_list(ssl, tls11CipherString.c_str());
-  if ((rc == 0) || ERR_peek_error() != 0) {
-    // This shouldn't happen since we checked for this when proxygen
-    // started up.
-    LOG(WARNING) << "ssl_cipher: No specified ciphers supported for switch";
-    SSL_set_cipher_list(ssl, providedCiphersString_.c_str());
+#if FOLLY_OPENSSL_HAS_ALPN
+int SSLContext::alpnSelectCallback(SSL* /* ssl */,
+                                   const unsigned char** out,
+                                   unsigned char* outlen,
+                                   const unsigned char* in,
+                                   unsigned int inlen,
+                                   void* data) {
+  SSLContext* context = (SSLContext*)data;
+  CHECK(context);
+  if (context->advertisedNextProtocols_.empty()) {
+    *out = nullptr;
+    *outlen = 0;
+  } else {
+    auto i = context->pickNextProtocols();
+    const auto& item = context->advertisedNextProtocols_[i];
+    if (SSL_select_next_proto((unsigned char**)out,
+                              outlen,
+                              item.protocols,
+                              item.length,
+                              in,
+                              inlen) != OPENSSL_NPN_NEGOTIATED) {
+      return SSL_TLSEXT_ERR_NOACK;
+    }
   }
+  return SSL_TLSEXT_ERR_OK;
 }
-#endif
+#endif // FOLLY_OPENSSL_HAS_ALPN
 
 #ifdef OPENSSL_NPN_NEGOTIATED
-bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
-  return setRandomizedAdvertisedNextProtocols({{1, protocols}});
+
+bool SSLContext::setAdvertisedNextProtocols(
+    const std::list<std::string>& protocols, NextProtocolType protocolType) {
+  return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
 }
 
 bool SSLContext::setRandomizedAdvertisedNextProtocols(
-    const std::list<NextProtocolsItem>& items) {
+    const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
   unsetNextProtocols();
   if (items.size() == 0) {
     return false;
@@ -332,12 +439,12 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols(
     advertised_item.length = 0;
     for (const auto& proto : item.protocols) {
       ++advertised_item.length;
-      unsigned protoLength = proto.length();
+      auto protoLength = proto.length();
       if (protoLength >= 256) {
         deleteNextProtocolsStrings();
         return false;
       }
-      advertised_item.length += protoLength;
+      advertised_item.length += unsigned(protoLength);
     }
     advertised_item.protocols = new unsigned char[advertised_item.length];
     if (!advertised_item.protocols) {
@@ -345,26 +452,36 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols(
     }
     unsigned char* dst = advertised_item.protocols;
     for (auto& proto : item.protocols) {
-      unsigned protoLength = proto.length();
+      uint8_t protoLength = uint8_t(proto.length());
       *dst++ = (unsigned char)protoLength;
       memcpy(dst, proto.data(), protoLength);
       dst += protoLength;
     }
     total_weight += item.weight;
-    advertised_item.probability = item.weight;
     advertisedNextProtocols_.push_back(advertised_item);
+    advertisedNextProtocolWeights_.push_back(item.weight);
   }
   if (total_weight == 0) {
     deleteNextProtocolsStrings();
     return false;
   }
-  for (auto &advertised_item : advertisedNextProtocols_) {
-    advertised_item.probability /= total_weight;
+  nextProtocolDistribution_ =
+      std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
+                                   advertisedNextProtocolWeights_.end());
+  if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
+    SSL_CTX_set_next_protos_advertised_cb(
+        ctx_, advertisedNextProtocolCallback, this);
+    SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
+  }
+#if FOLLY_OPENSSL_HAS_ALPN
+  if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
+    SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
+    // Client cannot really use randomized alpn
+    SSL_CTX_set_alpn_protos(ctx_,
+                            advertisedNextProtocols_[0].protocols,
+                            advertisedNextProtocols_[0].length);
   }
-  SSL_CTX_set_next_protos_advertised_cb(
-    ctx_, advertisedNextProtocolCallback, this);
-  SSL_CTX_set_next_proto_select_cb(
-    ctx_, selectNextProtocolCallback, this);
+#endif
   return true;
 }
 
@@ -373,16 +490,30 @@ void SSLContext::deleteNextProtocolsStrings() {
     delete[] protocols.protocols;
   }
   advertisedNextProtocols_.clear();
+  advertisedNextProtocolWeights_.clear();
 }
 
 void SSLContext::unsetNextProtocols() {
   deleteNextProtocolsStrings();
   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
+#if FOLLY_OPENSSL_HAS_ALPN
+  SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
+  SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
+#endif
+}
+
+size_t SSLContext::pickNextProtocols() {
+  CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
+  auto rng = ThreadLocalPRNG();
+  return size_t(nextProtocolDistribution_(rng));
 }
 
 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
       const unsigned char** out, unsigned int* outlen, void* data) {
+  static int nextProtocolsExDataIndex = SSL_get_ex_new_index(
+      0, (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
+
   SSLContext* context = (SSLContext*)data;
   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
     *out = nullptr;
@@ -391,52 +522,52 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
     *out = context->advertisedNextProtocols_[0].protocols;
     *outlen = context->advertisedNextProtocols_[0].length;
   } else {
-    uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
-          sNextProtocolsExDataIndex_));
+    uintptr_t selected_index = reinterpret_cast<uintptr_t>(
+        SSL_get_ex_data(ssl, nextProtocolsExDataIndex));
     if (selected_index) {
       --selected_index;
       *out = context->advertisedNextProtocols_[selected_index].protocols;
       *outlen = context->advertisedNextProtocols_[selected_index].length;
     } else {
-      unsigned char random_byte;
-      RAND_bytes(&random_byte, 1);
-      double random_value = random_byte / 255.0;
-      double sum = 0;
-      for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
-        sum += context->advertisedNextProtocols_[i].probability;
-        if (sum < random_value &&
-            i + 1 < context->advertisedNextProtocols_.size()) {
-          continue;
-        }
-        uintptr_t selected = i + 1;
-        SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
-        *out = context->advertisedNextProtocols_[i].protocols;
-        *outlen = context->advertisedNextProtocols_[i].length;
-        break;
-      }
+      auto i = context->pickNextProtocols();
+      uintptr_t selected = i + 1;
+      SSL_set_ex_data(ssl, nextProtocolsExDataIndex, (void*)selected);
+      *out = context->advertisedNextProtocols_[i].protocols;
+      *outlen = context->advertisedNextProtocols_[i].length;
     }
   }
   return SSL_TLSEXT_ERR_OK;
 }
 
-int SSLContext::selectNextProtocolCallback(
-  SSL* ssl, unsigned char **out, unsigned char *outlen,
-  const unsigned char *server, unsigned int server_len, void *data) {
-
+int SSLContext::selectNextProtocolCallback(SSL* ssl,
+                                           unsigned char** out,
+                                           unsigned char* outlen,
+                                           const unsigned char* server,
+                                           unsigned int server_len,
+                                           void* data) {
+  (void)ssl; // Make -Wunused-parameters happy
   SSLContext* ctx = (SSLContext*)data;
   if (ctx->advertisedNextProtocols_.size() > 1) {
     VLOG(3) << "SSLContext::selectNextProcolCallback() "
             << "client should be deterministic in selecting protocols.";
   }
 
-  unsigned char *client;
-  int client_len;
-  if (ctx->advertisedNextProtocols_.empty()) {
-    client = (unsigned char *) "";
-    client_len = 0;
-  } else {
-    client = ctx->advertisedNextProtocols_[0].protocols;
-    client_len = ctx->advertisedNextProtocols_[0].length;
+  unsigned char* client = nullptr;
+  unsigned int client_len = 0;
+  bool filtered = false;
+  auto cpf = ctx->getClientProtocolFilterCallback();
+  if (cpf) {
+    filtered = (*cpf)(&client, &client_len, server, server_len);
+  }
+
+  if (!filtered) {
+    if (ctx->advertisedNextProtocols_.empty()) {
+      client = (unsigned char *) "";
+      client_len = 0;
+    } else {
+      client = ctx->advertisedNextProtocols_[0].protocols;
+      client_len = ctx->advertisedNextProtocols_[0].length;
+    }
   }
 
   int retval = SSL_select_next_proto(out, outlen, server, server_len,
@@ -457,6 +588,14 @@ SSL* SSLContext::createSSL() const {
   return ssl;
 }
 
+void SSLContext::setSessionCacheContext(const std::string& context) {
+  SSL_CTX_set_session_id_context(
+      ctx_,
+      reinterpret_cast<const unsigned char*>(context.data()),
+      std::min<unsigned int>(
+          static_cast<unsigned int>(context.length()), SSL_MAX_SID_CTX_LENGTH));
+}
+
 /**
  * Match a name with a pattern. The pattern may include wildcard. A single
  * wildcard "*" can match up to one component in the domain name.
@@ -501,157 +640,19 @@ int SSLContext::passwordCallback(char* password,
   std::string userPassword;
   // call user defined password collector to get password
   context->passwordCollector()->getPassword(userPassword, size);
-  int length = userPassword.size();
-  if (length > size) {
-    length = size;
-  }
-  strncpy(password, userPassword.c_str(), length);
-  return length;
-}
-
-struct SSLLock {
-  explicit SSLLock(
-    SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
-      lockType(inLockType) {
-  }
-
-  void lock() {
-    if (lockType == SSLContext::LOCK_MUTEX) {
-      mutex.lock();
-    } else if (lockType == SSLContext::LOCK_SPINLOCK) {
-      spinLock.lock();
-    }
-    // lockType == LOCK_NONE, no-op
-  }
-
-  void unlock() {
-    if (lockType == SSLContext::LOCK_MUTEX) {
-      mutex.unlock();
-    } else if (lockType == SSLContext::LOCK_SPINLOCK) {
-      spinLock.unlock();
-    }
-    // lockType == LOCK_NONE, no-op
-  }
-
-  SSLContext::SSLLockType lockType;
-  folly::io::PortableSpinLock spinLock{};
-  std::mutex mutex;
-};
-
-// Statics are unsafe in environments that call exit().
-// If one thread calls exit() while another thread is
-// references a member of SSLContext, bad things can happen.
-// SSLContext runs in such environments.
-// Instead of declaring a static member we "new" the static
-// member so that it won't be destructed on exit().
-static std::map<int, SSLContext::SSLLockType>* lockTypesInst =
-  new std::map<int, SSLContext::SSLLockType>();
-
-static std::unique_ptr<SSLLock[]>* locksInst =
-  new std::unique_ptr<SSLLock[]>();
-
-static std::unique_ptr<SSLLock[]>& locks() {
-  return *locksInst;
+  auto const length = std::min(userPassword.size(), size_t(size));
+  std::memcpy(password, userPassword.data(), length);
+  return int(length);
 }
 
-static std::map<int, SSLContext::SSLLockType>& lockTypes() {
-  return *lockTypesInst;
-}
-
-static void callbackLocking(int mode, int n, const char*, int) {
-  if (mode & CRYPTO_LOCK) {
-    locks()[n].lock();
-  } else {
-    locks()[n].unlock();
-  }
+#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
+void SSLContext::enableFalseStart() {
+  SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
 }
-
-static unsigned long callbackThreadID() {
-  return static_cast<unsigned long>(
-#ifdef __APPLE__
-    pthread_mach_thread_np(pthread_self())
-#else
-    pthread_self()
 #endif
-  );
-}
-
-static CRYPTO_dynlock_value* dyn_create(const char*, int) {
-  return new CRYPTO_dynlock_value;
-}
-
-static void dyn_lock(int mode,
-                     struct CRYPTO_dynlock_value* lock,
-                     const char*, int) {
-  if (lock != nullptr) {
-    if (mode & CRYPTO_LOCK) {
-      lock->mutex.lock();
-    } else {
-      lock->mutex.unlock();
-    }
-  }
-}
-
-static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
-  delete lock;
-}
-
-void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
-  lockTypes() = inLockTypes;
-}
 
 void SSLContext::initializeOpenSSL() {
-  std::lock_guard<std::mutex> g(mutex_);
-  initializeOpenSSLLocked();
-}
-
-void SSLContext::initializeOpenSSLLocked() {
-  if (initialized_) {
-    return;
-  }
-  SSL_library_init();
-  SSL_load_error_strings();
-  ERR_load_crypto_strings();
-  // static locking
-  locks().reset(new SSLLock[::CRYPTO_num_locks()]);
-  for (auto it: lockTypes()) {
-    locks()[it.first].lockType = it.second;
-  }
-  CRYPTO_set_id_callback(callbackThreadID);
-  CRYPTO_set_locking_callback(callbackLocking);
-  // dynamic locking
-  CRYPTO_set_dynlock_create_callback(dyn_create);
-  CRYPTO_set_dynlock_lock_callback(dyn_lock);
-  CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
-  randomize();
-#ifdef OPENSSL_NPN_NEGOTIATED
-  sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
-      (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
-#endif
-  initialized_ = true;
-}
-
-void SSLContext::cleanupOpenSSL() {
-  std::lock_guard<std::mutex> g(mutex_);
-  cleanupOpenSSLLocked();
-}
-
-void SSLContext::cleanupOpenSSLLocked() {
-  if (!initialized_) {
-    return;
-  }
-
-  CRYPTO_set_id_callback(nullptr);
-  CRYPTO_set_locking_callback(nullptr);
-  CRYPTO_set_dynlock_create_callback(nullptr);
-  CRYPTO_set_dynlock_lock_callback(nullptr);
-  CRYPTO_set_dynlock_destroy_callback(nullptr);
-  CRYPTO_cleanup_all_ex_data();
-  ERR_free_strings();
-  EVP_cleanup();
-  ERR_remove_state(0);
-  locks().reset();
-  initialized_ = false;
+  folly::ssl::init();
 }
 
 void SSLContext::setOptions(long options) {
@@ -690,4 +691,4 @@ operator<<(std::ostream& os, const PasswordCollector& collector) {
   return os;
 }
 
-} // folly
+} // namespace folly