Simplify TFO write path
authorSubodh Iyengar <subodh@fb.com>
Wed, 17 Aug 2016 04:52:08 +0000 (21:52 -0700)
committerFacebook Github Bot 2 <facebook-github-bot-2-bot@fb.com>
Wed, 17 Aug 2016 04:53:42 +0000 (21:53 -0700)
Summary:
We currently call handleInitialReadWrite.
The reason for this is that if the read callback
was set before TFO was done with connecting, then
we need to call handleinitialreadwrite to setup the
read callback similar to how connect invokes handleInitialReadWrite
after it's done.

However handleinitalreadwrite may also call handleWrite
if writeReqHead_ is non null.
Practically this will not happen since TFO will happen on
the first write only where writeReqHead_ will be null.

The current code path though is a little bit complicated.
This simplfies the code so that we dont need to potentially
call handleWrite within a write call.

We schedule the initial readwrite call asynchrously.
The reason for this is that handleReadWrite can actually fail if updating
events fails. This might cause weird state issues once it returns and we
have no mechanism of processing it.

Reviewed By: djwatson

Differential Revision: D3695925

fbshipit-source-id: 72e19a9e1802caa14e872e05a5cd9bf4e34c5e7d

folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/test/AsyncSocketTest2.cpp

index ab988041a915dc8f1f47e10d1f1af15f7b52005d..18ef98752e732f578001fe860fd4d6311b5d80c9 100644 (file)
@@ -174,19 +174,19 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
 };
 
 AsyncSocket::AsyncSocket()
-  : eventBase_(nullptr)
-  , writeTimeout_(this, nullptr)
-  , ioHandler_(this, nullptr)
-  , immediateReadHandler_(this) {
+    : eventBase_(nullptr),
+      writeTimeout_(this, nullptr),
+      ioHandler_(this, nullptr),
+      immediateReadHandler_(this) {
   VLOG(5) << "new AsyncSocket()";
   init();
 }
 
 AsyncSocket::AsyncSocket(EventBase* evb)
-  : eventBase_(evb)
-  , writeTimeout_(this, evb)
-  , ioHandler_(this, evb)
-  , immediateReadHandler_(this) {
+    : eventBase_(evb),
+      writeTimeout_(this, evb),
+      ioHandler_(this, evb),
+      immediateReadHandler_(this) {
   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
   init();
 }
@@ -207,10 +207,10 @@ AsyncSocket::AsyncSocket(EventBase* evb,
 }
 
 AsyncSocket::AsyncSocket(EventBase* evb, int fd)
-  : eventBase_(evb)
-  , writeTimeout_(this, evb)
-  , ioHandler_(this, evb, fd)
-  , immediateReadHandler_(this) {
+    : eventBase_(evb),
+      writeTimeout_(this, evb),
+      ioHandler_(this, evb, fd),
+      immediateReadHandler_(this) {
   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
           << fd << ")";
   init();
@@ -1615,7 +1615,6 @@ void AsyncSocket::handleInitialReadWrite() noexcept {
   // one here just to make sure, in case one of our calling code paths ever
   // changes.
   DestructorGuard dg(this);
-
   // If we have a readCallback_, make sure we enable read events.  We
   // may already be registered for reads if connectSuccess() set
   // the read calback.
@@ -1772,7 +1771,9 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
     if (totalWritten >= 0) {
       tfoFinished_ = true;
       state_ = StateEnum::ESTABLISHED;
-      handleInitialReadWrite();
+      // We schedule this asynchrously so that we don't end up
+      // invoking initial read or write while a write is in progress.
+      scheduleInitialReadWrite();
     } else if (errno == EINPROGRESS) {
       VLOG(4) << "TFO falling back to connecting";
       // A normal sendmsg doesn't return EINPROGRESS, however
@@ -1798,7 +1799,7 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
           // connect succeeded immediately
           // Treat this like no data was written.
           state_ = StateEnum::ESTABLISHED;
-          handleInitialReadWrite();
+          scheduleInitialReadWrite();
         }
         // If there was no exception during connections,
         // we would return that no bytes were written.
index 36949725c3558639de74e124f423430ee951e788..6e0fb77b0e0283238b394f3e99f4f969c606da8a 100644 (file)
@@ -735,6 +735,20 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     }
   }
 
+  /**
+   * Schedule handleInitalReadWrite to run in the next iteration.
+   */
+  void scheduleInitialReadWrite() noexcept {
+    if (good()) {
+      DestructorGuard dg(this);
+      eventBase_->runInLoop([this, dg] {
+        if (good()) {
+          handleInitialReadWrite();
+        }
+      });
+    }
+  }
+
   // event notification methods
   void ioReady(uint16_t events) noexcept;
   virtual void checkForImmediateRead() noexcept;
index 7ba0a4f89490898d93d61f86933501a767261d0a..35958c32fdcff3b672b034a4a738ff888594f721 100644 (file)
@@ -2410,6 +2410,55 @@ TEST(AsyncSocketTest, ConnectTFO) {
   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
 }
 
+TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) {
+  // Start listening on a local port
+  TestServer server(true);
+
+  // Connect using a AsyncSocket
+  EventBase evb;
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  socket->enableTFO();
+  ConnCallback cb;
+  socket->connect(&cb, server.getAddress(), 30);
+  ReadCallback rcb;
+  socket->setReadCB(&rcb);
+
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+
+  std::array<uint8_t, 3> readBuf;
+  auto sendBuf = IOBuf::copyBuffer("hey");
+
+  std::thread t([&] {
+    auto acceptedSocket = server.accept();
+    acceptedSocket->write(buf.data(), buf.size());
+    acceptedSocket->flush();
+    acceptedSocket->readAll(readBuf.data(), readBuf.size());
+    acceptedSocket->close();
+  });
+
+  evb.loop();
+
+  CHECK_EQ(cb.state, STATE_SUCCEEDED);
+  EXPECT_LE(0, socket->getConnectTime().count());
+  EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
+  EXPECT_TRUE(socket->getTFOAttempted());
+
+  // Should trigger the connect
+  WriteCallback write;
+  socket->writeChain(&write, sendBuf->clone());
+  evb.loop();
+
+  t.join();
+
+  EXPECT_EQ(STATE_SUCCEEDED, write.state);
+  EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
+  EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
+  ASSERT_EQ(1, rcb.buffers.size());
+  ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
+  EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
+}
+
 /**
  * Test connecting to a server that isn't listening
  */