Update SSLContext to use discrete_distribution
authorNeel Goyal <ngoyal@fb.com>
Thu, 10 Dec 2015 00:21:19 +0000 (16:21 -0800)
committerfacebook-github-bot-4 <folly-bot@fb.com>
Thu, 10 Dec 2015 01:20:24 +0000 (17:20 -0800)
Summary: Update the protocol pick logic to use discrete_distribution

Reviewed By: siyengar

Differential Revision: D2741855

fb-gh-sync-id: 244bd087124a7a9584a1108fe8f8150093275878

folly/io/async/SSLContext.cpp
folly/io/async/SSLContext.h

index def95ee41cda6520b1d249cb7d72c8235154f002..4e8ea69f5a4b5d82a1d73b416fa33fbcf7785c51 100644 (file)
@@ -84,6 +84,10 @@ SSLContext::SSLContext(SSLVersion version) {
   SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
   SSL_CTX_set_tlsext_servername_arg(ctx_, this);
 #endif
+
+#ifdef OPENSSL_NPN_NEGOTIATED
+  Random::seed(nextProtocolPicker_);
+#endif
 }
 
 SSLContext::~SSLContext() {
@@ -374,16 +378,16 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols(
       dst += protoLength;
     }
     total_weight += item.weight;
-    advertised_item.probability = item.weight;
     advertisedNextProtocols_.push_back(advertised_item);
+    advertisedNextProtocolWeights_.push_back(item.weight);
   }
   if (total_weight == 0) {
     deleteNextProtocolsStrings();
     return false;
   }
-  for (auto &advertised_item : advertisedNextProtocols_) {
-    advertised_item.probability /= total_weight;
-  }
+  nextProtocolDistribution_ =
+      std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(),
+                                   advertisedNextProtocolWeights_.end());
   if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
     SSL_CTX_set_next_protos_advertised_cb(
         ctx_, advertisedNextProtocolCallback, this);
@@ -406,6 +410,7 @@ void SSLContext::deleteNextProtocolsStrings() {
     delete[] protocols.protocols;
   }
   advertisedNextProtocols_.clear();
+  advertisedNextProtocolWeights_.clear();
 }
 
 void SSLContext::unsetNextProtocols() {
@@ -419,18 +424,8 @@ void SSLContext::unsetNextProtocols() {
 }
 
 size_t SSLContext::pickNextProtocols() {
-  unsigned char random_byte;
-  RAND_bytes(&random_byte, 1);
-  double random_value = random_byte / 255.0;
-  double sum = 0;
-  for (size_t i = 0; i < advertisedNextProtocols_.size(); ++i) {
-    sum += advertisedNextProtocols_[i].probability;
-    if (sum < random_value && i + 1 < advertisedNextProtocols_.size()) {
-      continue;
-    }
-    return i;
-  }
-  CHECK(false) << "Failed to pickNextProtocols";
+  CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols";
+  return nextProtocolDistribution_(nextProtocolPicker_);
 }
 
 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
index a4b44b2dbea619c7be12a3ac7484c2444269ce56..e20b093b4007839c95722e47e931f35ce49adac7 100644 (file)
@@ -22,6 +22,7 @@
 #include <vector>
 #include <memory>
 #include <string>
+#include <random>
 
 #include <openssl/ssl.h>
 #include <openssl/tls1.h>
@@ -35,6 +36,8 @@
 #include <folly/folly-config.h>
 #endif
 
+#include <folly/Random.h>
+
 namespace folly {
 
 /**
@@ -87,12 +90,6 @@ class SSLContext {
     std::list<std::string> protocols;
   };
 
-  struct AdvertisedNextProtocolsItem {
-    unsigned char *protocols;
-    unsigned length;
-    double probability;
-  };
-
   // Function that selects a client protocol given the server's list
   using ClientProtocolFilterCallback = bool (*)(unsigned char**, unsigned int*,
                                         const unsigned char*, unsigned int);
@@ -458,10 +455,20 @@ class SSLContext {
   static bool initialized_;
 
 #ifdef OPENSSL_NPN_NEGOTIATED
+
+  struct AdvertisedNextProtocolsItem {
+    unsigned char* protocols;
+    unsigned length;
+  };
+
   /**
    * Wire-format list of advertised protocols for use in NPN.
    */
   std::vector<AdvertisedNextProtocolsItem> advertisedNextProtocols_;
+  std::vector<int> advertisedNextProtocolWeights_;
+  std::discrete_distribution<int> nextProtocolDistribution_;
+  Random::DefaultGenerator nextProtocolPicker_;
+
   static int sNextProtocolsExDataIndex_;
 
   static int advertisedNextProtocolCallback(SSL* ssl,