From: Ranjeeth Dasineni Date: Wed, 30 Sep 2015 00:57:06 +0000 (-0700) Subject: add callback to specify a client next protocol filter X-Git-Tag: deprecate-dynamic-initializer~370 X-Git-Url: http://plrg.eecs.uci.edu/git/?p=folly.git;a=commitdiff_plain;h=14a19db224da08445efa8f4cba32b86e004689df add callback to specify a client next protocol filter Summary: From the client perspective, we set the list in order of preference once and call into openssl to do the selection. This adds a little more flexibility in that client optionally can customize the selection for each negotiation. added tests for the no-op case and the customized case. Feel free to suggest improvements. Reviewed By: @afrind Differential Revision: D2489142 --- diff --git a/folly/io/async/SSLContext.cpp b/folly/io/async/SSLContext.cpp index cfd25c7f..34dbc917 100644 --- a/folly/io/async/SSLContext.cpp +++ b/folly/io/async/SSLContext.cpp @@ -513,13 +513,21 @@ int SSLContext::selectNextProtocolCallback( } unsigned char *client; - int client_len; - if (ctx->advertisedNextProtocols_.empty()) { - client = (unsigned char *) ""; - client_len = 0; - } else { - client = ctx->advertisedNextProtocols_[0].protocols; - client_len = ctx->advertisedNextProtocols_[0].length; + unsigned int client_len; + bool filtered = false; + auto cpf = ctx->getClientProtocolFilterCallback(); + if (cpf) { + filtered = (*cpf)(&client, &client_len, server, server_len); + } + + if (!filtered) { + if (ctx->advertisedNextProtocols_.empty()) { + client = (unsigned char *) ""; + client_len = 0; + } else { + client = ctx->advertisedNextProtocols_[0].protocols; + client_len = ctx->advertisedNextProtocols_[0].length; + } } int retval = SSL_select_next_proto(out, outlen, server, server_len, diff --git a/folly/io/async/SSLContext.h b/folly/io/async/SSLContext.h index 7585c168..90c96ca6 100644 --- a/folly/io/async/SSLContext.h +++ b/folly/io/async/SSLContext.h @@ -93,6 +93,10 @@ class SSLContext { double probability; }; + // Function that selects a client protocol given the server's list + using ClientProtocolFilterCallback = bool (*)(unsigned char**, unsigned int*, + const unsigned char*, unsigned int); + /** * Convenience function to call getErrors() with the current errno value. * @@ -327,6 +331,13 @@ class SSLContext { bool setRandomizedAdvertisedNextProtocols( const std::list& items); + void setClientProtocolFilterCallback(ClientProtocolFilterCallback cb) { + clientProtoFilter_ = cb; + } + + ClientProtocolFilterCallback getClientProtocolFilterCallback() { + return clientProtoFilter_; + } /** * Disables NPN on this SSL context. */ @@ -431,6 +442,8 @@ class SSLContext { std::vector clientHelloCbs_; #endif + ClientProtocolFilterCallback clientProtoFilter_{nullptr}; + static bool initialized_; #ifdef OPENSSL_NPN_NEGOTIATED diff --git a/folly/io/async/test/AsyncSSLSocketTest.cpp b/folly/io/async/test/AsyncSSLSocketTest.cpp index 4eb2fef2..f9624f05 100644 --- a/folly/io/async/test/AsyncSSLSocketTest.cpp +++ b/folly/io/async/test/AsyncSSLSocketTest.cpp @@ -127,6 +127,21 @@ void sslsocketpair( // (*serverSock)->setSendTimeout(100); } +// client protocol filters +bool clientProtoFilterPickPony(unsigned char** client, + unsigned int* client_len, const unsigned char*, unsigned int ) { + //the protocol string in length prefixed byte string. the + //length byte is not included in the length + static unsigned char p[7] = {6,'p','o','n','i','e','s'}; + *client = p; + *client_len = 7; + return true; +} + +bool clientProtoFilterPickNone(unsigned char**, unsigned int*, + const unsigned char*, unsigned int) { + return false; +} /** * Test connecting to, writing to, reading from, and closing the @@ -387,6 +402,64 @@ TEST(AsyncSSLSocketTest, NpnTestNoOverlap) { EXPECT_EQ(selected.compare("blub"), 0); } +TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterHit) { + EventBase eventBase; + auto clientCtx = std::make_shared(); + auto serverCtx = std::make_shared(); + int fds[2]; + getfds(fds); + getctx(clientCtx, serverCtx); + + clientCtx->setAdvertisedNextProtocols({"blub"}); + clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony); + serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); + NpnClient client(std::move(clientSock)); + NpnServer server(std::move(serverSock)); + + eventBase.loop(); + + EXPECT_TRUE(client.nextProtoLength != 0); + EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); + EXPECT_EQ(memcmp(client.nextProto, server.nextProto, + server.nextProtoLength), 0); + string selected((const char*)client.nextProto, client.nextProtoLength); + EXPECT_EQ(selected.compare("ponies"), 0); +} + +TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterMiss) { + EventBase eventBase; + auto clientCtx = std::make_shared(); + auto serverCtx = std::make_shared(); + int fds[2]; + getfds(fds); + getctx(clientCtx, serverCtx); + + clientCtx->setAdvertisedNextProtocols({"blub"}); + clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone); + serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"}); + + AsyncSSLSocket::UniquePtr clientSock( + new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false)); + AsyncSSLSocket::UniquePtr serverSock( + new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); + NpnClient client(std::move(clientSock)); + NpnServer server(std::move(serverSock)); + + eventBase.loop(); + + EXPECT_TRUE(client.nextProtoLength != 0); + EXPECT_EQ(client.nextProtoLength, server.nextProtoLength); + EXPECT_EQ(memcmp(client.nextProto, server.nextProto, + server.nextProtoLength), 0); + string selected((const char*)client.nextProto, client.nextProtoLength); + EXPECT_EQ(selected.compare("blub"), 0); +} + TEST(AsyncSSLSocketTest, RandomizedNpnTest) { // Probability that this test will fail is 2^-64, which could be considered // as negligible.