Add support for ALPN
authorAlan Frindell <afrind@fb.com>
Wed, 9 Dec 2015 19:55:40 +0000 (11:55 -0800)
committerfacebook-github-bot-4 <folly-bot@fb.com>
Wed, 9 Dec 2015 20:20:24 +0000 (12:20 -0800)
Summary: With openssl-1.0.2 and later add support for ALPN.  Clients can request NPN only, but the default is to support either (client will send ALPN list, server will send NPN advertisement if ALPN is not negotiated).

Reviewed By: siyengar

Differential Revision: D2710441

fb-gh-sync-id: a8efe69e1869bbecb4ed9e0a513448fcfdb21ca6

folly/io/async/AsyncSSLSocket.cpp
folly/io/async/AsyncSSLSocket.h
folly/io/async/SSLContext.cpp
folly/io/async/SSLContext.h
folly/io/async/test/AsyncSSLSocketTest.cpp
folly/io/async/test/AsyncSSLSocketTest.h
folly/io/async/test/MockAsyncSSLSocket.h

index 166a403950845bbfde57f9b20c2ac208ce6ef969..e72894cd9ac9bae46f0ad59e65713da6f8085e22 100644 (file)
@@ -55,6 +55,7 @@ using folly::AsyncSocket;
 using folly::AsyncSocketException;
 using folly::AsyncSSLSocket;
 using folly::Optional;
+using folly::SSLContext;
 
 // We have one single dummy SSL context so that we can implement attach
 // and detach methods in a thread safe fashion without modifying opnessl.
@@ -765,21 +766,36 @@ void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
   }
 }
 
-void AsyncSSLSocket::getSelectedNextProtocol(const unsigned char** protoName,
-    unsigned* protoLen) const {
-  if (!getSelectedNextProtocolNoThrow(protoName, protoLen)) {
+void AsyncSSLSocket::getSelectedNextProtocol(
+    const unsigned char** protoName,
+    unsigned* protoLen,
+    SSLContext::NextProtocolType* protoType) const {
+  if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) {
     throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
                               "NPN not supported");
   }
 }
 
 bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
-  const unsigned char** protoName,
-  unsigned* protoLen) const {
+    const unsigned char** protoName,
+    unsigned* protoLen,
+    SSLContext::NextProtocolType* protoType) const {
   *protoName = nullptr;
   *protoLen = 0;
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+  SSL_get0_alpn_selected(ssl_, protoName, protoLen);
+  if (*protoLen > 0) {
+    if (protoType) {
+      *protoType = SSLContext::NextProtocolType::ALPN;
+    }
+    return true;
+  }
+#endif
 #ifdef OPENSSL_NPN_NEGOTIATED
   SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
+  if (protoType) {
+    *protoType = SSLContext::NextProtocolType::NPN;
+  }
   return true;
 #else
   return false;
index 5d8a0e6b0e49bc39402f51874e339fd438249683..ba7485ae00cff2b18255b4c7cdf5f6abf82d6b8e 100644 (file)
@@ -376,7 +376,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   /**
    * Get the name of the protocol selected by the client during
-   * Next Protocol Negotiation (NPN)
+   * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
+   * (ALPN)
    *
    * Throw an exception if openssl does not support NPN
    *
@@ -386,13 +387,17 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                       Note: the AsyncSSLSocket retains ownership
    *                       of this string.
    * @param protoNameLen   Length of the name.
+   * @param protoType      Whether this was an NPN or ALPN negotiation
    */
-  virtual void getSelectedNextProtocol(const unsigned char** protoName,
-      unsigned* protoLen) const;
+  virtual void getSelectedNextProtocol(
+      const unsigned char** protoName,
+      unsigned* protoLen,
+      SSLContext::NextProtocolType* protoType = nullptr) const;
 
   /**
    * Get the name of the protocol selected by the client during
-   * Next Protocol Negotiation (NPN)
+   * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
+   * (ALPN)
    *
    * @param protoName      Name of the protocol (not guaranteed to be
    *                       null terminated); will be set to nullptr if
@@ -400,10 +405,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                       Note: the AsyncSSLSocket retains ownership
    *                       of this string.
    * @param protoNameLen   Length of the name.
+   * @param protoType      Whether this was an NPN or ALPN negotiation
    * @return false if openssl does not support NPN
    */
-  virtual bool getSelectedNextProtocolNoThrow(const unsigned char** protoName,
-      unsigned* protoLen) const;
+  virtual bool getSelectedNextProtocolNoThrow(
+      const unsigned char** protoName,
+      unsigned* protoLen,
+      SSLContext::NextProtocolType* protoType = nullptr) const;
 
   /**
    * Determine if the session specified during setSSLSession was reused
index 7426e237bdb63cbb0765d1b5495dbbb2df3dad50..def95ee41cda6520b1d249cb7d72c8235154f002 100644 (file)
@@ -305,13 +305,43 @@ void SSLContext::switchCiphersIfTLS11(
 }
 #endif
 
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+int SSLContext::alpnSelectCallback(SSL* ssl,
+                                   const unsigned char** out,
+                                   unsigned char* outlen,
+                                   const unsigned char* in,
+                                   unsigned int inlen,
+                                   void* data) {
+  SSLContext* context = (SSLContext*)data;
+  CHECK(context);
+  if (context->advertisedNextProtocols_.empty()) {
+    *out = nullptr;
+    *outlen = 0;
+  } else {
+    auto i = context->pickNextProtocols();
+    const auto& item = context->advertisedNextProtocols_[i];
+    if (SSL_select_next_proto((unsigned char**)out,
+                              outlen,
+                              item.protocols,
+                              item.length,
+                              in,
+                              inlen) != OPENSSL_NPN_NEGOTIATED) {
+      return SSL_TLSEXT_ERR_NOACK;
+    }
+  }
+  return SSL_TLSEXT_ERR_OK;
+}
+#endif
+
 #ifdef OPENSSL_NPN_NEGOTIATED
-bool SSLContext::setAdvertisedNextProtocols(const std::list<std::string>& protocols) {
-  return setRandomizedAdvertisedNextProtocols({{1, protocols}});
+
+bool SSLContext::setAdvertisedNextProtocols(
+    const std::list<std::string>& protocols, NextProtocolType protocolType) {
+  return setRandomizedAdvertisedNextProtocols({{1, protocols}}, protocolType);
 }
 
 bool SSLContext::setRandomizedAdvertisedNextProtocols(
-    const std::list<NextProtocolsItem>& items) {
+    const std::list<NextProtocolsItem>& items, NextProtocolType protocolType) {
   unsetNextProtocols();
   if (items.size() == 0) {
     return false;
@@ -354,10 +384,20 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols(
   for (auto &advertised_item : advertisedNextProtocols_) {
     advertised_item.probability /= total_weight;
   }
-  SSL_CTX_set_next_protos_advertised_cb(
-    ctx_, advertisedNextProtocolCallback, this);
-  SSL_CTX_set_next_proto_select_cb(
-    ctx_, selectNextProtocolCallback, this);
+  if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
+    SSL_CTX_set_next_protos_advertised_cb(
+        ctx_, advertisedNextProtocolCallback, this);
+    SSL_CTX_set_next_proto_select_cb(ctx_, selectNextProtocolCallback, this);
+  }
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+  if ((uint8_t)protocolType & (uint8_t)NextProtocolType::ALPN) {
+    SSL_CTX_set_alpn_select_cb(ctx_, alpnSelectCallback, this);
+    // Client cannot really use randomized alpn
+    SSL_CTX_set_alpn_protos(ctx_,
+                            advertisedNextProtocols_[0].protocols,
+                            advertisedNextProtocols_[0].length);
+  }
+#endif
   return true;
 }
 
@@ -372,6 +412,25 @@ void SSLContext::unsetNextProtocols() {
   deleteNextProtocolsStrings();
   SSL_CTX_set_next_protos_advertised_cb(ctx_, nullptr, nullptr);
   SSL_CTX_set_next_proto_select_cb(ctx_, nullptr, nullptr);
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+  SSL_CTX_set_alpn_select_cb(ctx_, nullptr, nullptr);
+  SSL_CTX_set_alpn_protos(ctx_, nullptr, 0);
+#endif
+}
+
+size_t SSLContext::pickNextProtocols() {
+  unsigned char random_byte;
+  RAND_bytes(&random_byte, 1);
+  double random_value = random_byte / 255.0;
+  double sum = 0;
+  for (size_t i = 0; i < advertisedNextProtocols_.size(); ++i) {
+    sum += advertisedNextProtocols_[i].probability;
+    if (sum < random_value && i + 1 < advertisedNextProtocols_.size()) {
+      continue;
+    }
+    return i;
+  }
+  CHECK(false) << "Failed to pickNextProtocols";
 }
 
 int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
@@ -391,22 +450,11 @@ int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
       *out = context->advertisedNextProtocols_[selected_index].protocols;
       *outlen = context->advertisedNextProtocols_[selected_index].length;
     } else {
-      unsigned char random_byte;
-      RAND_bytes(&random_byte, 1);
-      double random_value = random_byte / 255.0;
-      double sum = 0;
-      for (size_t i = 0; i < context->advertisedNextProtocols_.size(); ++i) {
-        sum += context->advertisedNextProtocols_[i].probability;
-        if (sum < random_value &&
-            i + 1 < context->advertisedNextProtocols_.size()) {
-          continue;
-        }
-        uintptr_t selected = i + 1;
-        SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void *)selected);
-        *out = context->advertisedNextProtocols_[i].protocols;
-        *outlen = context->advertisedNextProtocols_[i].length;
-        break;
-      }
+      auto i = context->pickNextProtocols();
+      uintptr_t selected = i + 1;
+      SSL_set_ex_data(ssl, sNextProtocolsExDataIndex_, (void*)selected);
+      *out = context->advertisedNextProtocols_[i].protocols;
+      *outlen = context->advertisedNextProtocols_[i].length;
     }
   }
   return SSL_TLSEXT_ERR_OK;
index 90c96ca681e63e0f9fd2518728373b766ea47bbe..a4b44b2dbea619c7be12a3ac7484c2444269ce56 100644 (file)
@@ -290,33 +290,42 @@ class SSLContext {
    */
   void setOptions(long options);
 
+  enum class NextProtocolType : uint8_t {
+    NPN = 0x1,
+    ALPN = 0x2,
+    ANY = NPN | ALPN
+  };
+
 #ifdef OPENSSL_NPN_NEGOTIATED
   /**
    * Set the list of protocols that this SSL context supports. In server
    * mode, this is the list of protocols that will be advertised for Next
-   * Protocol Negotiation (NPN). In client mode, the first protocol
-   * advertised by the server that is also on this list is
-   * chosen. Invoking this function with a list of length zero causes NPN
-   * to be disabled.
+   * Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
+   * (ALPN). In client mode, the first protocol advertised by the server
+   * that is also on this list is chosen. Invoking this function with a list
+   * of length zero causes NPN to be disabled.
    *
    * @param protocols   List of protocol names. This method makes a copy,
    *                    so the caller needn't keep the list in scope after
    *                    the call completes. The list must have at least
    *                    one element to enable NPN. Each element must have
    *                    a string length < 256.
-   * @return true if NPN has been activated. False if NPN is disabled.
+   * @param protocolType  What type of protocol negotiation to support.
+   * @return true if NPN/ALPN has been activated. False if NPN/ALPN is disabled.
    */
-  bool setAdvertisedNextProtocols(const std::list<std::string>& protocols);
+  bool setAdvertisedNextProtocols(
+      const std::list<std::string>& protocols,
+      NextProtocolType protocolType = NextProtocolType::ANY);
   /**
    * Set weighted list of lists of protocols that this SSL context supports.
    * In server mode, each element of the list contains a list of protocols that
-   * could be advertised for Next Protocol Negotiation (NPN). The list of
-   * protocols that will be advertised to a client is selected randomly, based
-   * on weights of elements. Client mode doesn't support randomized NPN, so
-   * this list should contain only 1 element. The first protocol advertised
-   * by the server that is also on the list of protocols of this element is
-   * chosen. Invoking this function with a list of length zero causes NPN
-   * to be disabled.
+   * could be advertised for Next Protocol Negotiation (NPN) or Application
+   * Layer Protocol Negotiation (ALPN). The list of protocols that will be
+   * advertised to a client is selected randomly, based on weights of elements.
+   * Client mode doesn't support randomized NPN/ALPN, so this list should
+   * contain only 1 element. The first protocol advertised by the server that
+   * is also on the list of protocols of this element is chosen. Invoking this
+   * function with a list of length zero causes NPN/ALPN to be disabled.
    *
    * @param items  List of NextProtocolsItems, Each item contains a list of
    *               protocol names and weight. After the call of this fucntion
@@ -326,10 +335,12 @@ class SSLContext {
    *               completes. The list must have at least one element with
    *               non-zero weight and non-empty protocols list to enable NPN.
    *               Each name of the protocol must have a string length < 256.
-   * @return true if NPN has been activated. False if NPN is disabled.
+   * @param protocolType  What type of protocol negotiation to support.
+   * @return true if NPN/ALPN has been activated. False if NPN/ALPN is disabled.
    */
   bool setRandomizedAdvertisedNextProtocols(
-      const std::list<NextProtocolsItem>& items);
+      const std::list<NextProtocolsItem>& items,
+      NextProtocolType protocolType = NextProtocolType::ANY);
 
   void setClientProtocolFilterCallback(ClientProtocolFilterCallback cb) {
     clientProtoFilter_ = cb;
@@ -459,6 +470,16 @@ class SSLContext {
     SSL* ssl, unsigned char **out, unsigned char *outlen,
     const unsigned char *server, unsigned int server_len, void *args);
 
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+  static int alpnSelectCallback(SSL* ssl,
+                                const unsigned char** out,
+                                unsigned char* outlen,
+                                const unsigned char* in,
+                                unsigned int inlen,
+                                void* data);
+#endif
+  size_t pickNextProtocols();
+
 #if defined(SSL_MODE_HANDSHAKE_CUTTHROUGH) && \
   FOLLY_SSLCONTEXT_USE_TLS_FALSE_START
   // This class contains all allowed ciphers for SSL false start. Call its
index a3e148868d1242aca0f67f523ffca8a22b779d22..350a6b57e40eb9a7039e0518202a8a2d8abf3b5c 100644 (file)
@@ -315,191 +315,214 @@ TEST(AsyncSSLSocketTest, SocketWithDelay) {
   cerr << "SocketWithDelay test completed" << endl;
 }
 
-TEST(AsyncSSLSocketTest, NpnTestOverlap) {
-  EventBase eventBase;
-  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
-  std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
-  int fds[2];
-  getfds(fds);
-  getctx(clientCtx, serverCtx);
+using NextProtocolTypePair =
+    std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
 
-  clientCtx->setAdvertisedNextProtocols({"blub","baz"});
-  serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
+  // For matching protos
+ public:
+  void SetUp() override { getctx(clientCtx, serverCtx); }
 
-  AsyncSSLSocket::UniquePtr clientSock(
-    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
-  AsyncSSLSocket::UniquePtr serverSock(
-    new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
-  NpnClient client(std::move(clientSock));
-  NpnServer server(std::move(serverSock));
+  void connect(bool unset = false) {
+    getfds(fds);
 
-  eventBase.loop();
+    if (unset) {
+      // unsetting NPN for any of [client, server] is enough to make NPN not
+      // work
+      clientCtx->unsetNextProtocols();
+    }
 
-  EXPECT_TRUE(client.nextProtoLength != 0);
-  EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
-  EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
-                           server.nextProtoLength), 0);
-  string selected((const char*)client.nextProto, client.nextProtoLength);
-  EXPECT_EQ(selected.compare("baz"), 0);
-}
+    AsyncSSLSocket::UniquePtr clientSock(
+      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+    AsyncSSLSocket::UniquePtr serverSock(
+      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+    client = folly::make_unique<NpnClient>(std::move(clientSock));
+    server = folly::make_unique<NpnServer>(std::move(serverSock));
+
+    eventBase.loop();
+  }
+
+  void expectProtocol(const std::string& proto) {
+    EXPECT_NE(client->nextProtoLength, 0);
+    EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
+    EXPECT_EQ(
+        memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
+        0);
+    string selected((const char*)client->nextProto, client->nextProtoLength);
+    EXPECT_EQ(proto, selected);
+  }
+
+  void expectNoProtocol() {
+    EXPECT_EQ(client->nextProtoLength, 0);
+    EXPECT_EQ(server->nextProtoLength, 0);
+    EXPECT_EQ(client->nextProto, nullptr);
+    EXPECT_EQ(server->nextProto, nullptr);
+  }
+
+  void expectProtocolType() {
+    if (GetParam().first == SSLContext::NextProtocolType::ANY &&
+        GetParam().second == SSLContext::NextProtocolType::ANY) {
+      EXPECT_EQ(client->protocolType, server->protocolType);
+    } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
+               GetParam().second == SSLContext::NextProtocolType::ANY) {
+      // Well not much we can say
+    } else {
+      expectProtocolType(GetParam());
+    }
+  }
+
+  void expectProtocolType(NextProtocolTypePair expected) {
+    EXPECT_EQ(client->protocolType, expected.first);
+    EXPECT_EQ(server->protocolType, expected.second);
+  }
 
-TEST(AsyncSSLSocketTest, NpnTestUnset) {
-  // Identical to above test, except that we want unset NPN before
-  // looping.
   EventBase eventBase;
-  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
-  std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
+  std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
+  std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
   int fds[2];
-  getfds(fds);
-  getctx(clientCtx, serverCtx);
-
-  clientCtx->setAdvertisedNextProtocols({"blub","baz"});
-  serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+  std::unique_ptr<NpnClient> client;
+  std::unique_ptr<NpnServer> server;
+};
 
-  AsyncSSLSocket::UniquePtr clientSock(
-    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
-  AsyncSSLSocket::UniquePtr serverSock(
-    new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
+class NextProtocolNPNOnlyTest : public NextProtocolTest {
+  // For mismatching protos
+};
 
-  // unsetting NPN for any of [client, server] is enought to make NPN not
-  // work
-  clientCtx->unsetNextProtocols();
+class NextProtocolMismatchTest : public NextProtocolTest {
+  // For mismatching protos
+};
 
-  NpnClient client(std::move(clientSock));
-  NpnServer server(std::move(serverSock));
+TEST_P(NextProtocolTest, NpnTestOverlap) {
+  clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
+  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().second);
 
-  eventBase.loop();
+  connect();
 
-  EXPECT_TRUE(client.nextProtoLength == 0);
-  EXPECT_TRUE(server.nextProtoLength == 0);
-  EXPECT_TRUE(client.nextProto == nullptr);
-  EXPECT_TRUE(server.nextProto == nullptr);
+  expectProtocol("baz");
+  expectProtocolType();
 }
 
-TEST(AsyncSSLSocketTest, NpnTestNoOverlap) {
-  EventBase eventBase;
-  std::shared_ptr<SSLContext> clientCtx(new SSLContext);
-  std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
-  int fds[2];
-  getfds(fds);
-  getctx(clientCtx, serverCtx);
+TEST_P(NextProtocolTest, NpnTestUnset) {
+  // Identical to above test, except that we want unset NPN before
+  // looping.
+  clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
+  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().second);
 
-  clientCtx->setAdvertisedNextProtocols({"blub"});
-  serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+  connect(true /* unset */);
 
-  AsyncSSLSocket::UniquePtr clientSock(
-    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
-  AsyncSSLSocket::UniquePtr serverSock(
-    new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
-  NpnClient client(std::move(clientSock));
-  NpnServer server(std::move(serverSock));
+  // if alpn negotiation fails, type will appear as npn
+  expectNoProtocol();
+  EXPECT_EQ(client->protocolType, server->protocolType);
+}
 
-  eventBase.loop();
+TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
+  clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
+  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().second);
 
-  EXPECT_TRUE(client.nextProtoLength != 0);
-  EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
-  EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
-                           server.nextProtoLength), 0);
-  string selected((const char*)client.nextProto, client.nextProtoLength);
-  EXPECT_EQ(selected.compare("blub"), 0);
+  connect();
+
+  expectNoProtocol();
+  expectProtocolType(
+      {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
 }
 
-TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterHit) {
-  EventBase eventBase;
-  auto clientCtx = std::make_shared<SSLContext>();
-  auto serverCtx = std::make_shared<SSLContext>();
-  int fds[2];
-  getfds(fds);
-  getctx(clientCtx, serverCtx);
+TEST_P(NextProtocolNPNOnlyTest, NpnTestNoOverlap) {
+  clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
+  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().second);
 
-  clientCtx->setAdvertisedNextProtocols({"blub"});
-  clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
-  serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
+  connect();
 
-  AsyncSSLSocket::UniquePtr clientSock(
-    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
-  AsyncSSLSocket::UniquePtr serverSock(
-    new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
-  NpnClient client(std::move(clientSock));
-  NpnServer server(std::move(serverSock));
+  expectProtocol("blub");
+  expectProtocolType();
+}
 
-  eventBase.loop();
+TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
+  clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
+  clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
+  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().second);
 
-  EXPECT_TRUE(client.nextProtoLength != 0);
-  EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
-  EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
-                           server.nextProtoLength), 0);
-  string selected((const char*)client.nextProto, client.nextProtoLength);
-  EXPECT_EQ(selected.compare("ponies"), 0);
-}
+  connect();
 
-TEST(AsyncSSLSocketTest, NpnTestClientProtoFilterMiss) {
-  EventBase eventBase;
-  auto clientCtx = std::make_shared<SSLContext>();
-  auto serverCtx = std::make_shared<SSLContext>();
-  int fds[2];
-  getfds(fds);
-  getctx(clientCtx, serverCtx);
+  expectProtocol("ponies");
+  expectProtocolType();
+}
 
-  clientCtx->setAdvertisedNextProtocols({"blub"});
+TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
+  clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
-  serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
-
-  AsyncSSLSocket::UniquePtr clientSock(
-    new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
-  AsyncSSLSocket::UniquePtr serverSock(
-    new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
-  NpnClient client(std::move(clientSock));
-  NpnServer server(std::move(serverSock));
+  serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().second);
 
-  eventBase.loop();
+  connect();
 
-  EXPECT_TRUE(client.nextProtoLength != 0);
-  EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
-  EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
-                           server.nextProtoLength), 0);
-  string selected((const char*)client.nextProto, client.nextProtoLength);
-  EXPECT_EQ(selected.compare("blub"), 0);
+  expectProtocol("blub");
+  expectProtocolType();
 }
 
-TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
+TEST_P(NextProtocolTest, RandomizedNpnTest) {
   // Probability that this test will fail is 2^-64, which could be considered
   // as negligible.
   const int kTries = 64;
 
+  clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
+                                        GetParam().first);
+  serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
+                                                  GetParam().second);
+
   std::set<string> selectedProtocols;
   for (int i = 0; i < kTries; ++i) {
-    EventBase eventBase;
-    std::shared_ptr<SSLContext> clientCtx = std::make_shared<SSLContext>();
-    std::shared_ptr<SSLContext> serverCtx = std::make_shared<SSLContext>();
-    int fds[2];
-    getfds(fds);
-    getctx(clientCtx, serverCtx);
-
-    clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
-    serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}},
-        {1, {"bar"}}});
-
-
-    AsyncSSLSocket::UniquePtr clientSock(
-      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
-    AsyncSSLSocket::UniquePtr serverSock(
-      new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
-    NpnClient client(std::move(clientSock));
-    NpnServer server(std::move(serverSock));
-
-    eventBase.loop();
-
-    EXPECT_TRUE(client.nextProtoLength != 0);
-    EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
-    EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
-                             server.nextProtoLength), 0);
-    string selected((const char*)client.nextProto, client.nextProtoLength);
+    connect();
+
+    EXPECT_NE(client->nextProtoLength, 0);
+    EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
+    EXPECT_EQ(
+        memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
+        0);
+    string selected((const char*)client->nextProto, client->nextProtoLength);
     selectedProtocols.insert(selected);
+    expectProtocolType();
   }
   EXPECT_EQ(selectedProtocols.size(), 2);
 }
 
+INSTANTIATE_TEST_CASE_P(
+    AsyncSSLSocketTest,
+    NextProtocolTest,
+    ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
+                                           SSLContext::NextProtocolType::NPN),
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+                      NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
+                                           SSLContext::NextProtocolType::ALPN),
+#endif
+                      NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
+                                           SSLContext::NextProtocolType::ANY),
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+                      NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
+                                           SSLContext::NextProtocolType::ANY),
+#endif
+                      NextProtocolTypePair(SSLContext::NextProtocolType::ANY,
+                                           SSLContext::NextProtocolType::ANY)));
+
+INSTANTIATE_TEST_CASE_P(
+    AsyncSSLSocketTest,
+    NextProtocolNPNOnlyTest,
+    ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
+                                           SSLContext::NextProtocolType::NPN)));
+
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+INSTANTIATE_TEST_CASE_P(
+    AsyncSSLSocketTest,
+    NextProtocolMismatchTest,
+    ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
+                                           SSLContext::NextProtocolType::ALPN),
+                      NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
+                                           SSLContext::NextProtocolType::NPN)));
+#endif
 
 #ifndef OPENSSL_NO_TLSEXT
 /**
@@ -657,8 +680,7 @@ TEST(AsyncSSLSocketTest, SSLClientTest) {
 
   // Set up SSL client
   EventBase eventBase;
-  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
-                                             1));
+  auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
 
   client->connect();
   EventBaseAborter eba(&eventBase, 3000);
@@ -684,8 +706,8 @@ TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
 
   // Set up SSL client
   EventBase eventBase;
-  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
-                                             10));
+  auto client =
+      std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
 
   client->connect();
   EventBaseAborter eba(&eventBase, 3000);
@@ -710,8 +732,8 @@ TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
 
   // Set up SSL client
   EventBase eventBase;
-  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
-                                             1, 10));
+  auto client =
+      std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
   client->connect(true /* write before connect completes */);
   EventBaseAborter eba(&eventBase, 3000);
   eventBase.loop();
@@ -741,8 +763,8 @@ TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
 
   // Set up SSL client
   EventBase eventBase;
-  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
-                                             10, 500));
+  auto client =
+      std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
 
   client->connect();
   EventBaseAborter eba(&eventBase, 3000);
@@ -798,8 +820,7 @@ TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
 
   // Set up SSL client
   EventBase eventBase;
-  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
-                                             2));
+  auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
 
   client->connect();
   EventBaseAborter eba(&eventBase, 3000);
@@ -828,8 +849,8 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
 
   // Set up SSL client
   EventBase eventBase;
-  std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
-                                             2, 100));
+  auto client =
+      std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
 
   client->connect();
   EventBaseAborter eba(&eventBase, 3000);
index 63349bc909cd5874e993eaaaeed6ac6097ee74c1..683b2a6507acfa25a1a9c73cb8b5ad3cf013e490 100644 (file)
@@ -804,10 +804,12 @@ class NpnClient :
 
   const unsigned char* nextProto;
   unsigned nextProtoLength;
+  SSLContext::NextProtocolType protocolType;
+
  private:
   void handshakeSuc(AsyncSSLSocket*) noexcept override {
-    socket_->getSelectedNextProtocol(&nextProto,
-                                     &nextProtoLength);
+    socket_->getSelectedNextProtocol(
+        &nextProto, &nextProtoLength, &protocolType);
   }
   void handshakeErr(
     AsyncSSLSocket*,
@@ -838,10 +840,12 @@ class NpnServer :
 
   const unsigned char* nextProto;
   unsigned nextProtoLength;
+  SSLContext::NextProtocolType protocolType;
+
  private:
   void handshakeSuc(AsyncSSLSocket*) noexcept override {
-    socket_->getSelectedNextProtocol(&nextProto,
-                                     &nextProtoLength);
+    socket_->getSelectedNextProtocol(
+        &nextProto, &nextProtoLength, &protocolType);
   }
   void handshakeErr(
     AsyncSSLSocket*,
index ff9456c65be1fa50ec9ebe7d9eaa3f0531e2e36f..6c930d79514883766c62734371d365f4a965ced5 100644 (file)
@@ -42,12 +42,14 @@ class MockAsyncSSLSocket : public AsyncSSLSocket {
   MOCK_CONST_METHOD0(good, bool());
   MOCK_CONST_METHOD0(readable, bool());
   MOCK_CONST_METHOD0(hangup, bool());
-  MOCK_CONST_METHOD2(
-   getSelectedNextProtocol,
-   void(const unsigned char**, unsigned*));
-  MOCK_CONST_METHOD2(
-   getSelectedNextProtocolNoThrow,
-   bool(const unsigned char**, unsigned*));
+  MOCK_CONST_METHOD3(getSelectedNextProtocol,
+                     void(const unsigned char**,
+                          unsigned*,
+                          SSLContext::NextProtocolType*));
+  MOCK_CONST_METHOD3(getSelectedNextProtocolNoThrow,
+                     bool(const unsigned char**,
+                          unsigned*,
+                          SSLContext::NextProtocolType*));
   MOCK_METHOD1(setPeek, void(bool));
   MOCK_METHOD1(setReadCB, void(ReadCallback*));