From 7e6064c66cf34a5d02e9312d3d90b43e1185d735 Mon Sep 17 00:00:00 2001 From: Alan Frindell Date: Wed, 9 Dec 2015 11:55:40 -0800 Subject: [PATCH] Add support for ALPN Summary: With openssl-1.0.2 and later add support for ALPN. Clients can request NPN only, but the default is to support either (client will send ALPN list, server will send NPN advertisement if ALPN is not negotiated). Reviewed By: siyengar Differential Revision: D2710441 fb-gh-sync-id: a8efe69e1869bbecb4ed9e0a513448fcfdb21ca6 --- folly/io/async/AsyncSSLSocket.cpp | 26 +- folly/io/async/AsyncSSLSocket.h | 20 +- folly/io/async/SSLContext.cpp | 94 ++++-- folly/io/async/SSLContext.h | 51 +++- folly/io/async/test/AsyncSSLSocketTest.cpp | 331 +++++++++++---------- folly/io/async/test/AsyncSSLSocketTest.h | 12 +- folly/io/async/test/MockAsyncSSLSocket.h | 14 +- 7 files changed, 334 insertions(+), 214 deletions(-) diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index 166a4039..e72894cd 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -55,6 +55,7 @@ using folly::AsyncSocket; using folly::AsyncSocketException; using folly::AsyncSSLSocket; using folly::Optional; +using folly::SSLContext; // We have one single dummy SSL context so that we can implement attach // and detach methods in a thread safe fashion without modifying opnessl. @@ -765,21 +766,36 @@ void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) { } } -void AsyncSSLSocket::getSelectedNextProtocol(const unsigned char** protoName, - unsigned* protoLen) const { - if (!getSelectedNextProtocolNoThrow(protoName, protoLen)) { +void AsyncSSLSocket::getSelectedNextProtocol( + const unsigned char** protoName, + unsigned* protoLen, + SSLContext::NextProtocolType* protoType) const { + if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) { throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED, "NPN not supported"); } } bool AsyncSSLSocket::getSelectedNextProtocolNoThrow( - const unsigned char** protoName, - unsigned* protoLen) const { + const unsigned char** protoName, + unsigned* protoLen, + SSLContext::NextProtocolType* protoType) const { *protoName = nullptr; *protoLen = 0; +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) + SSL_get0_alpn_selected(ssl_, protoName, protoLen); + if (*protoLen > 0) { + if (protoType) { + *protoType = SSLContext::NextProtocolType::ALPN; + } + return true; + } +#endif #ifdef OPENSSL_NPN_NEGOTIATED SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen); + if (protoType) { + *protoType = SSLContext::NextProtocolType::NPN; + } return true; #else return false; diff --git a/folly/io/async/AsyncSSLSocket.h b/folly/io/async/AsyncSSLSocket.h index 5d8a0e6b..ba7485ae 100644 --- a/folly/io/async/AsyncSSLSocket.h +++ b/folly/io/async/AsyncSSLSocket.h @@ -376,7 +376,8 @@ class AsyncSSLSocket : public virtual AsyncSocket { /** * Get the name of the protocol selected by the client during - * Next Protocol Negotiation (NPN) + * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation + * (ALPN) * * Throw an exception if openssl does not support NPN * @@ -386,13 +387,17 @@ class AsyncSSLSocket : public virtual AsyncSocket { * Note: the AsyncSSLSocket retains ownership * of this string. * @param protoNameLen Length of the name. + * @param protoType Whether this was an NPN or ALPN negotiation */ - virtual void getSelectedNextProtocol(const unsigned char** protoName, - unsigned* protoLen) const; + virtual void getSelectedNextProtocol( + const unsigned char** protoName, + unsigned* protoLen, + SSLContext::NextProtocolType* protoType = nullptr) const; /** * Get the name of the protocol selected by the client during - * Next Protocol Negotiation (NPN) + * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation + * (ALPN) * * @param protoName Name of the protocol (not guaranteed to be * null terminated); will be set to nullptr if @@ -400,10 +405,13 @@ class AsyncSSLSocket : public virtual AsyncSocket { * Note: the AsyncSSLSocket retains ownership * of this string. * @param protoNameLen Length of the name. + * @param protoType Whether this was an NPN or ALPN negotiation * @return false if openssl does not support NPN */ - virtual bool getSelectedNextProtocolNoThrow(const unsigned char** protoName, - unsigned* protoLen) const; + virtual bool getSelectedNextProtocolNoThrow( + const unsigned char** protoName, + unsigned* protoLen, + SSLContext::NextProtocolType* protoType = nullptr) const; /** * Determine if the session specified during setSSLSession was reused diff --git a/folly/io/async/SSLContext.cpp b/folly/io/async/SSLContext.cpp index 7426e237..def95ee4 100644 --- a/folly/io/async/SSLContext.cpp +++ b/folly/io/async/SSLContext.cpp @@ -305,13 +305,43 @@ void SSLContext::switchCiphersIfTLS11( } #endif +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) +int SSLContext::alpnSelectCallback(SSL* ssl, + const unsigned char** out, + unsigned char* outlen, + const unsigned char* in, + unsigned int inlen, + void* data) { + SSLContext* context = (SSLContext*)data; + CHECK(context); + if (context->advertisedNextProtocols_.empty()) { + *out = nullptr; + *outlen = 0; + } else { + auto i = context->pickNextProtocols(); + const auto& item = context->advertisedNextProtocols_[i]; + if (SSL_select_next_proto((unsigned char**)out, + outlen, + item.protocols, + item.length, + in, + inlen) != OPENSSL_NPN_NEGOTIATED) { + return SSL_TLSEXT_ERR_NOACK; + } + } + return SSL_TLSEXT_ERR_OK; +} +#endif + #ifdef OPENSSL_NPN_NEGOTIATED -bool SSLContext::setAdvertisedNextProtocols(const std::list& protocols) { - return setRandomizedAdvertisedNextProtocols({{1, protocols}}); + +bool SSLContext::setAdvertisedNextProtocols( + const std::list& protocols, NextProtocolType protocolType) { + return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType); } bool SSLContext::setRandomizedAdvertisedNextProtocols( - const std::list& items) { + const std::list& items, NextProtocolType protocolType) { unsetNextProtocols(); if (items.size() == 0) { return false; @@ -354,10 +384,20 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols( for (auto &advertised_item : advertisedNextProtocols_) { advertised_item.probability /= total_weight; } - SSL_CTX_set_next_protos_advertised_cb( - ctx_, advertisedNextProtocolCallback, this); - SSL_CTX_set_next_proto_select_cb( - ctx_, selectNextProtocolCallback, this); + if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) { + SSL_CTX_set_next_protos_advertised_cb( + ctx_, advertisedNextProtocolCallback, this); + SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this); + } +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) + if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) { + SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this); + // Client cannot really use randomized alpn + SSL_CTX_set_alpn_protos(ctx_, + advertisedNextProtocols_[0].protocols, + advertisedNextProtocols_[0].length); + } +#endif return true; } @@ -372,6 +412,25 @@ void SSLContext::unsetNextProtocols() { deleteNextProtocolsStrings(); SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr); SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr); +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) + SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr); + SSL_CTX_set_alpn_protos(ctx_, nullptr, 0); +#endif +} + +size_t SSLContext::pickNextProtocols() { + unsigned char random_byte; + RAND_bytes(&random_byte, 1); + double random_value = random_byte / 255.0; + double sum = 0; + for (size_t i = 0; i < advertisedNextProtocols_.size(); ++i) { + sum += advertisedNextProtocols_[i].probability; + if (sum < random_value && i + 1 < advertisedNextProtocols_.size()) { + continue; + } + return i; + } + CHECK(false) << "Failed to pickNextProtocols"; } int SSLContext::advertisedNextProtocolCallback(SSL* ssl, @@ -391,22 +450,11 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl, *out = context->advertisedNextProtocols_[selected_index].protocols; *outlen = context->advertisedNextProtocols_[selected_index].length; } else { - unsigned char random_byte; - RAND_bytes(&random_byte, 1); - double random_value = random_byte / 255.0; - double sum = 0; - for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) { - sum += context->advertisedNextProtocols_[i].probability; - if (sum < random_value && - i + 1 < context->advertisedNextProtocols_.size()) { - continue; - } - uintptr_t selected = i + 1; - SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected); - *out = context->advertisedNextProtocols_[i].protocols; - *outlen = context->advertisedNextProtocols_[i].length; - break; - } + auto i = context->pickNextProtocols(); + uintptr_t selected = i + 1; + SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected); + *out = context->advertisedNextProtocols_[i].protocols; + *outlen = context->advertisedNextProtocols_[i].length; } } return SSL_TLSEXT_ERR_OK; diff --git a/folly/io/async/SSLContext.h b/folly/io/async/SSLContext.h index 90c96ca6..a4b44b2d 100644 --- a/folly/io/async/SSLContext.h +++ b/folly/io/async/SSLContext.h @@ -290,33 +290,42 @@ class SSLContext { */ void setOptions(long options); + enum class NextProtocolType : uint8_t { + NPN = 0x1, + ALPN = 0x2, + ANY = NPN | ALPN + }; + #ifdef OPENSSL_NPN_NEGOTIATED /** * Set the list of protocols that this SSL context supports. In server * mode, this is the list of protocols that will be advertised for Next - * Protocol Negotiation (NPN). In client mode, the first protocol - * advertised by the server that is also on this list is - * chosen. Invoking this function with a list of length zero causes NPN - * to be disabled. + * Protocol Negotiation (NPN) or Application Layer Protocol Negotiation + * (ALPN). In client mode, the first protocol advertised by the server + * that is also on this list is chosen. Invoking this function with a list + * of length zero causes NPN to be disabled. * * @param protocols List of protocol names. This method makes a copy, * so the caller needn't keep the list in scope after * the call completes. The list must have at least * one element to enable NPN. Each element must have * a string length < 256. - * @return true if NPN has been activated. False if NPN is disabled. + * @param protocolType What type of protocol negotiation to support. + * @return true if NPN/ALPN has been activated. False if NPN/ALPN is disabled. */ - bool setAdvertisedNextProtocols(const std::list& protocols); + bool setAdvertisedNextProtocols( + const std::list& protocols, + NextProtocolType protocolType = NextProtocolType::ANY); /** * Set weighted list of lists of protocols that this SSL context supports. * In server mode, each element of the list contains a list of protocols that - * could be advertised for Next Protocol Negotiation (NPN). The list of - * protocols that will be advertised to a client is selected randomly, based - * on weights of elements. Client mode doesn't support randomized NPN, so - * this list should contain only 1 element. The first protocol advertised - * by the server that is also on the list of protocols of this element is - * chosen. Invoking this function with a list of length zero causes NPN - * to be disabled. + * could be advertised for Next Protocol Negotiation (NPN) or Application + * Layer Protocol Negotiation (ALPN). The list of protocols that will be + * advertised to a client is selected randomly, based on weights of elements. + * Client mode doesn't support randomized NPN/ALPN, so this list should + * contain only 1 element. The first protocol advertised by the server that + * is also on the list of protocols of this element is chosen. Invoking this + * function with a list of length zero causes NPN/ALPN to be disabled. * * @param items List of NextProtocolsItems, Each item contains a list of * protocol names and weight. After the call of this fucntion @@ -326,10 +335,12 @@ class SSLContext { * completes. The list must have at least one element with * non-zero weight and non-empty protocols list to enable NPN. * Each name of the protocol must have a string length < 256. - * @return true if NPN has been activated. False if NPN is disabled. + * @param protocolType What type of protocol negotiation to support. + * @return true if NPN/ALPN has been activated. False if NPN/ALPN is disabled. */ bool setRandomizedAdvertisedNextProtocols( - const std::list& items); + const std::list& items, + NextProtocolType protocolType = NextProtocolType::ANY); void setClientProtocolFilterCallback(ClientProtocolFilterCallback cb) { clientProtoFilter_ = cb; @@ -459,6 +470,16 @@ class SSLContext { SSL* ssl, unsigned char **out, unsigned char *outlen, const unsigned char *server, unsigned int server_len, void *args); +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) + static int alpnSelectCallback(SSL* ssl, + const unsigned char** out, + unsigned char* outlen, + const unsigned char* in, + unsigned int inlen, + void* data); +#endif + size_t pickNextProtocols(); + #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \ FOLLY_SSLCONTEXT_USE_TLS_FALSE_START // This class contains all allowed ciphers for SSL false start. Call its diff --git a/folly/io/async/test/AsyncSSLSocketTest.cpp b/folly/io/async/test/AsyncSSLSocketTest.cpp index a3e14886..350a6b57 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.cpp +++ b/folly/io/async/test/AsyncSSLSocketTest.cpp @@ -315,191 +315,214 @@ TEST(AsyncSSLSocketTest, SocketWithDelay) { cerr << "SocketWithDelay test completed" << endl; } -TEST(AsyncSSLSocketTest, NpnTestOverlap) { - EventBase eventBase; - std::shared_ptr clientCtx(new SSLContext); - std::shared_ptr serverCtx(new SSLContext);; - int fds[2]; - getfds(fds); - getctx(clientCtx, serverCtx); +using NextProtocolTypePair = + std::pair; - clientCtx->setAdvertisedNextProtocols({"blub","baz"}); - serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); +class NextProtocolTest : public testing::TestWithParam { + // For matching protos + public: + void SetUp() override { getctx(clientCtx, serverCtx); } - AsyncSSLSocket::UniquePtr clientSock( - new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); - AsyncSSLSocket::UniquePtr serverSock( - new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); - NpnClient client(std::move(clientSock)); - NpnServer server(std::move(serverSock)); + void connect(bool unset = false) { + getfds(fds); - eventBase.loop(); + if (unset) { + // unsetting NPN for any of [client, server] is enough to make NPN not + // work + clientCtx->unsetNextProtocols(); + } - EXPECT_TRUE(client.nextProtoLength != 0); - EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); - EXPECT_EQ(memcmp(client.nextProto, server.nextProto, - server.nextProtoLength), 0); - string selected((const char*)client.nextProto, client.nextProtoLength); - EXPECT_EQ(selected.compare("baz"), 0); -} + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); + client = folly::make_unique(std::move(clientSock)); + server = folly::make_unique(std::move(serverSock)); + + eventBase.loop(); + } + + void expectProtocol(const std::string& proto) { + EXPECT_NE(client->nextProtoLength, 0); + EXPECT_EQ(client->nextProtoLength, server->nextProtoLength); + EXPECT_EQ( + memcmp(client->nextProto, server->nextProto, server->nextProtoLength), + 0); + string selected((const char*)client->nextProto, client->nextProtoLength); + EXPECT_EQ(proto, selected); + } + + void expectNoProtocol() { + EXPECT_EQ(client->nextProtoLength, 0); + EXPECT_EQ(server->nextProtoLength, 0); + EXPECT_EQ(client->nextProto, nullptr); + EXPECT_EQ(server->nextProto, nullptr); + } + + void expectProtocolType() { + if (GetParam().first == SSLContext::NextProtocolType::ANY && + GetParam().second == SSLContext::NextProtocolType::ANY) { + EXPECT_EQ(client->protocolType, server->protocolType); + } else if (GetParam().first == SSLContext::NextProtocolType::ANY || + GetParam().second == SSLContext::NextProtocolType::ANY) { + // Well not much we can say + } else { + expectProtocolType(GetParam()); + } + } + + void expectProtocolType(NextProtocolTypePair expected) { + EXPECT_EQ(client->protocolType, expected.first); + EXPECT_EQ(server->protocolType, expected.second); + } -TEST(AsyncSSLSocketTest, NpnTestUnset) { - // Identical to above test, except that we want unset NPN before - // looping. EventBase eventBase; - std::shared_ptr clientCtx(new SSLContext); - std::shared_ptr serverCtx(new SSLContext);; + std::shared_ptr clientCtx{std::make_shared()}; + std::shared_ptr serverCtx{std::make_shared()}; int fds[2]; - getfds(fds); - getctx(clientCtx, serverCtx); - - clientCtx->setAdvertisedNextProtocols({"blub","baz"}); - serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); + std::unique_ptr client; + std::unique_ptr server; +}; - AsyncSSLSocket::UniquePtr clientSock( - new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); - AsyncSSLSocket::UniquePtr serverSock( - new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); +class NextProtocolNPNOnlyTest : public NextProtocolTest { + // For mismatching protos +}; - // unsetting NPN for any of [client, server] is enought to make NPN not - // work - clientCtx->unsetNextProtocols(); +class NextProtocolMismatchTest : public NextProtocolTest { + // For mismatching protos +}; - NpnClient client(std::move(clientSock)); - NpnServer server(std::move(serverSock)); +TEST_P(NextProtocolTest, NpnTestOverlap) { + clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first); + serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().second); - eventBase.loop(); + connect(); - EXPECT_TRUE(client.nextProtoLength == 0); - EXPECT_TRUE(server.nextProtoLength == 0); - EXPECT_TRUE(client.nextProto == nullptr); - EXPECT_TRUE(server.nextProto == nullptr); + expectProtocol("baz"); + expectProtocolType(); } -TEST(AsyncSSLSocketTest, NpnTestNoOverlap) { - EventBase eventBase; - std::shared_ptr clientCtx(new SSLContext); - std::shared_ptr serverCtx(new SSLContext);; - int fds[2]; - getfds(fds); - getctx(clientCtx, serverCtx); +TEST_P(NextProtocolTest, NpnTestUnset) { + // Identical to above test, except that we want unset NPN before + // looping. + clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first); + serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().second); - clientCtx->setAdvertisedNextProtocols({"blub"}); - serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); + connect(true /* unset */); - AsyncSSLSocket::UniquePtr clientSock( - new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); - AsyncSSLSocket::UniquePtr serverSock( - new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); - NpnClient client(std::move(clientSock)); - NpnServer server(std::move(serverSock)); + // if alpn negotiation fails, type will appear as npn + expectNoProtocol(); + EXPECT_EQ(client->protocolType, server->protocolType); +} - eventBase.loop(); +TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) { + clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first); + serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().second); - EXPECT_TRUE(client.nextProtoLength != 0); - EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); - EXPECT_EQ(memcmp(client.nextProto, server.nextProto, - server.nextProtoLength), 0); - string selected((const char*)client.nextProto, client.nextProtoLength); - EXPECT_EQ(selected.compare("blub"), 0); + connect(); + + expectNoProtocol(); + expectProtocolType( + {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN}); } -TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterHit) { - EventBase eventBase; - auto clientCtx = std::make_shared(); - auto serverCtx = std::make_shared(); - int fds[2]; - getfds(fds); - getctx(clientCtx, serverCtx); +TEST_P(NextProtocolNPNOnlyTest, NpnTestNoOverlap) { + clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first); + serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().second); - clientCtx->setAdvertisedNextProtocols({"blub"}); - clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony); - serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); + connect(); - AsyncSSLSocket::UniquePtr clientSock( - new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); - AsyncSSLSocket::UniquePtr serverSock( - new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); - NpnClient client(std::move(clientSock)); - NpnServer server(std::move(serverSock)); + expectProtocol("blub"); + expectProtocolType(); +} - eventBase.loop(); +TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) { + clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first); + clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony); + serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().second); - EXPECT_TRUE(client.nextProtoLength != 0); - EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); - EXPECT_EQ(memcmp(client.nextProto, server.nextProto, - server.nextProtoLength), 0); - string selected((const char*)client.nextProto, client.nextProtoLength); - EXPECT_EQ(selected.compare("ponies"), 0); -} + connect(); -TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterMiss) { - EventBase eventBase; - auto clientCtx = std::make_shared(); - auto serverCtx = std::make_shared(); - int fds[2]; - getfds(fds); - getctx(clientCtx, serverCtx); + expectProtocol("ponies"); + expectProtocolType(); +} - clientCtx->setAdvertisedNextProtocols({"blub"}); +TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) { + clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first); clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone); - serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); - - AsyncSSLSocket::UniquePtr clientSock( - new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); - AsyncSSLSocket::UniquePtr serverSock( - new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); - NpnClient client(std::move(clientSock)); - NpnServer server(std::move(serverSock)); + serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().second); - eventBase.loop(); + connect(); - EXPECT_TRUE(client.nextProtoLength != 0); - EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); - EXPECT_EQ(memcmp(client.nextProto, server.nextProto, - server.nextProtoLength), 0); - string selected((const char*)client.nextProto, client.nextProtoLength); - EXPECT_EQ(selected.compare("blub"), 0); + expectProtocol("blub"); + expectProtocolType(); } -TEST(AsyncSSLSocketTest, RandomizedNpnTest) { +TEST_P(NextProtocolTest, RandomizedNpnTest) { // Probability that this test will fail is 2^-64, which could be considered // as negligible. const int kTries = 64; + clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}, + GetParam().first); + serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}}, + GetParam().second); + std::set selectedProtocols; for (int i = 0; i < kTries; ++i) { - EventBase eventBase; - std::shared_ptr clientCtx = std::make_shared(); - std::shared_ptr serverCtx = std::make_shared(); - int fds[2]; - getfds(fds); - getctx(clientCtx, serverCtx); - - clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"}); - serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, - {1, {"bar"}}}); - - - AsyncSSLSocket::UniquePtr clientSock( - new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); - AsyncSSLSocket::UniquePtr serverSock( - new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); - NpnClient client(std::move(clientSock)); - NpnServer server(std::move(serverSock)); - - eventBase.loop(); - - EXPECT_TRUE(client.nextProtoLength != 0); - EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); - EXPECT_EQ(memcmp(client.nextProto, server.nextProto, - server.nextProtoLength), 0); - string selected((const char*)client.nextProto, client.nextProtoLength); + connect(); + + EXPECT_NE(client->nextProtoLength, 0); + EXPECT_EQ(client->nextProtoLength, server->nextProtoLength); + EXPECT_EQ( + memcmp(client->nextProto, server->nextProto, server->nextProtoLength), + 0); + string selected((const char*)client->nextProto, client->nextProtoLength); selectedProtocols.insert(selected); + expectProtocolType(); } EXPECT_EQ(selectedProtocols.size(), 2); } +INSTANTIATE_TEST_CASE_P( + AsyncSSLSocketTest, + NextProtocolTest, + ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN, + SSLContext::NextProtocolType::NPN), +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) + NextProtocolTypePair(SSLContext::NextProtocolType::ALPN, + SSLContext::NextProtocolType::ALPN), +#endif + NextProtocolTypePair(SSLContext::NextProtocolType::NPN, + SSLContext::NextProtocolType::ANY), +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) + NextProtocolTypePair(SSLContext::NextProtocolType::ALPN, + SSLContext::NextProtocolType::ANY), +#endif + NextProtocolTypePair(SSLContext::NextProtocolType::ANY, + SSLContext::NextProtocolType::ANY))); + +INSTANTIATE_TEST_CASE_P( + AsyncSSLSocketTest, + NextProtocolNPNOnlyTest, + ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN, + SSLContext::NextProtocolType::NPN))); + +#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT) +INSTANTIATE_TEST_CASE_P( + AsyncSSLSocketTest, + NextProtocolMismatchTest, + ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN, + SSLContext::NextProtocolType::ALPN), + NextProtocolTypePair(SSLContext::NextProtocolType::ALPN, + SSLContext::NextProtocolType::NPN))); +#endif #ifndef OPENSSL_NO_TLSEXT /** @@ -657,8 +680,7 @@ TEST(AsyncSSLSocketTest, SSLClientTest) { // Set up SSL client EventBase eventBase; - std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), - 1)); + auto client = std::make_shared(&eventBase, server.getAddress(), 1); client->connect(); EventBaseAborter eba(&eventBase, 3000); @@ -684,8 +706,8 @@ TEST(AsyncSSLSocketTest, SSLClientTestReuse) { // Set up SSL client EventBase eventBase; - std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), - 10)); + auto client = + std::make_shared(&eventBase, server.getAddress(), 10); client->connect(); EventBaseAborter eba(&eventBase, 3000); @@ -710,8 +732,8 @@ TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) { // Set up SSL client EventBase eventBase; - std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), - 1, 10)); + auto client = + std::make_shared(&eventBase, server.getAddress(), 1, 10); client->connect(true /* write before connect completes */); EventBaseAborter eba(&eventBase, 3000); eventBase.loop(); @@ -741,8 +763,8 @@ TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) { // Set up SSL client EventBase eventBase; - std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), - 10, 500)); + auto client = + std::make_shared(&eventBase, server.getAddress(), 10, 500); client->connect(); EventBaseAborter eba(&eventBase, 3000); @@ -798,8 +820,7 @@ TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) { // Set up SSL client EventBase eventBase; - std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), - 2)); + auto client = std::make_shared(&eventBase, server.getAddress(), 2); client->connect(); EventBaseAborter eba(&eventBase, 3000); @@ -828,8 +849,8 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) { // Set up SSL client EventBase eventBase; - std::shared_ptr client(new SSLClient(&eventBase, server.getAddress(), - 2, 100)); + auto client = + std::make_shared(&eventBase, server.getAddress(), 2, 100); client->connect(); EventBaseAborter eba(&eventBase, 3000); diff --git a/folly/io/async/test/AsyncSSLSocketTest.h b/folly/io/async/test/AsyncSSLSocketTest.h index 63349bc9..683b2a65 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.h +++ b/folly/io/async/test/AsyncSSLSocketTest.h @@ -804,10 +804,12 @@ class NpnClient : const unsigned char* nextProto; unsigned nextProtoLength; + SSLContext::NextProtocolType protocolType; + private: void handshakeSuc(AsyncSSLSocket*) noexcept override { - socket_->getSelectedNextProtocol(&nextProto, - &nextProtoLength); + socket_->getSelectedNextProtocol( + &nextProto, &nextProtoLength, &protocolType); } void handshakeErr( AsyncSSLSocket*, @@ -838,10 +840,12 @@ class NpnServer : const unsigned char* nextProto; unsigned nextProtoLength; + SSLContext::NextProtocolType protocolType; + private: void handshakeSuc(AsyncSSLSocket*) noexcept override { - socket_->getSelectedNextProtocol(&nextProto, - &nextProtoLength); + socket_->getSelectedNextProtocol( + &nextProto, &nextProtoLength, &protocolType); } void handshakeErr( AsyncSSLSocket*, diff --git a/folly/io/async/test/MockAsyncSSLSocket.h b/folly/io/async/test/MockAsyncSSLSocket.h index ff9456c6..6c930d79 100644 --- a/folly/io/async/test/MockAsyncSSLSocket.h +++ b/folly/io/async/test/MockAsyncSSLSocket.h @@ -42,12 +42,14 @@ class MockAsyncSSLSocket : public AsyncSSLSocket { MOCK_CONST_METHOD0(good, bool()); MOCK_CONST_METHOD0(readable, bool()); MOCK_CONST_METHOD0(hangup, bool()); - MOCK_CONST_METHOD2( - getSelectedNextProtocol, - void(const unsigned char**, unsigned*)); - MOCK_CONST_METHOD2( - getSelectedNextProtocolNoThrow, - bool(const unsigned char**, unsigned*)); + MOCK_CONST_METHOD3(getSelectedNextProtocol, + void(const unsigned char**, + unsigned*, + SSLContext::NextProtocolType*)); + MOCK_CONST_METHOD3(getSelectedNextProtocolNoThrow, + bool(const unsigned char**, + unsigned*, + SSLContext::NextProtocolType*)); MOCK_METHOD1(setPeek, void(bool)); MOCK_METHOD1(setReadCB, void(ReadCallback*)); -- 2.34.1