Move OpenSSL locking code out of SSLContext
[folly.git] / folly / io / async / SSLContext.cpp
index 1ccd7308ade0f2839f50006e8d2e9ffbe7313c54..95ae99a12edff49cdaa921f2cf243f769b0503b8 100644 (file)
 #include <folly/SharedMutex.h>
 #include <folly/SpinLock.h>
 #include <folly/ThreadId.h>
+#include <folly/ssl/Init.h>
 
 // ---------------------------------------------------------------------
 // SSLContext implementation
 // ---------------------------------------------------------------------
-
-struct CRYPTO_dynlock_value {
-  std::mutex mutex;
-};
-
 namespace folly {
 //
 // For OpenSSL portability API
 using namespace folly::ssl;
 
-bool SSLContext::initialized_ = false;
-
-namespace {
-
-std::mutex& initMutex() {
-  static std::mutex m;
-  return m;
-}
-
-} // anonymous namespace
-
-#ifdef OPENSSL_NPN_NEGOTIATED
-int SSLContext::sNextProtocolsExDataIndex_ = -1;
-#endif
-
 // SSLContext implementation
 SSLContext::SSLContext(SSLVersion version) {
-  {
-    std::lock_guard<std::mutex> g(initMutex());
-    initializeOpenSSLLocked();
-  }
+  folly::ssl::init();
 
   ctx_ = SSL_CTX_new(SSLv23_method());
   if (ctx_ == nullptr) {
@@ -352,10 +330,6 @@ void SSLContext::loadClientCAList(const char* path) {
   SSL_CTX_set_client_CA_list(ctx_, clientCAs);
 }
 
-void SSLContext::randomize() {
-  RAND_poll();
-}
-
 void SSLContext::passwordCollector(
     std::shared_ptr<PasswordCollector> collector) {
   if (collector == nullptr) {
@@ -586,6 +560,9 @@ size_t SSLContext::pickNextProtocols() {
 
 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
       const unsigned char** out, unsigned int* outlen, void* data) {
+  static int nextProtocolsExDataIndex = SSL_get_ex_new_index(
+      0, (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
+
   SSLContext* context = (SSLContext*)data;
   if (context == nullptr || context->advertisedNextProtocols_.empty()) {
     *out = nullptr;
@@ -594,8 +571,8 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
     *out = context->advertisedNextProtocols_[0].protocols;
     *outlen = context->advertisedNextProtocols_[0].length;
   } else {
-    uintptr_t selected_index = reinterpret_cast<uintptr_t>(SSL_get_ex_data(ssl,
-          sNextProtocolsExDataIndex_));
+    uintptr_t selected_index = reinterpret_cast<uintptr_t>(
+        SSL_get_ex_data(ssl, nextProtocolsExDataIndex));
     if (selected_index) {
       --selected_index;
       *out = context->advertisedNextProtocols_[selected_index].protocols;
@@ -603,7 +580,7 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
     } else {
       auto i = context->pickNextProtocols();
       uintptr_t selected = i + 1;
-      SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
+      SSL_set_ex_data(ssl, nextProtocolsExDataIndex, (void*)selected);
       *out = context->advertisedNextProtocols_[i].protocols;
       *outlen = context->advertisedNextProtocols_[i].length;
     }
@@ -721,126 +698,8 @@ int SSLContext::passwordCallback(char* password,
   return length;
 }
 
-struct SSLLock {
-  explicit SSLLock(
-    SSLContext::SSLLockType inLockType = SSLContext::LOCK_MUTEX) :
-      lockType(inLockType) {
-  }
-
-  void lock(bool read) {
-    if (lockType == SSLContext::LOCK_MUTEX) {
-      mutex.lock();
-    } else if (lockType == SSLContext::LOCK_SPINLOCK) {
-      spinLock.lock();
-    } else if (lockType == SSLContext::LOCK_SHAREDMUTEX) {
-      if (read) {
-        sharedMutex.lock_shared();
-      } else {
-        sharedMutex.lock();
-      }
-    }
-    // lockType == LOCK_NONE, no-op
-  }
-
-  void unlock(bool read) {
-    if (lockType == SSLContext::LOCK_MUTEX) {
-      mutex.unlock();
-    } else if (lockType == SSLContext::LOCK_SPINLOCK) {
-      spinLock.unlock();
-    } else if (lockType == SSLContext::LOCK_SHAREDMUTEX) {
-      if (read) {
-        sharedMutex.unlock_shared();
-      } else {
-        sharedMutex.unlock();
-      }
-    }
-    // lockType == LOCK_NONE, no-op
-  }
-
-  SSLContext::SSLLockType lockType;
-  folly::SpinLock spinLock{};
-  std::mutex mutex;
-  SharedMutex sharedMutex;
-};
-
-// Statics are unsafe in environments that call exit().
-// If one thread calls exit() while another thread is
-// references a member of SSLContext, bad things can happen.
-// SSLContext runs in such environments.
-// Instead of declaring a static member we "new" the static
-// member so that it won't be destructed on exit().
-static std::unique_ptr<SSLLock[]>& locks() {
-  static auto locksInst = new std::unique_ptr<SSLLock[]>();
-  return *locksInst;
-}
-
-static std::map<int, SSLContext::SSLLockType>& lockTypes() {
-  static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
-  return *lockTypesInst;
-}
-
-static void callbackLocking(int mode, int n, const char*, int) {
-  if (mode & CRYPTO_LOCK) {
-    locks()[size_t(n)].lock(mode & CRYPTO_READ);
-  } else {
-    locks()[size_t(n)].unlock(mode & CRYPTO_READ);
-  }
-}
-
-static unsigned long callbackThreadID() {
-  return static_cast<unsigned long>(folly::getCurrentThreadID());
-}
-
-static CRYPTO_dynlock_value* dyn_create(const char*, int) {
-  return new CRYPTO_dynlock_value;
-}
-
-static void dyn_lock(int mode,
-                     struct CRYPTO_dynlock_value* lock,
-                     const char*, int) {
-  if (lock != nullptr) {
-    if (mode & CRYPTO_LOCK) {
-      lock->mutex.lock();
-    } else {
-      lock->mutex.unlock();
-    }
-  }
-}
-
-static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
-  delete lock;
-}
-
-void SSLContext::setSSLLockTypesLocked(std::map<int, SSLLockType> inLockTypes) {
-  lockTypes() = inLockTypes;
-}
-
-void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
-  std::lock_guard<std::mutex> g(initMutex());
-  if (initialized_) {
-    // We set the locks on initialization, so if we are already initialized
-    // this would have no affect.
-    LOG(INFO) << "Ignoring setSSLLockTypes after initialization";
-    return;
-  }
-  setSSLLockTypesLocked(std::move(inLockTypes));
-}
-
-void SSLContext::setSSLLockTypesAndInitOpenSSL(
-    std::map<int, SSLLockType> inLockTypes) {
-  std::lock_guard<std::mutex> g(initMutex());
-  CHECK(!initialized_) << "OpenSSL is already initialized";
-  setSSLLockTypesLocked(std::move(inLockTypes));
-  initializeOpenSSLLocked();
-}
-
-bool SSLContext::isSSLLockDisabled(int lockId) {
-  std::lock_guard<std::mutex> g(initMutex());
-  CHECK(initialized_) << "OpenSSL is not initialized yet";
-  const auto& sslLocks = lockTypes();
-  const auto it = sslLocks.find(lockId);
-  return it != sslLocks.end() &&
-      it->second == SSLContext::SSLLockType::LOCK_NONE;
+void SSLContext::setSSLLockTypes(std::map<int, LockType> inLockTypes) {
+  folly::ssl::setLockTypes(inLockTypes);
 }
 
 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
@@ -849,63 +708,8 @@ void SSLContext::enableFalseStart() {
 }
 #endif
 
-void SSLContext::markInitialized() {
-  std::lock_guard<std::mutex> g(initMutex());
-  initialized_ = true;
-}
-
 void SSLContext::initializeOpenSSL() {
-  std::lock_guard<std::mutex> g(initMutex());
-  initializeOpenSSLLocked();
-}
-
-void SSLContext::initializeOpenSSLLocked() {
-  if (initialized_) {
-    return;
-  }
-  SSL_library_init();
-  SSL_load_error_strings();
-  ERR_load_crypto_strings();
-  // static locking
-  locks().reset(new SSLLock[size_t(CRYPTO_num_locks())]);
-  for (auto it: lockTypes()) {
-    locks()[size_t(it.first)].lockType = it.second;
-  }
-  CRYPTO_set_id_callback(callbackThreadID);
-  CRYPTO_set_locking_callback(callbackLocking);
-  // dynamic locking
-  CRYPTO_set_dynlock_create_callback(dyn_create);
-  CRYPTO_set_dynlock_lock_callback(dyn_lock);
-  CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
-  randomize();
-#ifdef OPENSSL_NPN_NEGOTIATED
-  sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
-      (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
-#endif
-  initialized_ = true;
-}
-
-void SSLContext::cleanupOpenSSL() {
-  std::lock_guard<std::mutex> g(initMutex());
-  cleanupOpenSSLLocked();
-}
-
-void SSLContext::cleanupOpenSSLLocked() {
-  if (!initialized_) {
-    return;
-  }
-
-  CRYPTO_set_id_callback(nullptr);
-  CRYPTO_set_locking_callback(nullptr);
-  CRYPTO_set_dynlock_create_callback(nullptr);
-  CRYPTO_set_dynlock_lock_callback(nullptr);
-  CRYPTO_set_dynlock_destroy_callback(nullptr);
-  CRYPTO_cleanup_all_ex_data();
-  ERR_free_strings();
-  EVP_cleanup();
-  ERR_clear_error();
-  locks().reset();
-  initialized_ = false;
+  folly::ssl::init();
 }
 
 void SSLContext::setOptions(long options) {