X-Git-Url: http://plrg.eecs.uci.edu/git/?p=folly.git;a=blobdiff_plain;f=folly%2Fio%2Fasync%2FSSLContext.cpp;h=7ab01c301f8be6d56ed9b35a95a0202330b2773b;hp=4e8ea69f5a4b5d82a1d73b416fa33fbcf7785c51;hb=17d04308e64ee7a11ad68f4b4b4c03498c3c8844;hpb=97c7b417342e8c941aedfaf811fab0332718cd01 diff --git a/folly/io/async/SSLContext.cpp b/folly/io/async/SSLContext.cpp index 4e8ea69f..7ab01c30 100644 --- a/folly/io/async/SSLContext.cpp +++ b/folly/io/async/SSLContext.cpp @@ -84,10 +84,6 @@ SSLContext::SSLContext(SSLVersion version) { SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback); SSL_CTX_set_tlsext_servername_arg(ctx_, this); #endif - -#ifdef OPENSSL_NPN_NEGOTIATED - Random::seed(nextProtocolPicker_); -#endif } SSLContext::~SSLContext() { @@ -378,16 +374,16 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols( dst += protoLength; } total_weight += item.weight; + advertised_item.probability = item.weight; advertisedNextProtocols_.push_back(advertised_item); - advertisedNextProtocolWeights_.push_back(item.weight); } if (total_weight == 0) { deleteNextProtocolsStrings(); return false; } - nextProtocolDistribution_ = - std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(), - advertisedNextProtocolWeights_.end()); + for (auto& advertised_item : advertisedNextProtocols_) { + advertised_item.probability /= total_weight; + } if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) { SSL_CTX_set_next_protos_advertised_cb( ctx_, advertisedNextProtocolCallback, this); @@ -410,7 +406,6 @@ void SSLContext::deleteNextProtocolsStrings() { delete[] protocols.protocols; } advertisedNextProtocols_.clear(); - advertisedNextProtocolWeights_.clear(); } void SSLContext::unsetNextProtocols() { @@ -424,8 +419,18 @@ void SSLContext::unsetNextProtocols() { } size_t SSLContext::pickNextProtocols() { - CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols"; - return nextProtocolDistribution_(nextProtocolPicker_); + unsigned char random_byte; + RAND_bytes(&random_byte, 1); + double random_value = random_byte / 255.0; + double sum = 0; + for (size_t i = 0; i < advertisedNextProtocols_.size(); ++i) { + sum += advertisedNextProtocols_[i].probability; + if (sum < random_value && i + 1 < advertisedNextProtocols_.size()) { + continue; + } + return i; + } + CHECK(false) << "Failed to pickNextProtocols"; } int SSLContext::advertisedNextProtocolCallback(SSL* ssl,