From: Yang Chi Date: Wed, 11 Nov 2015 21:32:04 +0000 (-0800) Subject: Add a buffer callback to AsyncSocket X-Git-Tag: deprecate-dynamic-initializer~261 X-Git-Url: http://plrg.eecs.uci.edu/git/?a=commitdiff_plain;h=7749a46977a772b1f8d310c055875a90bed3efa9;p=folly.git Add a buffer callback to AsyncSocket Summary: This is probably easier than D2612490. The idea is just to add a callback to write, writev and writeChain in AsyncSocket, so upper layer can know when data starts to buffer up Reviewed By: mzlee Differential Revision: D2623385 fb-gh-sync-id: 98d32ca83871aaa4f6c75a769b5f1bf0b5d62c3e --- diff --git a/folly/io/async/AsyncPipe.cpp b/folly/io/async/AsyncPipe.cpp index 206f4557..b4263346 100644 --- a/folly/io/async/AsyncPipe.cpp +++ b/folly/io/async/AsyncPipe.cpp @@ -148,7 +148,8 @@ void AsyncPipeWriter::write(unique_ptr buf, void AsyncPipeWriter::writeChain(folly::AsyncWriter::WriteCallback* callback, std::unique_ptr&& buf, - WriteFlags) { + WriteFlags, + BufferCallback*) { write(std::move(buf), callback); } diff --git a/folly/io/async/AsyncPipe.h b/folly/io/async/AsyncPipe.h index 40d5021a..efa659b2 100644 --- a/folly/io/async/AsyncPipe.h +++ b/folly/io/async/AsyncPipe.h @@ -148,16 +148,19 @@ class AsyncPipeWriter : public EventHandler, // AsyncWriter methods void write(folly::AsyncWriter::WriteCallback* callback, const void* buf, - size_t bytes, WriteFlags flags = WriteFlags::NONE) override { - writeChain(callback, IOBuf::wrapBuffer(buf, bytes), flags); + size_t bytes, WriteFlags flags = WriteFlags::NONE, + BufferCallback* bufCallback = nullptr) override { + writeChain(callback, IOBuf::wrapBuffer(buf, bytes), flags, bufCallback); } void writev(folly::AsyncWriter::WriteCallback*, const iovec*, - size_t, WriteFlags = WriteFlags::NONE) override { + size_t, WriteFlags = WriteFlags::NONE, + BufferCallback* = nullptr) override { throw std::runtime_error("writev is not supported. Please use writeChain."); } void writeChain(folly::AsyncWriter::WriteCallback* callback, std::unique_ptr&& buf, - WriteFlags flags = WriteFlags::NONE) override; + WriteFlags flags = WriteFlags::NONE, + BufferCallback* bufCallback = nullptr) override; private: void handlerReady(uint16_t events) noexcept override; diff --git a/folly/io/async/AsyncSocket.cpp b/folly/io/async/AsyncSocket.cpp index 86eba2ff..28668245 100644 --- a/folly/io/async/AsyncSocket.cpp +++ b/folly/io/async/AsyncSocket.cpp @@ -63,14 +63,16 @@ const AsyncSocketException socketShutdownForWritesEx( */ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { public: - static BytesWriteRequest* newRequest(AsyncSocket* socket, - WriteCallback* callback, - const iovec* ops, - uint32_t opCount, - uint32_t partialWritten, - uint32_t bytesWritten, - unique_ptr&& ioBuf, - WriteFlags flags) { + static BytesWriteRequest* newRequest( + AsyncSocket* socket, + WriteCallback* callback, + const iovec* ops, + uint32_t opCount, + uint32_t partialWritten, + uint32_t bytesWritten, + unique_ptr&& ioBuf, + WriteFlags flags, + BufferCallback* bufferCallback = nullptr) { assert(opCount > 0); // Since we put a variable size iovec array at the end // of each BytesWriteRequest, we have to manually allocate the memory. @@ -82,7 +84,7 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { return new(buf) BytesWriteRequest(socket, callback, ops, opCount, partialWritten, bytesWritten, - std::move(ioBuf), flags); + std::move(ioBuf), flags, bufferCallback); } void destroy() override { @@ -136,8 +138,9 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { uint32_t partialBytes, uint32_t bytesWritten, unique_ptr&& ioBuf, - WriteFlags flags) - : AsyncSocket::WriteRequest(socket, callback) + WriteFlags flags, + BufferCallback* bufferCallback = nullptr) + : AsyncSocket::WriteRequest(socket, callback, bufferCallback) , opCount_(opCount) , opIndex_(0) , flags_(flags) @@ -608,43 +611,46 @@ AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const { } void AsyncSocket::write(WriteCallback* callback, - const void* buf, size_t bytes, WriteFlags flags) { + const void* buf, size_t bytes, WriteFlags flags, + BufferCallback* bufCallback) { iovec op; op.iov_base = const_cast(buf); op.iov_len = bytes; - writeImpl(callback, &op, 1, unique_ptr(), flags); + writeImpl(callback, &op, 1, unique_ptr(), flags, bufCallback); } void AsyncSocket::writev(WriteCallback* callback, const iovec* vec, size_t count, - WriteFlags flags) { - writeImpl(callback, vec, count, unique_ptr(), flags); + WriteFlags flags, + BufferCallback* bufCallback) { + writeImpl(callback, vec, count, unique_ptr(), flags, bufCallback); } void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr&& buf, - WriteFlags flags) { + WriteFlags flags, BufferCallback* bufCallback) { constexpr size_t kSmallSizeMax = 64; size_t count = buf->countChainElements(); if (count <= kSmallSizeMax) { iovec vec[BOOST_PP_IF(FOLLY_HAVE_VLA, count, kSmallSizeMax)]; - writeChainImpl(callback, vec, count, std::move(buf), flags); + writeChainImpl(callback, vec, count, std::move(buf), flags, bufCallback); } else { iovec* vec = new iovec[count]; - writeChainImpl(callback, vec, count, std::move(buf), flags); + writeChainImpl(callback, vec, count, std::move(buf), flags, bufCallback); delete[] vec; } } void AsyncSocket::writeChainImpl(WriteCallback* callback, iovec* vec, - size_t count, unique_ptr&& buf, WriteFlags flags) { + size_t count, unique_ptr&& buf, WriteFlags flags, + BufferCallback* bufCallback) { size_t veclen = buf->fillIov(vec, count); - writeImpl(callback, vec, veclen, std::move(buf), flags); + writeImpl(callback, vec, veclen, std::move(buf), flags, bufCallback); } void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec, size_t count, unique_ptr&& buf, - WriteFlags flags) { + WriteFlags flags, BufferCallback* bufCallback) { VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_ << ", callback=" << callback << ", count=" << count << ", state=" << state_; @@ -688,7 +694,11 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec, callback->writeSuccess(); } return; - } // else { continue writing the next writeReq } + } else { // continue writing the next writeReq + if (bufCallback) { + bufCallback->onEgressBuffered(); + } + } mustRegister = true; } } else if (!connecting()) { @@ -701,7 +711,8 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec, try { req = BytesWriteRequest::newRequest(this, callback, vec + countWritten, count - countWritten, partialWritten, - bytesWritten, std::move(ioBuf), flags); + bytesWritten, std::move(ioBuf), flags, + bufCallback); } catch (const std::exception& ex) { // we mainly expect to catch std::bad_alloc here AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR, @@ -1473,6 +1484,11 @@ void AsyncSocket::handleWrite() noexcept { } // We'll continue around the loop, trying to write another request } else { + // Notify BufferCallback: + BufferCallback* bufferCallback = writeReqHead_->getBufferCallback(); + if (bufferCallback) { + bufferCallback->onEgressBuffered(); + } // Partial write. writeReqHead_->consume(); // Stop after a partial write; it's highly likely that a subsequent write diff --git a/folly/io/async/AsyncSocket.h b/folly/io/async/AsyncSocket.h index b7eeafc7..de1f5c23 100644 --- a/folly/io/async/AsyncSocket.h +++ b/folly/io/async/AsyncSocket.h @@ -328,12 +328,15 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ReadCallback* getReadCallback() const override; void write(WriteCallback* callback, const void* buf, size_t bytes, - WriteFlags flags = WriteFlags::NONE) override; + WriteFlags flags = WriteFlags::NONE, + BufferCallback* bufCallback = nullptr) override; void writev(WriteCallback* callback, const iovec* vec, size_t count, - WriteFlags flags = WriteFlags::NONE) override; + WriteFlags flags = WriteFlags::NONE, + BufferCallback* bufCallback = nullptr) override; void writeChain(WriteCallback* callback, std::unique_ptr&& buf, - WriteFlags flags = WriteFlags::NONE) override; + WriteFlags flags = WriteFlags::NONE, + BufferCallback* bufCallback = nullptr) override; class WriteRequest; virtual void writeRequest(WriteRequest* req); @@ -507,8 +510,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper { */ class WriteRequest { public: - WriteRequest(AsyncSocket* socket, WriteCallback* callback) : - socket_(socket), callback_(callback) {} + WriteRequest( + AsyncSocket* socket, + WriteCallback* callback, + BufferCallback* bufferCallback = nullptr) : + socket_(socket), callback_(callback), bufferCallback_(bufferCallback) {} virtual void start() {}; @@ -546,6 +552,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper { socket_->appBytesWritten_ += count; } + BufferCallback* getBufferCallback() const { + return bufferCallback_; + } + protected: // protected destructor, to ensure callers use destroy() virtual ~WriteRequest() {} @@ -554,6 +564,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { WriteRequest* next_{nullptr}; ///< pointer to next WriteRequest WriteCallback* callback_; ///< completion callback uint32_t totalBytesWritten_{0}; ///< total bytes written + BufferCallback* bufferCallback_{nullptr}; }; protected: @@ -677,36 +688,39 @@ class AsyncSocket : virtual public AsyncTransportWrapper { /** * Populate an iovec array from an IOBuf and attempt to write it. * - * @param callback Write completion/error callback. - * @param vec Target iovec array; caller retains ownership. - * @param count Number of IOBufs to write, beginning at start of buf. - * @param buf Chain of iovecs. - * @param flags set of flags for the underlying write calls, like cork + * @param callback Write completion/error callback. + * @param vec Target iovec array; caller retains ownership. + * @param count Number of IOBufs to write, beginning at start of buf. + * @param buf Chain of iovecs. + * @param flags set of flags for the underlying write calls, like cork + * @param bufCallback Callback when egress data begins to buffer */ void writeChainImpl(WriteCallback* callback, iovec* vec, size_t count, std::unique_ptr&& buf, - WriteFlags flags); + WriteFlags flags, BufferCallback* bufCallback = nullptr); /** * Write as much data as possible to the socket without blocking, * and queue up any leftover data to send when the socket can * handle writes again. * - * @param callback The callback to invoke when the write is completed. - * @param vec Array of buffers to write; this method will make a - * copy of the vector (but not the buffers themselves) - * if the write has to be completed asynchronously. - * @param count Number of elements in vec. - * @param buf The IOBuf that manages the buffers referenced by - * vec, or a pointer to nullptr if the buffers are not - * associated with an IOBuf. Note that ownership of - * the IOBuf is transferred here; upon completion of - * the write, the AsyncSocket deletes the IOBuf. - * @param flags Set of write flags. + * @param callback The callback to invoke when the write is completed. + * @param vec Array of buffers to write; this method will make a + * copy of the vector (but not the buffers themselves) + * if the write has to be completed asynchronously. + * @param count Number of elements in vec. + * @param buf The IOBuf that manages the buffers referenced by + * vec, or a pointer to nullptr if the buffers are not + * associated with an IOBuf. Note that ownership of + * the IOBuf is transferred here; upon completion of + * the write, the AsyncSocket deletes the IOBuf. + * @param flags Set of write flags. + * @param bufCallback Callback when egress data buffers up */ void writeImpl(WriteCallback* callback, const iovec* vec, size_t count, std::unique_ptr&& buf, - WriteFlags flags = WriteFlags::NONE); + WriteFlags flags = WriteFlags::NONE, + BufferCallback* bufCallback = nullptr); /** * Attempt to write to the socket. diff --git a/folly/io/async/AsyncTransport.h b/folly/io/async/AsyncTransport.h index 13bc6c94..031b88e4 100644 --- a/folly/io/async/AsyncTransport.h +++ b/folly/io/async/AsyncTransport.h @@ -464,6 +464,12 @@ class AsyncReader { class AsyncWriter { public: + class BufferCallback { + public: + virtual ~BufferCallback() {} + virtual void onEgressBuffered() = 0; + }; + class WriteCallback { public: virtual ~WriteCallback() = default; @@ -493,12 +499,15 @@ class AsyncWriter { // Write methods that aren't part of AsyncTransport virtual void write(WriteCallback* callback, const void* buf, size_t bytes, - WriteFlags flags = WriteFlags::NONE) = 0; + WriteFlags flags = WriteFlags::NONE, + BufferCallback* bufCallback = nullptr) = 0; virtual void writev(WriteCallback* callback, const iovec* vec, size_t count, - WriteFlags flags = WriteFlags::NONE) = 0; + WriteFlags flags = WriteFlags::NONE, + BufferCallback* bufCallback = nullptr) = 0; virtual void writeChain(WriteCallback* callback, std::unique_ptr&& buf, - WriteFlags flags = WriteFlags::NONE) = 0; + WriteFlags flags = WriteFlags::NONE, + BufferCallback* bufCallback = nullptr) = 0; protected: virtual ~AsyncWriter() = default; @@ -516,15 +525,19 @@ class AsyncTransportWrapper : virtual public AsyncTransport, // to keep compatibility. using ReadCallback = AsyncReader::ReadCallback; using WriteCallback = AsyncWriter::WriteCallback; + using BufferCallback = AsyncWriter::BufferCallback; virtual void setReadCB(ReadCallback* callback) override = 0; virtual ReadCallback* getReadCallback() const override = 0; virtual void write(WriteCallback* callback, const void* buf, size_t bytes, - WriteFlags flags = WriteFlags::NONE) override = 0; + WriteFlags flags = WriteFlags::NONE, + BufferCallback* bufCallback = nullptr) override = 0; virtual void writev(WriteCallback* callback, const iovec* vec, size_t count, - WriteFlags flags = WriteFlags::NONE) override = 0; + WriteFlags flags = WriteFlags::NONE, + BufferCallback* bufCallback = nullptr) override = 0; virtual void writeChain(WriteCallback* callback, std::unique_ptr&& buf, - WriteFlags flags = WriteFlags::NONE) override = 0; + WriteFlags flags = WriteFlags::NONE, + BufferCallback* bufCallback = nullptr) override = 0; /** * The transport wrapper may wrap another transport. This returns the * transport that is wrapped. It returns nullptr if there is no wrapped diff --git a/folly/io/async/test/AsyncSocketTest.h b/folly/io/async/test/AsyncSocketTest.h index 51230014..5d52ad20 100644 --- a/folly/io/async/test/AsyncSocketTest.h +++ b/folly/io/async/test/AsyncSocketTest.h @@ -60,6 +60,23 @@ class ConnCallback : public AsyncSocket::ConnectCallback { VoidCallback errorCallback; }; +class BufferCallback : public AsyncTransportWrapper::BufferCallback { + public: + BufferCallback() + : buffered_(false) {} + + void onEgressBuffered() override { + buffered_ = true; + } + + bool hasBuffered() const { + return buffered_; + } + + private: + bool buffered_{false}; +}; + class WriteCallback : public AsyncTransportWrapper::WriteCallback { public: WriteCallback() diff --git a/folly/io/async/test/AsyncSocketTest2.cpp b/folly/io/async/test/AsyncSocketTest2.cpp index 1a5ebeb8..81acc826 100644 --- a/folly/io/async/test/AsyncSocketTest2.cpp +++ b/folly/io/async/test/AsyncSocketTest2.cpp @@ -2238,3 +2238,32 @@ TEST(AsyncSocketTest, NumPendingMessagesInQueue) { eventBase.loop(); } + +TEST(AsyncSocketTest, BufferTest) { + TestServer server; + + EventBase evb; + AsyncSocket::OptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}}; + std::shared_ptr socket = AsyncSocket::newSocket(&evb); + ConnCallback ccb; + socket->connect(&ccb, server.getAddress(), 30, option); + + + char buf[100 * 1024]; + memset(buf, 'c', sizeof(buf)); + WriteCallback wcb; + BufferCallback bcb; + socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE, &bcb); + + evb.loop(); + CHECK_EQ(ccb.state, STATE_SUCCEEDED); + CHECK_EQ(wcb.state, STATE_SUCCEEDED); + + ASSERT_TRUE(bcb.hasBuffered()); + + socket->close(); + server.verifyConnection(buf, sizeof(buf)); + + ASSERT_TRUE(socket->isClosedBySelf()); + ASSERT_FALSE(socket->isClosedByPeer()); +} diff --git a/folly/io/async/test/MockAsyncTransport.h b/folly/io/async/test/MockAsyncTransport.h index 84ce7b93..9202cd9c 100644 --- a/folly/io/async/test/MockAsyncTransport.h +++ b/folly/io/async/test/MockAsyncTransport.h @@ -27,23 +27,31 @@ class MockAsyncTransport: public AsyncTransportWrapper { MOCK_METHOD1(setReadCB, void(ReadCallback*)); MOCK_CONST_METHOD0(getReadCallback, ReadCallback*()); MOCK_CONST_METHOD0(getReadCB, ReadCallback*()); - MOCK_METHOD4(write, void(WriteCallback*, + MOCK_METHOD5(write, void(WriteCallback*, const void*, size_t, - WriteFlags)); - MOCK_METHOD4(writev, void(WriteCallback*, + WriteFlags, + BufferCallback*)); + MOCK_METHOD5(writev, void(WriteCallback*, const iovec*, size_t, - WriteFlags)); - MOCK_METHOD3(writeChain, + WriteFlags, + BufferCallback*)); + MOCK_METHOD4(writeChain, void(WriteCallback*, std::shared_ptr, - WriteFlags)); + WriteFlags, + BufferCallback*)); void writeChain(WriteCallback* callback, std::unique_ptr&& iob, WriteFlags flags = - WriteFlags::NONE) override { - writeChain(callback, std::shared_ptr(iob.release()), flags); + WriteFlags::NONE, + BufferCallback* bufCB = nullptr) override { + writeChain( + callback, + std::shared_ptr(iob.release()), + flags, + bufCB); } MOCK_METHOD0(close, void());