Check readCallback before calling handleRead
authorSarang Masti <mssarang@fb.com>
Tue, 9 Jun 2015 18:17:12 +0000 (11:17 -0700)
committerSara Golemon <sgolemon@fb.com>
Mon, 15 Jun 2015 20:34:36 +0000 (13:34 -0700)
Summary: Since readCallback_ could be uninstalled in any of callbacks,
we need to ensure that readCallback_ != nullptr before calling
handleRead.

Reviewed By: @djwatson

Differential Revision: D2140054

folly/io/async/AsyncSocket.cpp
folly/io/async/test/AsyncSocketTest.h
folly/io/async/test/AsyncSocketTest2.cpp

index 3ff2b345eba1855ad30057100da9ded2f6be8e81..205261621e0bdc6569f58b94fb765a263a97c545 100644 (file)
@@ -527,6 +527,12 @@ void AsyncSocket::setReadCB(ReadCallback *callback) {
     return;
   }
 
+  /* We are removing a read callback */
+  if (callback == nullptr &&
+      immediateReadHandler_.isLoopCallbackScheduled()) {
+    immediateReadHandler_.cancelLoopCallback();
+  }
+
   if (shutdownFlags_ & SHUT_READ) {
     // Reads have already been shut down on this socket.
     //
@@ -1330,9 +1336,11 @@ void AsyncSocket::handleRead() noexcept {
       return;
     }
     if (maxReadsPerEvent_ && (++numReads >= maxReadsPerEvent_)) {
-      // We might still have data in the socket.
-      // (e.g. see comment in AsyncSSLSocket::checkForImmediateRead)
-      scheduleImmediateRead();
+      if (readCallback_ != nullptr) {
+        // We might still have data in the socket.
+        // (e.g. see comment in AsyncSSLSocket::checkForImmediateRead)
+        scheduleImmediateRead();
+      }
       return;
     }
   }
index 2c25d0e5ca8e5d6ee1e45f13cebaee29c81c745e..51230014203c038cda8fabb1872c61a1b6623ff5 100644 (file)
@@ -93,10 +93,11 @@ class WriteCallback : public AsyncTransportWrapper::WriteCallback {
 
 class ReadCallback : public AsyncTransportWrapper::ReadCallback {
  public:
-  ReadCallback()
+  explicit ReadCallback(size_t _maxBufferSz = 4096)
     : state(STATE_WAITING)
     , exception(AsyncSocketException::UNKNOWN, "none")
-    , buffers() {}
+    , buffers()
+    , maxBufferSz(_maxBufferSz) {}
 
   ~ReadCallback() {
     for (std::vector<Buffer>::iterator it = buffers.begin();
@@ -109,7 +110,7 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback {
 
   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
     if (!currentBuffer.buffer) {
-      currentBuffer.allocate(4096);
+      currentBuffer.allocate(maxBufferSz);
     }
     *bufReturn = currentBuffer.buffer;
     *lenReturn = currentBuffer.length;
@@ -145,6 +146,14 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback {
     CHECK_EQ(offset, expectedLen);
   }
 
+  size_t dataRead() const {
+    size_t ret = 0;
+    for (const auto& buf : buffers) {
+      ret += buf.length;
+    }
+    return ret;
+  }
+
   class Buffer {
    public:
     Buffer() : buffer(nullptr), length(0) {}
@@ -173,6 +182,7 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback {
   std::vector<Buffer> buffers;
   Buffer currentBuffer;
   VoidCallback dataAvailableCallback;
+  const size_t maxBufferSz;
 };
 
 class ReadVerifier {
index bcbe5ec880a46bfe617e528ba03ea1245d32db33..f44d4fd52779806a9f47de97541f32e066529f8a 100644 (file)
@@ -1253,6 +1253,115 @@ TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) {
   }
 }
 
+///////////////////////////////////////////////////////////////////////////
+// ImmediateRead related tests
+///////////////////////////////////////////////////////////////////////////
+
+/* AsyncSocket use to verify immediate read works */
+class AsyncSocketImmediateRead : public folly::AsyncSocket {
+ public:
+  bool immediateReadCalled = false;
+  explicit AsyncSocketImmediateRead(folly::EventBase* evb) : AsyncSocket(evb) {}
+ protected:
+  virtual void checkForImmediateRead() noexcept override {
+    immediateReadCalled = true;
+    AsyncSocket::handleRead();
+  }
+};
+
+TEST(AsyncSocket, ConnectReadImmediateRead) {
+  TestServer server;
+
+  const size_t maxBufferSz = 100;
+  const size_t maxReadsPerEvent = 1;
+  const size_t expectedDataSz = maxBufferSz * 3;
+  char expectedData[expectedDataSz];
+  memset(expectedData, 'j', expectedDataSz);
+
+  EventBase evb;
+  ReadCallback rcb(maxBufferSz);
+  AsyncSocketImmediateRead socket(&evb);
+  socket.connect(nullptr, server.getAddress(), 30);
+
+  evb.loop(); // loop until the socket is connected
+
+  socket.setReadCB(&rcb);
+  socket.setMaxReadsPerEvent(maxReadsPerEvent);
+  socket.immediateReadCalled = false;
+
+  auto acceptedSocket = server.acceptAsync(&evb);
+
+  ReadCallback rcbServer;
+  WriteCallback wcbServer;
+  rcbServer.dataAvailableCallback = [&]() {
+    if (rcbServer.dataRead() == expectedDataSz) {
+      // write back all data read
+      rcbServer.verifyData(expectedData, expectedDataSz);
+      acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
+      acceptedSocket->close();
+    }
+  };
+  acceptedSocket->setReadCB(&rcbServer);
+
+  // write data
+  WriteCallback wcb1;
+  socket.write(&wcb1, expectedData, expectedDataSz);
+  evb.loop();
+  CHECK_EQ(wcb1.state, STATE_SUCCEEDED);
+  rcb.verifyData(expectedData, expectedDataSz);
+  CHECK_EQ(socket.immediateReadCalled, true);
+}
+
+TEST(AsyncSocket, ConnectReadUninstallRead) {
+  TestServer server;
+
+  const size_t maxBufferSz = 100;
+  const size_t maxReadsPerEvent = 1;
+  const size_t expectedDataSz = maxBufferSz * 3;
+  char expectedData[expectedDataSz];
+  memset(expectedData, 'k', expectedDataSz);
+
+  EventBase evb;
+  ReadCallback rcb(maxBufferSz);
+  AsyncSocketImmediateRead socket(&evb);
+  socket.connect(nullptr, server.getAddress(), 30);
+
+  evb.loop(); // loop until the socket is connected
+
+  socket.setReadCB(&rcb);
+  socket.setMaxReadsPerEvent(maxReadsPerEvent);
+  socket.immediateReadCalled = false;
+
+  auto acceptedSocket = server.acceptAsync(&evb);
+
+  ReadCallback rcbServer;
+  WriteCallback wcbServer;
+  rcbServer.dataAvailableCallback = [&]() {
+    if (rcbServer.dataRead() == expectedDataSz) {
+      // write back all data read
+      rcbServer.verifyData(expectedData, expectedDataSz);
+      acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
+      acceptedSocket->close();
+    }
+  };
+  acceptedSocket->setReadCB(&rcbServer);
+
+  rcb.dataAvailableCallback = [&]() {
+    // we read data and reset readCB
+    socket.setReadCB(nullptr);
+  };
+
+  // write data
+  WriteCallback wcb;
+  socket.write(&wcb, expectedData, expectedDataSz);
+  evb.loop();
+  CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+
+  /* we shoud've only read maxBufferSz data since readCallback_
+   * was reset in dataAvailableCallback */
+  CHECK_EQ(rcb.dataRead(), maxBufferSz);
+  CHECK_EQ(socket.immediateReadCalled, false);
+}
 
 // TODO:
 // - Test connect() and have the connect callback set the read callback