D2741855 broke my wangle. Reverting
[folly.git] / folly / io / async / SSLContext.cpp
index c50e257164924cc7106c605267d8a8b08c441ad5..7ab01c301f8be6d56ed9b35a95a0202330b2773b 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2014 Facebook, Inc.
+ * Copyright 2015 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -21,9 +21,8 @@
 #include <openssl/ssl.h>
 #include <openssl/x509v3.h>
 
-#include <folly/SmallLocks.h>
 #include <folly/Format.h>
-#include <folly/io/PortableSpinLock.h>
+#include <folly/SpinLock.h>
 
 // ---------------------------------------------------------------------
 // SSLContext implementation
@@ -35,24 +34,26 @@ struct CRYPTO_dynlock_value {
 
 namespace folly {
 
-uint64_t SSLContext::count_ = 0;
-std::mutex    SSLContext::mutex_;
+bool SSLContext::initialized_ = false;
+
+namespace {
+
+std::mutex& initMutex() {
+  static std::mutex m;
+  return m;
+}
+
+}  // anonymous namespace
+
 #ifdef OPENSSL_NPN_NEGOTIATED
 int SSLContext::sNextProtocolsExDataIndex_ = -1;
-
 #endif
+
 // SSLContext implementation
 SSLContext::SSLContext(SSLVersion version) {
   {
-    std::lock_guard<std::mutex> g(mutex_);
-    if (!count_++) {
-      initializeOpenSSL();
-      randomize();
-#ifdef OPENSSL_NPN_NEGOTIATED
-      sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
-          (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
-#endif
-    }
+    std::lock_guard<std::mutex> g(initMutex());
+    initializeOpenSSLLocked();
   }
 
   ctx_ = SSL_CTX_new(SSLv23_method());
@@ -94,11 +95,6 @@ SSLContext::~SSLContext() {
 #ifdef OPENSSL_NPN_NEGOTIATED
   deleteNextProtocolsStrings();
 #endif
-
-  std::lock_guard<std::mutex> g(mutex_);
-  if (!--count_) {
-    cleanupOpenSSL();
-  }
 }
 
 void SSLContext::ciphers(const std::string& ciphers) {
@@ -309,13 +305,43 @@ void SSLContext::switchCiphersIfTLS11(
 }
 #endif
 
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+int SSLContext::alpnSelectCallback(SSL* ssl,
+                                   const unsigned char** out,
+                                   unsigned char* outlen,
+                                   const unsigned char* in,
+                                   unsigned int inlen,
+                                   void* data) {
+  SSLContext* context = (SSLContext*)data;
+  CHECK(context);
+  if (context->advertisedNextProtocols_.empty()) {
+    *out = nullptr;
+    *outlen = 0;
+  } else {
+    auto i = context->pickNextProtocols();
+    const auto& item = context->advertisedNextProtocols_[i];
+    if (SSL_select_next_proto((unsigned char**)out,
+                              outlen,
+                              item.protocols,
+                              item.length,
+                              in,
+                              inlen) != OPENSSL_NPN_NEGOTIATED) {
+      return SSL_TLSEXT_ERR_NOACK;
+    }
+  }
+  return SSL_TLSEXT_ERR_OK;
+}
+#endif
+
 #ifdef OPENSSL_NPN_NEGOTIATED
-bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
-  return setRandomizedAdvertisedNextProtocols({{1, protocols}});
+
+bool SSLContext::setAdvertisedNextProtocols(
+    const std::list<std::string>& protocols, NextProtocolType protocolType) {
+  return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
 }
 
 bool SSLContext::setRandomizedAdvertisedNextProtocols(
-    const std::list<NextProtocolsItem>& items) {
+    const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
   unsetNextProtocols();
   if (items.size() == 0) {
     return false;
@@ -355,13 +381,23 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols(
     deleteNextProtocolsStrings();
     return false;
   }
-  for (auto &advertised_item : advertisedNextProtocols_) {
+  for (autoadvertised_item : advertisedNextProtocols_) {
     advertised_item.probability /= total_weight;
   }
-  SSL_CTX_set_next_protos_advertised_cb(
-    ctx_, advertisedNextProtocolCallback, this);
-  SSL_CTX_set_next_proto_select_cb(
-    ctx_, selectNextProtocolCallback, this);
+  if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
+    SSL_CTX_set_next_protos_advertised_cb(
+        ctx_, advertisedNextProtocolCallback, this);
+    SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
+  }
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+  if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
+    SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
+    // Client cannot really use randomized alpn
+    SSL_CTX_set_alpn_protos(ctx_,
+                            advertisedNextProtocols_[0].protocols,
+                            advertisedNextProtocols_[0].length);
+  }
+#endif
   return true;
 }
 
@@ -376,6 +412,25 @@ void SSLContext::unsetNextProtocols() {
   deleteNextProtocolsStrings();
   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+  SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
+  SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
+#endif
+}
+
+size_t SSLContext::pickNextProtocols() {
+  unsigned char random_byte;
+  RAND_bytes(&random_byte, 1);
+  double random_value = random_byte / 255.0;
+  double sum = 0;
+  for (size_t i = 0; i < advertisedNextProtocols_.size(); ++i) {
+    sum += advertisedNextProtocols_[i].probability;
+    if (sum < random_value && i + 1 < advertisedNextProtocols_.size()) {
+      continue;
+    }
+    return i;
+  }
+  CHECK(false) << "Failed to pickNextProtocols";
 }
 
 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
@@ -395,27 +450,76 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
       *out = context->advertisedNextProtocols_[selected_index].protocols;
       *outlen = context->advertisedNextProtocols_[selected_index].length;
     } else {
-      unsigned char random_byte;
-      RAND_bytes(&random_byte, 1);
-      double random_value = random_byte / 255.0;
-      double sum = 0;
-      for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
-        sum += context->advertisedNextProtocols_[i].probability;
-        if (sum < random_value &&
-            i + 1 < context->advertisedNextProtocols_.size()) {
-          continue;
-        }
-        uintptr_t selected = i + 1;
-        SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
-        *out = context->advertisedNextProtocols_[i].protocols;
-        *outlen = context->advertisedNextProtocols_[i].length;
-        break;
-      }
+      auto i = context->pickNextProtocols();
+      uintptr_t selected = i + 1;
+      SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
+      *out = context->advertisedNextProtocols_[i].protocols;
+      *outlen = context->advertisedNextProtocols_[i].length;
     }
   }
   return SSL_TLSEXT_ERR_OK;
 }
 
+#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
+  FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
+SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
+  ciphers_{
+    TLS1_CK_DHE_DSS_WITH_AES_128_SHA,
+    TLS1_CK_DHE_RSA_WITH_AES_128_SHA,
+    TLS1_CK_DHE_DSS_WITH_AES_256_SHA,
+    TLS1_CK_DHE_RSA_WITH_AES_256_SHA,
+    TLS1_CK_DHE_DSS_WITH_AES_128_SHA256,
+    TLS1_CK_DHE_RSA_WITH_AES_128_SHA256,
+    TLS1_CK_DHE_DSS_WITH_AES_256_SHA256,
+    TLS1_CK_DHE_RSA_WITH_AES_256_SHA256,
+    TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256,
+    TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384,
+    TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256,
+    TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384,
+    TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+    TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
+    TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+    TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA,
+    TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256,
+    TLS1_CK_ECDHE_ECDSA_WITH_AES_256_SHA384,
+    TLS1_CK_ECDH_ECDSA_WITH_AES_128_SHA256,
+    TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384,
+    TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256,
+    TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384,
+    TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+    TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+    TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+    TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+  } {
+  length_ = sizeof(ciphers_)/sizeof(ciphers_[0]);
+  width_ = sizeof(ciphers_[0]);
+  qsort(ciphers_, length_, width_, compare_ulong);
+}
+
+bool SSLContext::SSLFalseStartChecker::canUseFalseStartWithCipher(
+  const SSL_CIPHER *cipher) {
+  unsigned long cid = cipher->id;
+  unsigned long *r =
+    (unsigned long*)bsearch(&cid, ciphers_, length_, width_, compare_ulong);
+  return r != nullptr;
+}
+
+int
+SSLContext::SSLFalseStartChecker::compare_ulong(const void *x, const void *y) {
+  if (*(unsigned long *)x < *(unsigned long *)y) {
+    return -1;
+  }
+  if (*(unsigned long *)x > *(unsigned long *)y) {
+    return 1;
+  }
+  return 0;
+};
+
+bool SSLContext::canUseFalseStartWithCipher(const SSL_CIPHER *cipher) {
+  return falseStartChecker_.canUseFalseStartWithCipher(cipher);
+}
+#endif
+
 int SSLContext::selectNextProtocolCallback(
   SSL* ssl, unsigned char **out, unsigned char *outlen,
   const unsigned char *server, unsigned int server_len, void *data) {
@@ -427,13 +531,21 @@ int SSLContext::selectNextProtocolCallback(
   }
 
   unsigned char *client;
-  int client_len;
-  if (ctx->advertisedNextProtocols_.empty()) {
-    client = (unsigned char *) "";
-    client_len = 0;
-  } else {
-    client = ctx->advertisedNextProtocols_[0].protocols;
-    client_len = ctx->advertisedNextProtocols_[0].length;
+  unsigned int client_len;
+  bool filtered = false;
+  auto cpf = ctx->getClientProtocolFilterCallback();
+  if (cpf) {
+    filtered = (*cpf)(&client, &client_len, server, server_len);
+  }
+
+  if (!filtered) {
+    if (ctx->advertisedNextProtocols_.empty()) {
+      client = (unsigned char *) "";
+      client_len = 0;
+    } else {
+      client = ctx->advertisedNextProtocols_[0].protocols;
+      client_len = ctx->advertisedNextProtocols_[0].length;
+    }
   }
 
   int retval = SSL_select_next_proto(out, outlen, server, server_len,
@@ -441,6 +553,14 @@ int SSLContext::selectNextProtocolCallback(
   if (retval != OPENSSL_NPN_NEGOTIATED) {
     VLOG(3) << "SSLContext::selectNextProcolCallback() "
             << "unable to pick a next protocol.";
+#if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
+  FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
+  } else {
+    const SSL_CIPHER *cipher = ssl->s3->tmp.new_cipher;
+    if (cipher && ctx->canUseFalseStartWithCipher(cipher)) {
+      SSL_set_mode(ssl, SSL_MODE_HANDSHAKE_CUTTHROUGH);
+    }
+#endif
   }
   return SSL_TLSEXT_ERR_OK;
 }
@@ -531,23 +651,42 @@ struct SSLLock {
   }
 
   SSLContext::SSLLockType lockType;
-  folly::io::PortableSpinLock spinLock{};
+  folly::SpinLock spinLock{};
   std::mutex mutex;
 };
 
-static std::map<int, SSLContext::SSLLockType> lockTypes;
-static std::unique_ptr<SSLLock[]> locks;
+// Statics are unsafe in environments that call exit().
+// If one thread calls exit() while another thread is
+// references a member of SSLContext, bad things can happen.
+// SSLContext runs in such environments.
+// Instead of declaring a static member we "new" the static
+// member so that it won't be destructed on exit().
+static std::unique_ptr<SSLLock[]>& locks() {
+  static auto locksInst = new std::unique_ptr<SSLLock[]>();
+  return *locksInst;
+}
+
+static std::map<int, SSLContext::SSLLockType>& lockTypes() {
+  static auto lockTypesInst = new std::map<int, SSLContext::SSLLockType>();
+  return *lockTypesInst;
+}
 
 static void callbackLocking(int mode, int n, const char*, int) {
   if (mode & CRYPTO_LOCK) {
-    locks[n].lock();
+    locks()[n].lock();
   } else {
-    locks[n].unlock();
+    locks()[n].unlock();
   }
 }
 
 static unsigned long callbackThreadID() {
-  return static_cast<unsigned long>(pthread_self());
+  return static_cast<unsigned long>(
+#ifdef __APPLE__
+    pthread_mach_thread_np(pthread_self())
+#else
+    pthread_self()
+#endif
+  );
 }
 
 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
@@ -571,17 +710,30 @@ static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
 }
 
 void SSLContext::setSSLLockTypes(std::map<int, SSLLockType> inLockTypes) {
-  lockTypes = inLockTypes;
+  lockTypes() = inLockTypes;
+}
+
+void SSLContext::markInitialized() {
+  std::lock_guard<std::mutex> g(initMutex());
+  initialized_ = true;
 }
 
 void SSLContext::initializeOpenSSL() {
+  std::lock_guard<std::mutex> g(initMutex());
+  initializeOpenSSLLocked();
+}
+
+void SSLContext::initializeOpenSSLLocked() {
+  if (initialized_) {
+    return;
+  }
   SSL_library_init();
   SSL_load_error_strings();
   ERR_load_crypto_strings();
   // static locking
-  locks.reset(new SSLLock[::CRYPTO_num_locks()]);
-  for (auto it: lockTypes) {
-    locks[it.first].lockType = it.second;
+  locks().reset(new SSLLock[::CRYPTO_num_locks()]);
+  for (auto it: lockTypes()) {
+    locks()[it.first].lockType = it.second;
   }
   CRYPTO_set_id_callback(callbackThreadID);
   CRYPTO_set_locking_callback(callbackLocking);
@@ -589,9 +741,24 @@ void SSLContext::initializeOpenSSL() {
   CRYPTO_set_dynlock_create_callback(dyn_create);
   CRYPTO_set_dynlock_lock_callback(dyn_lock);
   CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
+  randomize();
+#ifdef OPENSSL_NPN_NEGOTIATED
+  sNextProtocolsExDataIndex_ = SSL_get_ex_new_index(0,
+      (void*)"Advertised next protocol index", nullptr, nullptr, nullptr);
+#endif
+  initialized_ = true;
 }
 
 void SSLContext::cleanupOpenSSL() {
+  std::lock_guard<std::mutex> g(initMutex());
+  cleanupOpenSSLLocked();
+}
+
+void SSLContext::cleanupOpenSSLLocked() {
+  if (!initialized_) {
+    return;
+  }
+
   CRYPTO_set_id_callback(nullptr);
   CRYPTO_set_locking_callback(nullptr);
   CRYPTO_set_dynlock_create_callback(nullptr);
@@ -601,7 +768,8 @@ void SSLContext::cleanupOpenSSL() {
   ERR_free_strings();
   EVP_cleanup();
   ERR_remove_state(0);
-  locks.reset();
+  locks().reset();
+  initialized_ = false;
 }
 
 void SSLContext::setOptions(long options) {
@@ -640,4 +808,83 @@ operator<<(std::ostream& os, const PasswordCollector& collector) {
   return os;
 }
 
+bool OpenSSLUtils::getPeerAddressFromX509StoreCtx(X509_STORE_CTX* ctx,
+                                                  sockaddr_storage* addrStorage,
+                                                  socklen_t* addrLen) {
+  // Grab the ssl idx and then the ssl object so that we can get the peer
+  // name to compare against the ips in the subjectAltName
+  auto sslIdx = SSL_get_ex_data_X509_STORE_CTX_idx();
+  auto ssl =
+    reinterpret_cast<SSL*>(X509_STORE_CTX_get_ex_data(ctx, sslIdx));
+  int fd = SSL_get_fd(ssl);
+  if (fd < 0) {
+    LOG(ERROR) << "Inexplicably couldn't get fd from SSL";
+    return false;
+  }
+
+  *addrLen = sizeof(*addrStorage);
+  if (getpeername(fd, reinterpret_cast<sockaddr*>(addrStorage), addrLen) != 0) {
+    PLOG(ERROR) << "Unable to get peer name";
+    return false;
+  }
+  CHECK(*addrLen <= sizeof(*addrStorage));
+  return true;
+}
+
+bool OpenSSLUtils::validatePeerCertNames(X509* cert,
+                                         const sockaddr* addr,
+                                         socklen_t addrLen) {
+  // Try to extract the names within the SAN extension from the certificate
+  auto altNames =
+    reinterpret_cast<STACK_OF(GENERAL_NAME)*>(
+        X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr));
+  SCOPE_EXIT {
+    if (altNames != nullptr) {
+      sk_GENERAL_NAME_pop_free(altNames, GENERAL_NAME_free);
+    }
+  };
+  if (altNames == nullptr) {
+    LOG(WARNING) << "No subjectAltName provided and we only support ip auth";
+    return false;
+  }
+
+  const sockaddr_in* addr4 = nullptr;
+  const sockaddr_in6* addr6 = nullptr;
+  if (addr != nullptr) {
+    if (addr->sa_family == AF_INET) {
+      addr4 = reinterpret_cast<const sockaddr_in*>(addr);
+    } else if (addr->sa_family == AF_INET6) {
+      addr6 = reinterpret_cast<const sockaddr_in6*>(addr);
+    } else {
+      LOG(FATAL) << "Unsupported sockaddr family: " << addr->sa_family;
+    }
+  }
+
+
+  for (int i = 0; i < sk_GENERAL_NAME_num(altNames); i++) {
+    auto name = sk_GENERAL_NAME_value(altNames, i);
+    if ((addr4 != nullptr || addr6 != nullptr) && name->type == GEN_IPADD) {
+      // Extra const-ness for paranoia
+      unsigned char const * const rawIpStr = name->d.iPAddress->data;
+      int const rawIpLen = name->d.iPAddress->length;
+
+      if (rawIpLen == 4 && addr4 != nullptr) {
+        if (::memcmp(rawIpStr, &addr4->sin_addr, rawIpLen) == 0) {
+          return true;
+        }
+      } else if (rawIpLen == 16 && addr6 != nullptr) {
+        if (::memcmp(rawIpStr, &addr6->sin6_addr, rawIpLen) == 0) {
+          return true;
+        }
+      } else if (rawIpLen != 4 && rawIpLen != 16) {
+        LOG(WARNING) << "Unexpected IP length: " << rawIpLen;
+      }
+    }
+  }
+
+  LOG(WARNING) << "Unable to match client cert against alt name ip";
+  return false;
+}
+
+
 } // folly