Replace MSG_PEEK with a pre-received data interface.
authorKyle Nekritz <knekritz@fb.com>
Thu, 9 Mar 2017 16:25:27 +0000 (08:25 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Mar 2017 16:35:27 +0000 (08:35 -0800)
Summary: MSG_PEEK was difficult if not impossible to use well since we do not provide a way wait for more data to arrive. If you are using setPeek on AsyncSocket, and you do not receive the amount of data you want, you must either abandon your peek attempt, or spin around the event base waiting for more data. This diff replaces the peek interface on AsyncSocket with a pre-received data interface, allowing users to insert data back onto the front of connections after reading some data in another layer.

Reviewed By: djwatson

Differential Revision: D4626315

fbshipit-source-id: c552e64f5b3ac9e40ea3358d65b4b9db848f5d74

folly/io/async/AsyncSSLSocket.cpp
folly/io/async/AsyncSSLSocket.h
folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/test/AsyncSocketTest2.cpp
folly/io/async/test/MockAsyncSSLSocket.h
folly/io/async/test/MockAsyncSocket.h

index a25269d2c55b8536d8ade0e2cc6176d294ae6711..3b99a4c4a05121cbfdddd39336877d57b3ed5c8f 100644 (file)
@@ -218,14 +218,38 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
 /**
  * Create a server/client AsyncSSLSocket
  */
-AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
-                               EventBase* evb, int fd, bool server,
-                               bool deferSecurityNegotiation) :
-    AsyncSocket(evb, fd),
-    server_(server),
-    ctx_(ctx),
-    handshakeTimeout_(this, evb),
-    connectionTimeout_(this, evb) {
+AsyncSSLSocket::AsyncSSLSocket(
+    const shared_ptr<SSLContext>& ctx,
+    EventBase* evb,
+    int fd,
+    bool server,
+    bool deferSecurityNegotiation)
+    : AsyncSocket(evb, fd),
+      server_(server),
+      ctx_(ctx),
+      handshakeTimeout_(this, evb),
+      connectionTimeout_(this, evb) {
+  noTransparentTls_ = true;
+  init();
+  if (server) {
+    SSL_CTX_set_info_callback(
+        ctx_->getSSLCtx(), AsyncSSLSocket::sslInfoCallback);
+  }
+  if (deferSecurityNegotiation) {
+    sslState_ = STATE_UNENCRYPTED;
+  }
+}
+
+AsyncSSLSocket::AsyncSSLSocket(
+    const shared_ptr<SSLContext>& ctx,
+    AsyncSocket::UniquePtr oldAsyncSocket,
+    bool server,
+    bool deferSecurityNegotiation)
+    : AsyncSocket(std::move(oldAsyncSocket)),
+      server_(server),
+      ctx_(ctx),
+      handshakeTimeout_(this, oldAsyncSocket->getEventBase()),
+      connectionTimeout_(this, oldAsyncSocket->getEventBase()) {
   noTransparentTls_ = true;
   init();
   if (server) {
@@ -254,11 +278,13 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
  * Create a client AsyncSSLSocket from an already connected fd
  * and allow tlsext_hostname to be sent in Client Hello.
  */
-AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
-                                 EventBase* evb, int fd,
-                               const std::string& serverName,
-                               bool deferSecurityNegotiation) :
-    AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
+AsyncSSLSocket::AsyncSSLSocket(
+    const shared_ptr<SSLContext>& ctx,
+    EventBase* evb,
+    int fd,
+    const std::string& serverName,
+    bool deferSecurityNegotiation)
+    : AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
   tlsextHostname_ = serverName;
 }
 #endif // FOLLY_OPENSSL_HAS_SNI
@@ -451,9 +477,7 @@ void AsyncSSLSocket::sslAccept(
   /* register for a read operation (waiting for CLIENT HELLO) */
   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
 
-  if (preReceivedData_) {
-    handleRead();
-  }
+  checkForImmediateRead();
 }
 
 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
@@ -985,6 +1009,8 @@ void AsyncSSLSocket::checkForImmediateRead() noexcept {
   // the socket to become readable again.
   if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
     AsyncSocket::handleRead();
+  } else {
+    AsyncSocket::checkForImmediateRead();
   }
 }
 
@@ -1684,12 +1710,6 @@ int AsyncSSLSocket::sslVerifyCallback(
     preverifyOk;
 }
 
-void AsyncSSLSocket::setPreReceivedData(std::unique_ptr<IOBuf> data) {
-  CHECK(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED);
-  CHECK(!preReceivedData_);
-  preReceivedData_ = std::move(data);
-}
-
 void AsyncSSLSocket::enableClientHelloParsing()  {
     parseClientHello_ = true;
     clientHelloInfo_.reset(new ssl::ClientHelloInfo());
index a8bb1e123aa2a22089b069ec59cb490c9b1c8820..2121c2ff89a857b7e9db571921afbb3726b1283a 100644 (file)
@@ -173,10 +173,22 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * @param deferSecurityNegotiation
    *          unencrypted data can be sent before sslConn/Accept
    */
-  AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
-                 EventBase* evb, int fd,
-                 bool server = true, bool deferSecurityNegotiation = false);
+  AsyncSSLSocket(
+      const std::shared_ptr<folly::SSLContext>& ctx,
+      EventBase* evb,
+      int fd,
+      bool server = true,
+      bool deferSecurityNegotiation = false);
 
+  /**
+   * Create a server/client AsyncSSLSocket from an already connected
+   * AsyncSocket.
+   */
+  AsyncSSLSocket(
+      const std::shared_ptr<folly::SSLContext>& ctx,
+      AsyncSocket::UniquePtr oldAsyncSocket,
+      bool server = true,
+      bool deferSecurityNegotiation = false);
 
   /**
    * Helper function to create a server/client shared_ptr<AsyncSSLSocket>.
@@ -227,11 +239,12 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * @param fd   File descriptor to take over (should be a connected socket).
    * @param serverName tlsext_hostname that will be sent in ClientHello.
    */
-  AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
-                  EventBase* evb,
-                  int fd,
-                 const std::string& serverName,
-                bool deferSecurityNegotiation = false);
+  AsyncSSLSocket(
+      const std::shared_ptr<folly::SSLContext>& ctx,
+      EventBase* evb,
+      int fd,
+      const std::string& serverName,
+      bool deferSecurityNegotiation = false);
 
   static std::shared_ptr<AsyncSSLSocket> newSocket(
     const std::shared_ptr<folly::SSLContext>& ctx,
@@ -276,8 +289,6 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   virtual size_t getRawBytesReceived() const override;
   void enableClientHelloParsing();
 
-  void setPreReceivedData(std::unique_ptr<IOBuf> data);
-
   /**
    * Accept an SSL connection on the socket.
    *
@@ -864,7 +875,6 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   bool sessionResumptionAttempted_{false};
   std::chrono::milliseconds totalConnectTimeout_{0};
 
-  std::unique_ptr<IOBuf> preReceivedData_;
   std::string sslVerificationAlert_;
 };
 
index 721686c517d20a4ad7b932c4fa672bcfd4285971..827ee8e9de9443a47852b61625ffc6d9fbe1dda2 100644 (file)
 #include <folly/io/async/AsyncSocket.h>
 
 #include <folly/ExceptionWrapper.h>
+#include <folly/Portability.h>
 #include <folly/SocketAddress.h>
+#include <folly/io/Cursor.h>
 #include <folly/io/IOBuf.h>
-#include <folly/Portability.h>
+#include <folly/io/IOBufQueue.h>
 #include <folly/portability/Fcntl.h>
 #include <folly/portability/Sockets.h>
 #include <folly/portability/SysUio.h>
@@ -229,6 +231,11 @@ AsyncSocket::AsyncSocket(EventBase* evb, int fd)
   state_ = StateEnum::ESTABLISHED;
 }
 
+AsyncSocket::AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket)
+    : AsyncSocket(oldAsyncSocket->getEventBase(), oldAsyncSocket->detachFd()) {
+  preReceivedData_ = std::move(oldAsyncSocket->preReceivedData_);
+}
+
 // init() method, since constructor forwarding isn't supported in most
 // compilers yet.
 void AsyncSocket::init() {
@@ -1406,12 +1413,23 @@ AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
   VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf
           << ", buflen=" << *buflen;
 
-  int recvFlags = 0;
-  if (peek_) {
-    recvFlags |= MSG_PEEK;
+  if (preReceivedData_ && !preReceivedData_->empty()) {
+    VLOG(5) << "AsyncSocket::performRead() this=" << this
+            << ", reading pre-received data";
+
+    io::Cursor cursor(preReceivedData_.get());
+    auto len = cursor.pullAtMost(*buf, *buflen);
+
+    IOBufQueue queue;
+    queue.append(std::move(preReceivedData_));
+    queue.trimStart(len);
+    preReceivedData_ = queue.move();
+
+    appBytesReceived_ += len;
+    return ReadResult(len);
   }
 
-  ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT | recvFlags);
+  ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT);
   if (bytes < 0) {
     if (errno == EAGAIN || errno == EWOULDBLOCK) {
       // No more data to read right now.
@@ -1762,6 +1780,12 @@ void AsyncSocket::checkForImmediateRead() noexcept {
   // be a pessimism.  In most cases it probably wouldn't be readable, and we
   // would just waste an extra system call.  Even if it is readable, waiting to
   // find out from libevent on the next event loop doesn't seem that bad.
+  //
+  // The exception to this is if we have pre-received data. In that case there
+  // is definitely data available immediately.
+  if (preReceivedData_ && !preReceivedData_->empty()) {
+    handleRead();
+  }
 }
 
 void AsyncSocket::handleInitialReadWrite() noexcept {
index 54b37705ae9b5879107485756fea1ed7df5f02e7..3e2adbce0f71a5e839b29f7016bbe26ab22ac451 100644 (file)
@@ -189,6 +189,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    */
   AsyncSocket(EventBase* evb, int fd);
 
+  /**
+   * Create an AsyncSocket from a different, already connected AsyncSocket.
+   *
+   * Similar to AsyncSocket(evb, fd) when fd was previously owned by an
+   * AsyncSocket.
+   */
+  explicit AsyncSocket(AsyncSocket::UniquePtr);
+
   /**
    * Helper function to create a shared_ptr<AsyncSocket>.
    *
@@ -264,6 +272,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    * error.  The AsyncSocket may no longer be used after the file descriptor
    * has been extracted.
    *
+   * This method should be used with care as the resulting fd is not guaranteed
+   * to perfectly reflect the state of the AsyncSocket (security state,
+   * pre-received data, etc.).
+   *
    * Returns the file descriptor.  The caller assumes ownership of the
    * descriptor, and it will not be closed when the AsyncSocket is destroyed.
    */
@@ -601,8 +613,16 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     return setsockopt(fd_, level, optname, optval, sizeof(T));
   }
 
-  virtual void setPeek(bool peek) {
-    peek_ = peek;
+  /**
+   * Set pre-received data, to be returned to read callback before any data
+   * from the socket.
+   */
+  virtual void setPreReceivedData(std::unique_ptr<IOBuf> data) {
+    if (preReceivedData_) {
+      preReceivedData_->prependChain(std::move(data));
+    } else {
+      preReceivedData_ = std::move(data);
+    }
   }
 
   /**
@@ -998,7 +1018,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   size_t appBytesWritten_;               ///< Num of bytes written to socket
   bool isBufferMovable_{false};
 
-  bool peek_{false}; // Peek bytes.
+  // Pre-received data, to be returned to read callback before any data from the
+  // socket.
+  std::unique_ptr<IOBuf> preReceivedData_;
 
   int8_t readErr_{READ_NO_ERROR};       ///< The read error encountered, if any.
 
index f532315153aaf2e6118aaf940204fc4544302c4e..864a40319edffc624ddf4cd04c391347573035c5 100644 (file)
@@ -2909,3 +2909,133 @@ TEST(AsyncSocketTest, ErrMessageCallback) {
   ASSERT_TRUE(errMsgCB.gotTimestamp_);
 }
 #endif // MSG_ERRQUEUE
+
+TEST(AsyncSocket, PreReceivedData) {
+  TestServer server;
+
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->connect(nullptr, server.getAddress(), 30);
+  evb.loop();
+
+  socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+  auto acceptedSocket = server.acceptAsync(&evb);
+
+  ReadCallback peekCallback(2);
+  ReadCallback readCallback;
+  peekCallback.dataAvailableCallback = [&]() {
+    peekCallback.verifyData("he", 2);
+    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("h"));
+    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("e"));
+    acceptedSocket->setReadCB(nullptr);
+    acceptedSocket->setReadCB(&readCallback);
+  };
+  readCallback.dataAvailableCallback = [&]() {
+    if (readCallback.dataRead() == 5) {
+      readCallback.verifyData("hello", 5);
+      acceptedSocket->setReadCB(nullptr);
+    }
+  };
+
+  acceptedSocket->setReadCB(&peekCallback);
+
+  evb.loop();
+}
+
+TEST(AsyncSocket, PreReceivedDataOnly) {
+  TestServer server;
+
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->connect(nullptr, server.getAddress(), 30);
+  evb.loop();
+
+  socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+  auto acceptedSocket = server.acceptAsync(&evb);
+
+  ReadCallback peekCallback;
+  ReadCallback readCallback;
+  peekCallback.dataAvailableCallback = [&]() {
+    peekCallback.verifyData("hello", 5);
+    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
+    acceptedSocket->setReadCB(&readCallback);
+  };
+  readCallback.dataAvailableCallback = [&]() {
+    readCallback.verifyData("hello", 5);
+    acceptedSocket->setReadCB(nullptr);
+  };
+
+  acceptedSocket->setReadCB(&peekCallback);
+
+  evb.loop();
+}
+
+TEST(AsyncSocket, PreReceivedDataPartial) {
+  TestServer server;
+
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->connect(nullptr, server.getAddress(), 30);
+  evb.loop();
+
+  socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+  auto acceptedSocket = server.acceptAsync(&evb);
+
+  ReadCallback peekCallback;
+  ReadCallback smallReadCallback(3);
+  ReadCallback normalReadCallback;
+  peekCallback.dataAvailableCallback = [&]() {
+    peekCallback.verifyData("hello", 5);
+    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
+    acceptedSocket->setReadCB(&smallReadCallback);
+  };
+  smallReadCallback.dataAvailableCallback = [&]() {
+    smallReadCallback.verifyData("hel", 3);
+    acceptedSocket->setReadCB(&normalReadCallback);
+  };
+  normalReadCallback.dataAvailableCallback = [&]() {
+    normalReadCallback.verifyData("lo", 2);
+    acceptedSocket->setReadCB(nullptr);
+  };
+
+  acceptedSocket->setReadCB(&peekCallback);
+
+  evb.loop();
+}
+
+TEST(AsyncSocket, PreReceivedDataTakeover) {
+  TestServer server;
+
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->connect(nullptr, server.getAddress(), 30);
+  evb.loop();
+
+  socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
+
+  auto acceptedSocket =
+      AsyncSocket::UniquePtr(new AsyncSocket(&evb, server.acceptFD()));
+  AsyncSocket::UniquePtr takeoverSocket;
+
+  ReadCallback peekCallback(3);
+  ReadCallback readCallback;
+  peekCallback.dataAvailableCallback = [&]() {
+    peekCallback.verifyData("hel", 3);
+    acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
+    acceptedSocket->setReadCB(nullptr);
+    takeoverSocket =
+        AsyncSocket::UniquePtr(new AsyncSocket(std::move(acceptedSocket)));
+    takeoverSocket->setReadCB(&readCallback);
+  };
+  readCallback.dataAvailableCallback = [&]() {
+    readCallback.verifyData("hello", 5);
+    takeoverSocket->setReadCB(nullptr);
+  };
+
+  acceptedSocket->setReadCB(&peekCallback);
+
+  evb.loop();
+}
index 7ab4fda3e5d835a0c5dd869dad1bd69a1df4ba53..a627ba4645788181b22e0579090da37c18206393 100644 (file)
@@ -50,7 +50,6 @@ class MockAsyncSSLSocket : public AsyncSSLSocket {
                      bool(const unsigned char**,
                           unsigned*,
                           SSLContext::NextProtocolType*));
-  MOCK_METHOD1(setPeek, void(bool));
   MOCK_METHOD1(setReadCB, void(ReadCallback*));
 
   void sslConn(
index cf55874ee52a53e3fb243d3a0b38fb052a0e90b8..d6cb3da2fd5a635c66a95aa54d2fbd11c14acca2 100644 (file)
@@ -45,8 +45,11 @@ class MockAsyncSocket : public AsyncSocket {
   MOCK_CONST_METHOD0(good, bool());
   MOCK_CONST_METHOD0(readable, bool());
   MOCK_CONST_METHOD0(hangup, bool());
-  MOCK_METHOD1(setPeek, void(bool));
   MOCK_METHOD1(setReadCB, void(ReadCallback*));
+  MOCK_METHOD1(_setPreReceivedData, void(std::unique_ptr<IOBuf>&));
+  void setPreReceivedData(std::unique_ptr<IOBuf> data) override {
+    return _setPreReceivedData(data);
+  }
 };
 
 }}