Add a buffer callback to AsyncSocket
authorYang Chi <yangchi@fb.com>
Wed, 11 Nov 2015 21:32:04 +0000 (13:32 -0800)
committerfacebook-github-bot-9 <folly-bot@fb.com>
Wed, 11 Nov 2015 22:20:22 +0000 (14:20 -0800)
Summary: This is probably easier than D2612490. The idea is just to add a callback to write, writev and writeChain in AsyncSocket, so upper layer can know when data starts to buffer up

Reviewed By: mzlee

Differential Revision: D2623385

fb-gh-sync-id: 98d32ca83871aaa4f6c75a769b5f1bf0b5d62c3e

folly/io/async/AsyncPipe.cpp
folly/io/async/AsyncPipe.h
folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/AsyncTransport.h
folly/io/async/test/AsyncSocketTest.h
folly/io/async/test/AsyncSocketTest2.cpp
folly/io/async/test/MockAsyncTransport.h

index 206f455765fab3d2e366901196ee0d1e9d1dccf8..b4263346917849e7ec7d8b9b3dc57a190849543c 100644 (file)
@@ -148,7 +148,8 @@ void AsyncPipeWriter::write(unique_ptr<folly::IOBuf> buf,
 
 void AsyncPipeWriter::writeChain(folly::AsyncWriter::WriteCallback* callback,
                                  std::unique_ptr<folly::IOBuf>&& buf,
-                                 WriteFlags) {
+                                 WriteFlags,
+                                 BufferCallback*) {
   write(std::move(buf), callback);
 }
 
index 40d5021a19aa92286ffa9784f751215a4c29453d..efa659b24687a524536ed9ca30925a896d579950 100644 (file)
@@ -148,16 +148,19 @@ class AsyncPipeWriter : public EventHandler,
 
   // AsyncWriter methods
   void write(folly::AsyncWriter::WriteCallback* callback, const void* buf,
-             size_t bytes, WriteFlags flags = WriteFlags::NONE) override {
-    writeChain(callback, IOBuf::wrapBuffer(buf, bytes), flags);
+             size_t bytes, WriteFlags flags = WriteFlags::NONE,
+             BufferCallback* bufCallback = nullptr) override {
+    writeChain(callback, IOBuf::wrapBuffer(buf, bytes), flags, bufCallback);
   }
   void writev(folly::AsyncWriter::WriteCallback*, const iovec*,
-              size_t, WriteFlags = WriteFlags::NONE) override {
+              size_t, WriteFlags = WriteFlags::NONE,
+              BufferCallback* = nullptr) override {
     throw std::runtime_error("writev is not supported. Please use writeChain.");
   }
   void writeChain(folly::AsyncWriter::WriteCallback* callback,
                   std::unique_ptr<folly::IOBuf>&& buf,
-                  WriteFlags flags = WriteFlags::NONE) override;
+                  WriteFlags flags = WriteFlags::NONE,
+                  BufferCallback* bufCallback = nullptr) override;
 
  private:
   void handlerReady(uint16_t events) noexcept override;
index 86eba2ff39d1153a7322b4ea90c7c2d8dc52d218..2866824535b1b2f07dfd8d151e007bb74a2f18b6 100644 (file)
@@ -63,14 +63,16 @@ const AsyncSocketException socketShutdownForWritesEx(
  */
 class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
  public:
-  static BytesWriteRequest* newRequest(AsyncSocket* socket,
-                                       WriteCallback* callback,
-                                       const iovec* ops,
-                                       uint32_t opCount,
-                                       uint32_t partialWritten,
-                                       uint32_t bytesWritten,
-                                       unique_ptr<IOBuf>&& ioBuf,
-                                       WriteFlags flags) {
+  static BytesWriteRequest* newRequest(
+      AsyncSocket* socket,
+      WriteCallback* callback,
+      const iovec* ops,
+      uint32_t opCount,
+      uint32_t partialWritten,
+      uint32_t bytesWritten,
+      unique_ptr<IOBuf>&& ioBuf,
+      WriteFlags flags,
+      BufferCallback* bufferCallback = nullptr) {
     assert(opCount > 0);
     // Since we put a variable size iovec array at the end
     // of each BytesWriteRequest, we have to manually allocate the memory.
@@ -82,7 +84,7 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
 
     return new(buf) BytesWriteRequest(socket, callback, ops, opCount,
                                       partialWritten, bytesWritten,
-                                      std::move(ioBuf), flags);
+                                      std::move(ioBuf), flags, bufferCallback);
   }
 
   void destroy() override {
@@ -136,8 +138,9 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
                     uint32_t partialBytes,
                     uint32_t bytesWritten,
                     unique_ptr<IOBuf>&& ioBuf,
-                    WriteFlags flags)
-    : AsyncSocket::WriteRequest(socket, callback)
+                    WriteFlags flags,
+                    BufferCallback* bufferCallback = nullptr)
+    : AsyncSocket::WriteRequest(socket, callback, bufferCallback)
     , opCount_(opCount)
     , opIndex_(0)
     , flags_(flags)
@@ -608,43 +611,46 @@ AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
 }
 
 void AsyncSocket::write(WriteCallback* callback,
-                         const void* buf, size_t bytes, WriteFlags flags) {
+                         const void* buf, size_t bytes, WriteFlags flags,
+                         BufferCallback* bufCallback) {
   iovec op;
   op.iov_base = const_cast<void*>(buf);
   op.iov_len = bytes;
-  writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), flags);
+  writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), flags, bufCallback);
 }
 
 void AsyncSocket::writev(WriteCallback* callback,
                           const iovec* vec,
                           size_t count,
-                          WriteFlags flags) {
-  writeImpl(callback, vec, count, unique_ptr<IOBuf>(), flags);
+                          WriteFlags flags,
+                          BufferCallback* bufCallback) {
+  writeImpl(callback, vec, count, unique_ptr<IOBuf>(), flags, bufCallback);
 }
 
 void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
-                              WriteFlags flags) {
+                              WriteFlags flags, BufferCallback* bufCallback) {
   constexpr size_t kSmallSizeMax = 64;
   size_t count = buf->countChainElements();
   if (count <= kSmallSizeMax) {
     iovec vec[BOOST_PP_IF(FOLLY_HAVE_VLA, count, kSmallSizeMax)];
-    writeChainImpl(callback, vec, count, std::move(buf), flags);
+    writeChainImpl(callback, vec, count, std::move(buf), flags, bufCallback);
   } else {
     iovec* vec = new iovec[count];
-    writeChainImpl(callback, vec, count, std::move(buf), flags);
+    writeChainImpl(callback, vec, count, std::move(buf), flags, bufCallback);
     delete[] vec;
   }
 }
 
 void AsyncSocket::writeChainImpl(WriteCallback* callback, iovec* vec,
-    size_t count, unique_ptr<IOBuf>&& buf, WriteFlags flags) {
+    size_t count, unique_ptr<IOBuf>&& buf, WriteFlags flags,
+    BufferCallback* bufCallback) {
   size_t veclen = buf->fillIov(vec, count);
-  writeImpl(callback, vec, veclen, std::move(buf), flags);
+  writeImpl(callback, vec, veclen, std::move(buf), flags, bufCallback);
 }
 
 void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
                              size_t count, unique_ptr<IOBuf>&& buf,
-                             WriteFlags flags) {
+                             WriteFlags flags, BufferCallback* bufCallback) {
   VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_
           << ", callback=" << callback << ", count=" << count
           << ", state=" << state_;
@@ -688,7 +694,11 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
           callback->writeSuccess();
         }
         return;
-      } // else { continue writing the next writeReq }
+      } else { // continue writing the next writeReq
+        if (bufCallback) {
+          bufCallback->onEgressBuffered();
+        }
+      }
       mustRegister = true;
     }
   } else if (!connecting()) {
@@ -701,7 +711,8 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
   try {
     req = BytesWriteRequest::newRequest(this, callback, vec + countWritten,
                                         count - countWritten, partialWritten,
-                                        bytesWritten, std::move(ioBuf), flags);
+                                        bytesWritten, std::move(ioBuf), flags,
+                                        bufCallback);
   } catch (const std::exception& ex) {
     // we mainly expect to catch std::bad_alloc here
     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
@@ -1473,6 +1484,11 @@ void AsyncSocket::handleWrite() noexcept {
       }
       // We'll continue around the loop, trying to write another request
     } else {
+      // Notify BufferCallback:
+      BufferCallback* bufferCallback = writeReqHead_->getBufferCallback();
+      if (bufferCallback) {
+        bufferCallback->onEgressBuffered();
+      }
       // Partial write.
       writeReqHead_->consume();
       // Stop after a partial write; it's highly likely that a subsequent write
index b7eeafc7eca21b0d2144283c5de974876ac20078..de1f5c23bfc22b3a006fdca55d770a225887f9bc 100644 (file)
@@ -328,12 +328,15 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   ReadCallback* getReadCallback() const override;
 
   void write(WriteCallback* callback, const void* buf, size_t bytes,
-             WriteFlags flags = WriteFlags::NONE) override;
+             WriteFlags flags = WriteFlags::NONE,
+             BufferCallback* bufCallback = nullptr) override;
   void writev(WriteCallback* callback, const iovec* vec, size_t count,
-              WriteFlags flags = WriteFlags::NONE) override;
+              WriteFlags flags = WriteFlags::NONE,
+              BufferCallback* bufCallback = nullptr) override;
   void writeChain(WriteCallback* callback,
                   std::unique_ptr<folly::IOBuf>&& buf,
-                  WriteFlags flags = WriteFlags::NONE) override;
+                  WriteFlags flags = WriteFlags::NONE,
+                  BufferCallback* bufCallback = nullptr) override;
 
   class WriteRequest;
   virtual void writeRequest(WriteRequest* req);
@@ -507,8 +510,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    */
   class WriteRequest {
    public:
-    WriteRequest(AsyncSocket* socket, WriteCallback* callback) :
-      socket_(socket), callback_(callback) {}
+    WriteRequest(
+        AsyncSocket* socket,
+        WriteCallback* callback,
+        BufferCallback* bufferCallback = nullptr) :
+      socket_(socket), callback_(callback), bufferCallback_(bufferCallback) {}
 
     virtual void start() {};
 
@@ -546,6 +552,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
       socket_->appBytesWritten_ += count;
     }
 
+    BufferCallback* getBufferCallback() const {
+      return bufferCallback_;
+    }
+
    protected:
     // protected destructor, to ensure callers use destroy()
     virtual ~WriteRequest() {}
@@ -554,6 +564,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     WriteRequest* next_{nullptr};          ///< pointer to next WriteRequest
     WriteCallback* callback_;     ///< completion callback
     uint32_t totalBytesWritten_{0};  ///< total bytes written
+    BufferCallback* bufferCallback_{nullptr};
   };
 
  protected:
@@ -677,36 +688,39 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   /**
    * Populate an iovec array from an IOBuf and attempt to write it.
    *
-   * @param callback Write completion/error callback.
-   * @param vec      Target iovec array; caller retains ownership.
-   * @param count    Number of IOBufs to write, beginning at start of buf.
-   * @param buf      Chain of iovecs.
-   * @param flags    set of flags for the underlying write calls, like cork
+   * @param callback    Write completion/error callback.
+   * @param vec         Target iovec array; caller retains ownership.
+   * @param count       Number of IOBufs to write, beginning at start of buf.
+   * @param buf         Chain of iovecs.
+   * @param flags       set of flags for the underlying write calls, like cork
+   * @param bufCallback Callback when egress data begins to buffer
    */
   void writeChainImpl(WriteCallback* callback, iovec* vec,
                       size_t count, std::unique_ptr<folly::IOBuf>&& buf,
-                      WriteFlags flags);
+                      WriteFlags flags, BufferCallback* bufCallback = nullptr);
 
   /**
    * Write as much data as possible to the socket without blocking,
    * and queue up any leftover data to send when the socket can
    * handle writes again.
    *
-   * @param callback The callback to invoke when the write is completed.
-   * @param vec      Array of buffers to write; this method will make a
-   *                 copy of the vector (but not the buffers themselves)
-   *                 if the write has to be completed asynchronously.
-   * @param count    Number of elements in vec.
-   * @param buf      The IOBuf that manages the buffers referenced by
-   *                 vec, or a pointer to nullptr if the buffers are not
-   *                 associated with an IOBuf.  Note that ownership of
-   *                 the IOBuf is transferred here; upon completion of
-   *                 the write, the AsyncSocket deletes the IOBuf.
-   * @param flags    Set of write flags.
+   * @param callback    The callback to invoke when the write is completed.
+   * @param vec         Array of buffers to write; this method will make a
+   *                    copy of the vector (but not the buffers themselves)
+   *                    if the write has to be completed asynchronously.
+   * @param count       Number of elements in vec.
+   * @param buf         The IOBuf that manages the buffers referenced by
+   *                    vec, or a pointer to nullptr if the buffers are not
+   *                    associated with an IOBuf.  Note that ownership of
+   *                    the IOBuf is transferred here; upon completion of
+   *                    the write, the AsyncSocket deletes the IOBuf.
+   * @param flags       Set of write flags.
+   * @param bufCallback Callback when egress data buffers up
    */
   void writeImpl(WriteCallback* callback, const iovec* vec, size_t count,
                  std::unique_ptr<folly::IOBuf>&& buf,
-                 WriteFlags flags = WriteFlags::NONE);
+                 WriteFlags flags = WriteFlags::NONE,
+                 BufferCallback* bufCallback = nullptr);
 
   /**
    * Attempt to write to the socket.
index 13bc6c94625cb317fe90656fd6e6b53382ce58f1..031b88e41e733b06d4f72e1b9198c66328683da2 100644 (file)
@@ -464,6 +464,12 @@ class AsyncReader {
 
 class AsyncWriter {
  public:
+  class BufferCallback {
+   public:
+    virtual ~BufferCallback() {}
+    virtual void onEgressBuffered() = 0;
+  };
+
   class WriteCallback {
    public:
     virtual ~WriteCallback() = default;
@@ -493,12 +499,15 @@ class AsyncWriter {
 
   // Write methods that aren't part of AsyncTransport
   virtual void write(WriteCallback* callback, const void* buf, size_t bytes,
-                     WriteFlags flags = WriteFlags::NONE) = 0;
+                     WriteFlags flags = WriteFlags::NONE,
+                     BufferCallback* bufCallback = nullptr) = 0;
   virtual void writev(WriteCallback* callback, const iovec* vec, size_t count,
-                      WriteFlags flags = WriteFlags::NONE) = 0;
+                      WriteFlags flags = WriteFlags::NONE,
+                      BufferCallback* bufCallback = nullptr) = 0;
   virtual void writeChain(WriteCallback* callback,
                           std::unique_ptr<IOBuf>&& buf,
-                          WriteFlags flags = WriteFlags::NONE) = 0;
+                          WriteFlags flags = WriteFlags::NONE,
+                          BufferCallback* bufCallback = nullptr) = 0;
 
  protected:
   virtual ~AsyncWriter() = default;
@@ -516,15 +525,19 @@ class AsyncTransportWrapper : virtual public AsyncTransport,
   // to keep compatibility.
   using ReadCallback    = AsyncReader::ReadCallback;
   using WriteCallback   = AsyncWriter::WriteCallback;
+  using BufferCallback  = AsyncWriter::BufferCallback;
   virtual void setReadCB(ReadCallback* callback) override = 0;
   virtual ReadCallback* getReadCallback() const override = 0;
   virtual void write(WriteCallback* callback, const void* buf, size_t bytes,
-                     WriteFlags flags = WriteFlags::NONE) override = 0;
+                     WriteFlags flags = WriteFlags::NONE,
+                     BufferCallback* bufCallback = nullptr) override = 0;
   virtual void writev(WriteCallback* callback, const iovec* vec, size_t count,
-                      WriteFlags flags = WriteFlags::NONE) override = 0;
+                      WriteFlags flags = WriteFlags::NONE,
+                      BufferCallback* bufCallback = nullptr) override = 0;
   virtual void writeChain(WriteCallback* callback,
                           std::unique_ptr<IOBuf>&& buf,
-                          WriteFlags flags = WriteFlags::NONE) override = 0;
+                          WriteFlags flags = WriteFlags::NONE,
+                          BufferCallback* bufCallback = nullptr) override = 0;
   /**
    * The transport wrapper may wrap another transport. This returns the
    * transport that is wrapped. It returns nullptr if there is no wrapped
index 51230014203c038cda8fabb1872c61a1b6623ff5..5d52ad204f68faa2859d91ec4c4c8bb42b58126d 100644 (file)
@@ -60,6 +60,23 @@ class ConnCallback : public AsyncSocket::ConnectCallback {
   VoidCallback errorCallback;
 };
 
+class BufferCallback : public AsyncTransportWrapper::BufferCallback {
+ public:
+  BufferCallback()
+    : buffered_(false) {}
+
+  void onEgressBuffered() override {
+    buffered_ = true;
+  }
+
+  bool hasBuffered() const {
+    return buffered_;
+  }
+
+ private:
+  bool buffered_{false};
+};
+
 class WriteCallback : public AsyncTransportWrapper::WriteCallback {
  public:
   WriteCallback()
index 1a5ebeb84e1d79f9d5a18f62161e90b7a6cd1495..81acc8265a1e74f07d96185ffe9a19b4130628e1 100644 (file)
@@ -2238,3 +2238,32 @@ TEST(AsyncSocketTest, NumPendingMessagesInQueue) {
 
   eventBase.loop();
 }
+
+TEST(AsyncSocketTest, BufferTest) {
+  TestServer server;
+
+  EventBase evb;
+  AsyncSocket::OptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
+  std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
+  ConnCallback ccb;
+  socket->connect(&ccb, server.getAddress(), 30, option);
+
+
+  char buf[100 * 1024];
+  memset(buf, 'c', sizeof(buf));
+  WriteCallback wcb;
+  BufferCallback bcb;
+  socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE, &bcb);
+
+  evb.loop();
+  CHECK_EQ(ccb.state, STATE_SUCCEEDED);
+  CHECK_EQ(wcb.state, STATE_SUCCEEDED);
+
+  ASSERT_TRUE(bcb.hasBuffered());
+
+  socket->close();
+  server.verifyConnection(buf, sizeof(buf));
+
+  ASSERT_TRUE(socket->isClosedBySelf());
+  ASSERT_FALSE(socket->isClosedByPeer());
+}
index 84ce7b936ab851d628fc04ed44fe09f001fdd218..9202cd9ceea46836e9110cafada4634d8500b514 100644 (file)
@@ -27,23 +27,31 @@ class MockAsyncTransport: public AsyncTransportWrapper {
   MOCK_METHOD1(setReadCB, void(ReadCallback*));
   MOCK_CONST_METHOD0(getReadCallback, ReadCallback*());
   MOCK_CONST_METHOD0(getReadCB, ReadCallback*());
-  MOCK_METHOD4(write, void(WriteCallback*,
+  MOCK_METHOD5(write, void(WriteCallback*,
                            const void*, size_t,
-                           WriteFlags));
-  MOCK_METHOD4(writev, void(WriteCallback*,
+                           WriteFlags,
+                           BufferCallback*));
+  MOCK_METHOD5(writev, void(WriteCallback*,
                             const iovec*, size_t,
-                            WriteFlags));
-  MOCK_METHOD3(writeChain,
+                            WriteFlags,
+                            BufferCallback*));
+  MOCK_METHOD4(writeChain,
                void(WriteCallback*,
                     std::shared_ptr<folly::IOBuf>,
-                    WriteFlags));
+                    WriteFlags,
+                    BufferCallback*));
 
 
   void writeChain(WriteCallback* callback,
                   std::unique_ptr<folly::IOBuf>&& iob,
                   WriteFlags flags =
-                  WriteFlags::NONE) override {
-    writeChain(callback, std::shared_ptr<folly::IOBuf>(iob.release()), flags);
+                  WriteFlags::NONE,
+                  BufferCallback* bufCB = nullptr) override {
+    writeChain(
+        callback,
+        std::shared_ptr<folly::IOBuf>(iob.release()),
+        flags,
+        bufCB);
   }
 
   MOCK_METHOD0(close, void());