D2741855 broke my wangle. Reverting
[folly.git] / folly / io / async / SSLContext.cpp
index 34dbc91775f16b353b38259364a7699614f385cb..7ab01c301f8be6d56ed9b35a95a0202330b2773b 100644 (file)
@@ -305,13 +305,43 @@ void SSLContext::switchCiphersIfTLS11(
 }
 #endif
 
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+int SSLContext::alpnSelectCallback(SSL* ssl,
+                                   const unsigned char** out,
+                                   unsigned char* outlen,
+                                   const unsigned char* in,
+                                   unsigned int inlen,
+                                   void* data) {
+  SSLContext* context = (SSLContext*)data;
+  CHECK(context);
+  if (context->advertisedNextProtocols_.empty()) {
+    *out = nullptr;
+    *outlen = 0;
+  } else {
+    auto i = context->pickNextProtocols();
+    const auto& item = context->advertisedNextProtocols_[i];
+    if (SSL_select_next_proto((unsigned char**)out,
+                              outlen,
+                              item.protocols,
+                              item.length,
+                              in,
+                              inlen) != OPENSSL_NPN_NEGOTIATED) {
+      return SSL_TLSEXT_ERR_NOACK;
+    }
+  }
+  return SSL_TLSEXT_ERR_OK;
+}
+#endif
+
 #ifdef OPENSSL_NPN_NEGOTIATED
-bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
-  return setRandomizedAdvertisedNextProtocols({{1, protocols}});
+
+bool SSLContext::setAdvertisedNextProtocols(
+    const std::list<std::string>& protocols, NextProtocolType protocolType) {
+  return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
 }
 
 bool SSLContext::setRandomizedAdvertisedNextProtocols(
-    const std::list<NextProtocolsItem>& items) {
+    const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
   unsetNextProtocols();
   if (items.size() == 0) {
     return false;
@@ -351,13 +381,23 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols(
     deleteNextProtocolsStrings();
     return false;
   }
-  for (auto &advertised_item : advertisedNextProtocols_) {
+  for (autoadvertised_item : advertisedNextProtocols_) {
     advertised_item.probability /= total_weight;
   }
-  SSL_CTX_set_next_protos_advertised_cb(
-    ctx_, advertisedNextProtocolCallback, this);
-  SSL_CTX_set_next_proto_select_cb(
-    ctx_, selectNextProtocolCallback, this);
+  if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
+    SSL_CTX_set_next_protos_advertised_cb(
+        ctx_, advertisedNextProtocolCallback, this);
+    SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
+  }
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+  if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
+    SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
+    // Client cannot really use randomized alpn
+    SSL_CTX_set_alpn_protos(ctx_,
+                            advertisedNextProtocols_[0].protocols,
+                            advertisedNextProtocols_[0].length);
+  }
+#endif
   return true;
 }
 
@@ -372,6 +412,25 @@ void SSLContext::unsetNextProtocols() {
   deleteNextProtocolsStrings();
   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+  SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
+  SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
+#endif
+}
+
+size_t SSLContext::pickNextProtocols() {
+  unsigned char random_byte;
+  RAND_bytes(&random_byte, 1);
+  double random_value = random_byte / 255.0;
+  double sum = 0;
+  for (size_t i = 0; i < advertisedNextProtocols_.size(); ++i) {
+    sum += advertisedNextProtocols_[i].probability;
+    if (sum < random_value && i + 1 < advertisedNextProtocols_.size()) {
+      continue;
+    }
+    return i;
+  }
+  CHECK(false) << "Failed to pickNextProtocols";
 }
 
 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
@@ -391,22 +450,11 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
       *out = context->advertisedNextProtocols_[selected_index].protocols;
       *outlen = context->advertisedNextProtocols_[selected_index].length;
     } else {
-      unsigned char random_byte;
-      RAND_bytes(&random_byte, 1);
-      double random_value = random_byte / 255.0;
-      double sum = 0;
-      for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
-        sum += context->advertisedNextProtocols_[i].probability;
-        if (sum < random_value &&
-            i + 1 < context->advertisedNextProtocols_.size()) {
-          continue;
-        }
-        uintptr_t selected = i + 1;
-        SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
-        *out = context->advertisedNextProtocols_[i].protocols;
-        *outlen = context->advertisedNextProtocols_[i].length;
-        break;
-      }
+      auto i = context->pickNextProtocols();
+      uintptr_t selected = i + 1;
+      SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
+      *out = context->advertisedNextProtocols_[i].protocols;
+      *outlen = context->advertisedNextProtocols_[i].length;
     }
   }
   return SSL_TLSEXT_ERR_OK;
@@ -415,46 +463,21 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
 SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
-  /**
-   * The list was generated as follows:
-   * grep "_CK_" openssl-1.0.1e/ssl/s3_lib.c -A 4 |
-   * while read A && read B && read C && read D && read E && read F; do
-   * echo $A $B $C $D $E; done |
-   * grep "\(SSL_kDHr\|SSL_kDHd\|SSL_kEDH\|SSL_kECDHr\|
-   *         SSL_kECDHe\|SSL_kEECDH\)" | grep -v SSL_aNULL | grep SSL_AES |
-   * awk -F, '{ print $1"," }'
-   */
   ciphers_{
-    TLS1_CK_DH_DSS_WITH_AES_128_SHA,
-    TLS1_CK_DH_RSA_WITH_AES_128_SHA,
     TLS1_CK_DHE_DSS_WITH_AES_128_SHA,
     TLS1_CK_DHE_RSA_WITH_AES_128_SHA,
-    TLS1_CK_DH_DSS_WITH_AES_256_SHA,
-    TLS1_CK_DH_RSA_WITH_AES_256_SHA,
     TLS1_CK_DHE_DSS_WITH_AES_256_SHA,
     TLS1_CK_DHE_RSA_WITH_AES_256_SHA,
-    TLS1_CK_DH_DSS_WITH_AES_128_SHA256,
-    TLS1_CK_DH_RSA_WITH_AES_128_SHA256,
     TLS1_CK_DHE_DSS_WITH_AES_128_SHA256,
     TLS1_CK_DHE_RSA_WITH_AES_128_SHA256,
-    TLS1_CK_DH_DSS_WITH_AES_256_SHA256,
-    TLS1_CK_DH_RSA_WITH_AES_256_SHA256,
     TLS1_CK_DHE_DSS_WITH_AES_256_SHA256,
     TLS1_CK_DHE_RSA_WITH_AES_256_SHA256,
     TLS1_CK_DHE_RSA_WITH_AES_128_GCM_SHA256,
     TLS1_CK_DHE_RSA_WITH_AES_256_GCM_SHA384,
-    TLS1_CK_DH_RSA_WITH_AES_128_GCM_SHA256,
-    TLS1_CK_DH_RSA_WITH_AES_256_GCM_SHA384,
     TLS1_CK_DHE_DSS_WITH_AES_128_GCM_SHA256,
     TLS1_CK_DHE_DSS_WITH_AES_256_GCM_SHA384,
-    TLS1_CK_DH_DSS_WITH_AES_128_GCM_SHA256,
-    TLS1_CK_DH_DSS_WITH_AES_256_GCM_SHA384,
-    TLS1_CK_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
-    TLS1_CK_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
-    TLS1_CK_ECDH_RSA_WITH_AES_128_CBC_SHA,
-    TLS1_CK_ECDH_RSA_WITH_AES_256_CBC_SHA,
     TLS1_CK_ECDHE_RSA_WITH_AES_128_CBC_SHA,
     TLS1_CK_ECDHE_RSA_WITH_AES_256_CBC_SHA,
     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_SHA256,
@@ -463,15 +486,10 @@ SSLContext::SSLFalseStartChecker::SSLFalseStartChecker() :
     TLS1_CK_ECDH_ECDSA_WITH_AES_256_SHA384,
     TLS1_CK_ECDHE_RSA_WITH_AES_128_SHA256,
     TLS1_CK_ECDHE_RSA_WITH_AES_256_SHA384,
-    TLS1_CK_ECDH_RSA_WITH_AES_128_SHA256,
-    TLS1_CK_ECDH_RSA_WITH_AES_256_SHA384,
     TLS1_CK_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
     TLS1_CK_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
-    TLS1_CK_ECDH_ECDSA_WITH_AES_128_GCM_SHA256,
-    TLS1_CK_ECDH_ECDSA_WITH_AES_256_GCM_SHA384,
     TLS1_CK_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
     TLS1_CK_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
-    TLS1_CK_ECDH_RSA_WITH_AES_128_GCM_SHA256,
   } {
   length_ = sizeof(ciphers_)/sizeof(ciphers_[0]);
   width_ = sizeof(ciphers_[0]);