Add pre received data API to AsyncSSLSocket.
authorKyle Nekritz <knekritz@fb.com>
Mon, 9 Jan 2017 19:51:34 +0000 (11:51 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 9 Jan 2017 20:03:13 +0000 (12:03 -0800)
Summary: This allows something else (ie fizz) to read data from a socket, and then later decide to to accept an SSL connection with OpenSSL by inserting the data it read in front of future reads on the socket.

Reviewed By: anirudhvr

Differential Revision: D4325634

fbshipit-source-id: 05076d2d911fda681b9c4e5d9d3375559293ea35

folly/io/async/AsyncSSLSocket.cpp
folly/io/async/AsyncSSLSocket.h
folly/io/async/test/AsyncSSLSocketTest.cpp

index 70a4640e27e72339e8c1b804f559c54de6e4d1aa..7910bd02aca6adc11e6d38a8949304ec3c726908 100644 (file)
@@ -452,6 +452,10 @@ void AsyncSSLSocket::sslAccept(
 
   /* register for a read operation (waiting for CLIENT HELLO) */
   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
+
+  if (preReceivedData_) {
+    handleRead();
+  }
 }
 
 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
@@ -1610,12 +1614,31 @@ int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) {
   if (!out) {
     return 0;
   }
-  auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0);
   BIO_clear_retry_flags(b);
-  if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
-    BIO_set_retry_read(b);
+
+  auto appData = OpenSSLUtils::getBioAppData(b);
+  CHECK(appData);
+  auto sslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
+
+  if (sslSock->preReceivedData_ && !sslSock->preReceivedData_->empty()) {
+    VLOG(5) << "AsyncSSLSocket::bioRead() this=" << sslSock
+            << ", reading pre-received data";
+
+    Cursor cursor(sslSock->preReceivedData_.get());
+    auto len = cursor.pullAtMost(out, outl);
+
+    IOBufQueue queue;
+    queue.append(std::move(sslSock->preReceivedData_));
+    queue.trimStart(len);
+    sslSock->preReceivedData_ = queue.move();
+    return len;
+  } else {
+    auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0);
+    if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
+      BIO_set_retry_read(b);
+    }
+    return result;
   }
-  return result;
 }
 
 int AsyncSSLSocket::sslVerifyCallback(
@@ -1632,6 +1655,12 @@ int AsyncSSLSocket::sslVerifyCallback(
     preverifyOk;
 }
 
+void AsyncSSLSocket::setPreReceivedData(std::unique_ptr<IOBuf> data) {
+  CHECK(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED);
+  CHECK(!preReceivedData_);
+  preReceivedData_ = std::move(data);
+}
+
 void AsyncSSLSocket::enableClientHelloParsing()  {
     parseClientHello_ = true;
     clientHelloInfo_.reset(new ssl::ClientHelloInfo());
index 0d106214a2587c4cb16a78f5ff4214b4d3a5ec2a..01ff7bd2098ea0abe67994fa90d38b362815c40e 100644 (file)
@@ -278,6 +278,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   virtual size_t getRawBytesReceived() const override;
   void enableClientHelloParsing();
 
+  void setPreReceivedData(std::unique_ptr<IOBuf> data);
+
   /**
    * Accept an SSL connection on the socket.
    *
@@ -818,6 +820,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   std::chrono::steady_clock::time_point handshakeEndTime_;
   std::chrono::milliseconds handshakeConnectTimeout_{0};
   bool sessionResumptionAttempted_{false};
+
+  std::unique_ptr<IOBuf> preReceivedData_;
 };
 
 } // namespace
index 1508d0e9174304180d079efa1cfa033e0ef99433..87c1453ce1cd492fc8a0feb9e652075b115dea1f 100644 (file)
@@ -1961,6 +1961,41 @@ TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
 }
 
+TEST(AsyncSSLSocketTest, TestPreReceivedData) {
+  EventBase clientEventBase;
+  EventBase serverEventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto dfServerCtx = std::make_shared<SSLContext>();
+  std::array<int, 2> fds;
+  getfds(fds.data());
+  getctx(clientCtx, dfServerCtx);
+
+  AsyncSSLSocket::UniquePtr clientSockPtr(
+      new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSockPtr(
+      new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
+  auto clientSock = clientSockPtr.get();
+  auto serverSock = serverSockPtr.get();
+  SSLHandshakeClient client(std::move(clientSockPtr), true, true);
+
+  // Steal some data from the server.
+  clientEventBase.loopOnce();
+  std::array<uint8_t, 10> buf;
+  recv(fds[1], buf.data(), buf.size(), 0);
+
+  serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
+  SSLHandshakeServer server(std::move(serverSockPtr), true, true);
+  while (!client.handshakeSuccess_ && !client.handshakeError_) {
+    serverEventBase.loopOnce();
+    clientEventBase.loopOnce();
+  }
+
+  EXPECT_TRUE(client.handshakeSuccess_);
+  EXPECT_TRUE(server.handshakeSuccess_);
+  EXPECT_EQ(
+      serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
+}
+
 #endif
 
 } // namespace