Update providedCiphersStr_ in one place.
[folly.git] / folly / io / async / SSLContext.cpp
index 7a0c9993fdc81dfe35085d1c9570fc2770e85c06..5ef22353efd3cde0e8950278bf7268fa23f1d7ec 100644 (file)
 
 #include "SSLContext.h"
 
-#include <openssl/err.h>
-#include <openssl/rand.h>
-#include <openssl/ssl.h>
-#include <openssl/x509v3.h>
-
 #include <folly/Format.h>
 #include <folly/Memory.h>
 #include <folly/Random.h>
+#include <folly/SharedMutex.h>
 #include <folly/SpinLock.h>
+#include <folly/ThreadId.h>
 
 // ---------------------------------------------------------------------
 // SSLContext implementation
@@ -105,7 +102,6 @@ SSLContext::~SSLContext() {
 }
 
 void SSLContext::ciphers(const std::string& ciphers) {
-  providedCiphersString_ = ciphers;
   setCiphersOrThrow(ciphers);
 }
 
@@ -149,16 +145,7 @@ void SSLContext::setClientECCurvesList(
 }
 
 void SSLContext::setServerECCurve(const std::string& curveName) {
-  bool validCall = false;
-#if OPENSSL_VERSION_NUMBER >= 0x0090800fL
-#ifndef OPENSSL_NO_ECDH
-  validCall = true;
-#endif
-#endif
-  if (!validCall) {
-    throw std::runtime_error("Elliptic curve encryption not allowed");
-  }
-
+#if OPENSSL_VERSION_NUMBER >= 0x0090800fL && !defined(OPENSSL_NO_ECDH)
   EC_KEY* ecdh = nullptr;
   int nid;
 
@@ -180,6 +167,9 @@ void SSLContext::setServerECCurve(const std::string& curveName) {
 
   SSL_CTX_set_tmp_ecdh(ctx_, ecdh);
   EC_KEY_free(ecdh);
+#else
+  throw std::runtime_error("Elliptic curve encryption not allowed");
+#endif
 }
 
 void SSLContext::setX509VerifyParam(
@@ -197,6 +187,7 @@ void SSLContext::setCiphersOrThrow(const std::string& ciphers) {
   if (rc == 0) {
     throw std::runtime_error("SSL_CTX_set_cipher_list: " + getErrors());
   }
+  providedCiphersString_ = ciphers;
 }
 
 void SSLContext::setVerificationOption(const SSLContext::SSLVerifyPeerEnum&
@@ -342,6 +333,7 @@ void SSLContext::loadTrustedCertificates(const char* path) {
   if (SSL_CTX_load_verify_locations(ctx_, path, nullptr) == 0) {
     throw std::runtime_error("SSL_CTX_load_verify_locations: " + getErrors());
   }
+  ERR_clear_error();
 }
 
 void SSLContext::loadTrustedCertificates(X509_STORE* store) {
@@ -450,7 +442,7 @@ void SSLContext::switchCiphersIfTLS11(
                  << ", but tls11AltCipherlist is of length "
                  << tls11AltCipherlist.size();
     } else {
-      ciphers = &tls11AltCipherlist[index].first;
+      ciphers = &tls11AltCipherlist[size_t(index)].first;
     }
   }
 
@@ -585,7 +577,7 @@ void SSLContext::unsetNextProtocols() {
 size_t SSLContext::pickNextProtocols() {
   CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
   auto rng = ThreadLocalPRNG();
-  return nextProtocolDistribution_(rng);
+  return size_t(nextProtocolDistribution_(rng));
 }
 
 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
@@ -668,8 +660,9 @@ void SSLContext::setSessionCacheContext(const std::string& context) {
   SSL_CTX_set_session_id_context(
       ctx_,
       reinterpret_cast<const unsigned char*>(context.data()),
-      std::min(
-          static_cast<int>(context.length()), SSL_MAX_SSL_SESSION_ID_LENGTH));
+      std::min<unsigned int>(
+          static_cast<unsigned int>(context.length()),
+          SSL_MAX_SSL_SESSION_ID_LENGTH));
 }
 
 /**
@@ -720,7 +713,7 @@ int SSLContext::passwordCallback(char* password,
   if (length > size) {
     length = size;
   }
-  strncpy(password, userPassword.c_str(), length);
+  strncpy(password, userPassword.c_str(), size_t(length));
   return length;
 }
 
@@ -730,20 +723,32 @@ struct SSLLock {
       lockType(inLockType) {
   }
 
-  void lock() {
+  void lock(bool read) {
     if (lockType == SSLContext::LOCK_MUTEX) {
       mutex.lock();
     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
       spinLock.lock();
+    } else if (lockType == SSLContext::LOCK_SHAREDMUTEX) {
+      if (read) {
+        sharedMutex.lock_shared();
+      } else {
+        sharedMutex.lock();
+      }
     }
     // lockType == LOCK_NONE, no-op
   }
 
-  void unlock() {
+  void unlock(bool read) {
     if (lockType == SSLContext::LOCK_MUTEX) {
       mutex.unlock();
     } else if (lockType == SSLContext::LOCK_SPINLOCK) {
       spinLock.unlock();
+    } else if (lockType == SSLContext::LOCK_SHAREDMUTEX) {
+      if (read) {
+        sharedMutex.unlock_shared();
+      } else {
+        sharedMutex.unlock();
+      }
     }
     // lockType == LOCK_NONE, no-op
   }
@@ -751,6 +756,7 @@ struct SSLLock {
   SSLContext::SSLLockType lockType;
   folly::SpinLock spinLock{};
   std::mutex mutex;
+  SharedMutex sharedMutex;
 };
 
 // Statics are unsafe in environments that call exit().
@@ -771,22 +777,14 @@ static std::map<int, SSLContext::SSLLockType>& lockTypes() {
 
 static void callbackLocking(int mode, int n, const char*, int) {
   if (mode & CRYPTO_LOCK) {
-    locks()[n].lock();
+    locks()[size_t(n)].lock(mode & CRYPTO_READ);
   } else {
-    locks()[n].unlock();
+    locks()[size_t(n)].unlock(mode & CRYPTO_READ);
   }
 }
 
 static unsigned long callbackThreadID() {
-  return static_cast<unsigned long>(
-#ifdef __APPLE__
-    pthread_mach_thread_np(pthread_self())
-#elif _MSC_VER
-    pthread_getw32threadid_np(pthread_self())
-#else
-    pthread_self()
-#endif
-  );
+  return static_cast<unsigned long>(folly::getCurrentThreadID());
 }
 
 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
@@ -809,10 +807,38 @@ static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
   delete lock;
 }
 
-void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
+void SSLContext::setSSLLockTypesLocked(std::map<int, SSLLockType> inLockTypes) {
   lockTypes() = inLockTypes;
 }
 
+void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
+  std::lock_guard<std::mutex> g(initMutex());
+  if (initialized_) {
+    // We set the locks on initialization, so if we are already initialized
+    // this would have no affect.
+    LOG(INFO) << "Ignoring setSSLLockTypes after initialization";
+    return;
+  }
+  setSSLLockTypesLocked(std::move(inLockTypes));
+}
+
+void SSLContext::setSSLLockTypesAndInitOpenSSL(
+    std::map<int, SSLLockType> inLockTypes) {
+  std::lock_guard<std::mutex> g(initMutex());
+  CHECK(!initialized_) << "OpenSSL is already initialized";
+  setSSLLockTypesLocked(std::move(inLockTypes));
+  initializeOpenSSLLocked();
+}
+
+bool SSLContext::isSSLLockDisabled(int lockId) {
+  std::lock_guard<std::mutex> g(initMutex());
+  CHECK(initialized_) << "OpenSSL is not initialized yet";
+  const auto& sslLocks = lockTypes();
+  const auto it = sslLocks.find(lockId);
+  return it != sslLocks.end() &&
+      it->second == SSLContext::SSLLockType::LOCK_NONE;
+}
+
 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH)
 void SSLContext::enableFalseStart() {
   SSL_CTX_set_mode(ctx_, SSL_MODE_HANDSHAKE_CUTTHROUGH);
@@ -837,9 +863,9 @@ void SSLContext::initializeOpenSSLLocked() {
   SSL_load_error_strings();
   ERR_load_crypto_strings();
   // static locking
-  locks().reset(new SSLLock[::CRYPTO_num_locks()]);
+  locks().reset(new SSLLock[size_t(CRYPTO_num_locks())]);
   for (auto it: lockTypes()) {
-    locks()[it.first].lockType = it.second;
+    locks()[size_t(it.first)].lockType = it.second;
   }
   CRYPTO_set_id_callback(callbackThreadID);
   CRYPTO_set_locking_callback(callbackLocking);
@@ -873,7 +899,7 @@ void SSLContext::cleanupOpenSSLLocked() {
   CRYPTO_cleanup_all_ex_data();
   ERR_free_strings();
   EVP_cleanup();
-  ERR_remove_state(0);
+  ERR_clear_error();
   locks().reset();
   initialized_ = false;
 }