add callback to specify a client next protocol filter
authorRanjeeth Dasineni <ranjeeth@fb.com>
Wed, 30 Sep 2015 00:57:06 +0000 (17:57 -0700)
committerfacebook-github-bot-9 <folly-bot@fb.com>
Wed, 30 Sep 2015 01:20:19 +0000 (18:20 -0700)
Summary: From the client perspective, we set the list in order of
preference once and call into openssl to do the selection. This adds
a little more flexibility in that client optionally can customize the
selection for each negotiation. added tests for the no-op case and the
customized case. Feel free to suggest improvements.

Reviewed By: @afrind

Differential Revision: D2489142

folly/io/async/SSLContext.cpp
folly/io/async/SSLContext.h
folly/io/async/test/AsyncSSLSocketTest.cpp

index cfd25c7f7c55341b944ec7344afdcf4b7f2a8801..34dbc91775f16b353b38259364a7699614f385cb 100644 (file)
@@ -513,13 +513,21 @@ int SSLContext::selectNextProtocolCallback(
   }
 
   unsigned char *client;
-  int client_len;
-  if (ctx->advertisedNextProtocols_.empty()) {
-    client = (unsigned char *) "";
-    client_len = 0;
-  } else {
-    client = ctx->advertisedNextProtocols_[0].protocols;
-    client_len = ctx->advertisedNextProtocols_[0].length;
+  unsigned int client_len;
+  bool filtered = false;
+  auto cpf = ctx->getClientProtocolFilterCallback();
+  if (cpf) {
+    filtered = (*cpf)(&client, &client_len, server, server_len);
+  }
+
+  if (!filtered) {
+    if (ctx->advertisedNextProtocols_.empty()) {
+      client = (unsigned char *) "";
+      client_len = 0;
+    } else {
+      client = ctx->advertisedNextProtocols_[0].protocols;
+      client_len = ctx->advertisedNextProtocols_[0].length;
+    }
   }
 
   int retval = SSL_select_next_proto(out, outlen, server, server_len,
index 7585c168b8a285cda5cc79b22e8ca31cc3492794..90c96ca681e63e0f9fd2518728373b766ea47bbe 100644 (file)
@@ -93,6 +93,10 @@ class SSLContext {
     double probability;
   };
 
+  // Function that selects a client protocol given the server's list
+  using ClientProtocolFilterCallback = bool (*)(unsigned char**, unsigned int*,
+                                        const unsigned char*, unsigned int);
+
   /**
    * Convenience function to call getErrors() with the current errno value.
    *
@@ -327,6 +331,13 @@ class SSLContext {
   bool setRandomizedAdvertisedNextProtocols(
       const std::list<NextProtocolsItem>& items);
 
+  void setClientProtocolFilterCallback(ClientProtocolFilterCallback cb) {
+    clientProtoFilter_ = cb;
+  }
+
+  ClientProtocolFilterCallback getClientProtocolFilterCallback() {
+    return clientProtoFilter_;
+  }
   /**
    * Disables NPN on this SSL context.
    */
@@ -431,6 +442,8 @@ class SSLContext {
   std::vector<ClientHelloCallback> clientHelloCbs_;
 #endif
 
+  ClientProtocolFilterCallback clientProtoFilter_{nullptr};
+
   static bool initialized_;
 
 #ifdef OPENSSL_NPN_NEGOTIATED
index 4eb2fef222583e1c65b8160d11825a6351e3fe95..f9624f058189456593539faee614c32046ce5473 100644 (file)
@@ -127,6 +127,21 @@ void sslsocketpair(
   // (*serverSock)->setSendTimeout(100);
 }
 
+// client protocol filters
+bool clientProtoFilterPickPony(unsigned char** client,
+  unsigned int* client_len, const unsigned char*, unsigned int ) {
+  //the protocol string in length prefixed byte string. the
+  //length byte is not included in the length
+  static unsigned char p[7] = {6,'p','o','n','i','e','s'};
+  *client = p;
+  *client_len = 7;
+  return true;
+}
+
+bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
+  const unsigned char*, unsigned int) {
+  return false;
+}
 
 /**
  * Test connecting to, writing to, reading from, and closing the
@@ -387,6 +402,64 @@ TEST(AsyncSSLSocketTest, NpnTestNoOverlap) {
   EXPECT_EQ(selected.compare("blub"), 0);
 }
 
+TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterHit) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto serverCtx = std::make_shared<SSLContext>();
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, serverCtx);
+
+  clientCtx->setAdvertisedNextProtocols({"blub"});
+  clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
+  serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+  NpnClient client(std::move(clientSock));
+  NpnServer server(std::move(serverSock));
+
+  eventBase.loop();
+
+  EXPECT_TRUE(client.nextProtoLength != 0);
+  EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
+  EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
+                           server.nextProtoLength), 0);
+  string selected((const char*)client.nextProto, client.nextProtoLength);
+  EXPECT_EQ(selected.compare("ponies"), 0);
+}
+
+TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterMiss) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto serverCtx = std::make_shared<SSLContext>();
+  int fds[2];
+  getfds(fds);
+  getctx(clientCtx, serverCtx);
+
+  clientCtx->setAdvertisedNextProtocols({"blub"});
+  clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
+  serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+
+  AsyncSSLSocket::UniquePtr clientSock(
+    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+    new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+  NpnClient client(std::move(clientSock));
+  NpnServer server(std::move(serverSock));
+
+  eventBase.loop();
+
+  EXPECT_TRUE(client.nextProtoLength != 0);
+  EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
+  EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
+                           server.nextProtoLength), 0);
+  string selected((const char*)client.nextProto, client.nextProtoLength);
+  EXPECT_EQ(selected.compare("blub"), 0);
+}
+
 TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
   // Probability that this test will fail is 2^-64, which could be considered
   // as negligible.