make AsyncSocket::WriteRequest an interface
authorJames Sedgwick <jsedgwick@fb.com>
Wed, 20 May 2015 15:34:26 +0000 (08:34 -0700)
committerViswanath Sivakumar <viswanath@fb.com>
Wed, 20 May 2015 17:57:12 +0000 (10:57 -0700)
Summary: This will allow a subsequent diff to implement file transfers as another type of write request

Test Plan: unit

Reviewed By: davejwatson@fb.com

Subscribers: net-systems@, folly-diffs@, yfeldblum, chalfant, fugalh, bmatheny

FB internal diff: D2080257

Signature: t1:2080257:1432044566:bcc0724d349879f46e3e58ee672aff7bf37fa5f6

folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h

index 01ed6621daec89784613477081b997f61fc2945f..f477d5da6174bb8fd0203d6e81cc5665fb60a173 100644 (file)
@@ -53,44 +53,24 @@ const AsyncSocketException socketShutdownForWritesEx(
 // the WriteRequest.
 
 /**
- * A WriteRequest object tracks information about a pending write() or writev()
- * operation.
- *
- * A new WriteRequest operation is allocated on the heap for all write
- * operations that cannot be completed immediately.
+ * A WriteRequest object tracks information about a pending write operation.
  */
 class AsyncSocket::WriteRequest {
  public:
-  static WriteRequest* newRequest(WriteCallback* callback,
-                                  const iovec* ops,
-                                  uint32_t opCount,
-                                  unique_ptr<IOBuf>&& ioBuf,
-                                  WriteFlags flags) {
-    assert(opCount > 0);
-    // Since we put a variable size iovec array at the end
-    // of each WriteRequest, we have to manually allocate the memory.
-    void* buf = malloc(sizeof(WriteRequest) +
-                       (opCount * sizeof(struct iovec)));
-    if (buf == nullptr) {
-      throw std::bad_alloc();
-    }
+  WriteRequest(AsyncSocket* socket,
+               WriteRequest* next,
+               WriteCallback* callback,
+               uint32_t totalBytesWritten) :
+    socket_(socket), next_(next), callback_(callback),
+    totalBytesWritten_(totalBytesWritten) {}
 
-    return new(buf) WriteRequest(callback, ops, opCount, std::move(ioBuf),
-                                 flags);
-  }
+  virtual void destroy() = 0;
 
-  void destroy() {
-    this->~WriteRequest();
-    free(this);
-  }
+  virtual bool performWrite() = 0;
 
-  bool cork() const {
-    return isSet(flags_, WriteFlags::CORK);
-  }
+  virtual void consume() = 0;
 
-  WriteFlags flags() const {
-    return flags_;
-  }
+  virtual bool isComplete() = 0;
 
   WriteRequest* getNext() const {
     return next_;
@@ -100,76 +80,141 @@ class AsyncSocket::WriteRequest {
     return callback_;
   }
 
-  uint32_t getBytesWritten() const {
-    return bytesWritten_;
+  uint32_t getTotalBytesWritten() const {
+    return totalBytesWritten_;
   }
 
-  const struct iovec* getOps() const {
-    assert(opCount_ > opIndex_);
-    return writeOps_ + opIndex_;
+  void append(WriteRequest* next) {
+    assert(next_ == nullptr);
+    next_ = next;
   }
 
-  uint32_t getOpCount() const {
-    assert(opCount_ > opIndex_);
-    return opCount_ - opIndex_;
+ protected:
+  // protected destructor, to ensure callers use destroy()
+  virtual ~WriteRequest() {}
+
+  AsyncSocket* socket_;         ///< parent socket
+  WriteRequest* next_;          ///< pointer to next WriteRequest
+  WriteCallback* callback_;     ///< completion callback
+  uint32_t totalBytesWritten_;  ///< total bytes written
+};
+
+/* The default WriteRequest implementation, used for write(), writev() and
+ * writeChain()
+ *
+ * A new BytesWriteRequest operation is allocated on the heap for all write
+ * operations that cannot be completed immediately.
+ */
+class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
+ public:
+  static BytesWriteRequest* newRequest(AsyncSocket* socket,
+                                       WriteCallback* callback,
+                                       const iovec* ops,
+                                       uint32_t opCount,
+                                       uint32_t partialWritten,
+                                       uint32_t bytesWritten,
+                                       unique_ptr<IOBuf>&& ioBuf,
+                                       WriteFlags flags) {
+    assert(opCount > 0);
+    // Since we put a variable size iovec array at the end
+    // of each BytesWriteRequest, we have to manually allocate the memory.
+    void* buf = malloc(sizeof(BytesWriteRequest) +
+                       (opCount * sizeof(struct iovec)));
+    if (buf == nullptr) {
+      throw std::bad_alloc();
+    }
+
+    return new(buf) BytesWriteRequest(socket, callback, ops, opCount,
+                                      partialWritten, bytesWritten,
+                                      std::move(ioBuf), flags);
   }
 
-  void consume(uint32_t wholeOps, uint32_t partialBytes,
-               uint32_t totalBytesWritten) {
-    // Advance opIndex_ forward by wholeOps
-    opIndex_ += wholeOps;
+  void destroy() override {
+    this->~BytesWriteRequest();
+    free(this);
+  }
+
+  bool performWrite() override {
+    WriteFlags writeFlags = flags_;
+    if (getNext() != nullptr) {
+      writeFlags = writeFlags | WriteFlags::CORK;
+    }
+    bytesWritten_ = socket_->performWrite(getOps(), getOpCount(), writeFlags,
+                                          &opsWritten_, &partialBytes_);
+    return bytesWritten_ >= 0;
+  }
+
+  bool isComplete() override {
+    return opsWritten_ == getOpCount();
+  }
+
+  void consume() override {
+    // Advance opIndex_ forward by opsWritten_
+    opIndex_ += opsWritten_;
     assert(opIndex_ < opCount_);
 
     // If we've finished writing any IOBufs, release them
     if (ioBuf_) {
-      for (uint32_t i = wholeOps; i != 0; --i) {
+      for (uint32_t i = opsWritten_; i != 0; --i) {
         assert(ioBuf_);
         ioBuf_ = ioBuf_->pop();
       }
     }
 
-    // Move partialBytes forward into the current iovec buffer
+    // Move partialBytes_ forward into the current iovec buffer
     struct iovec* currentOp = writeOps_ + opIndex_;
-    assert((partialBytes < currentOp->iov_len) || (currentOp->iov_len == 0));
+    assert((partialBytes_ < currentOp->iov_len) || (currentOp->iov_len == 0));
     currentOp->iov_base =
-      reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes;
-    currentOp->iov_len -= partialBytes;
+      reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes_;
+    currentOp->iov_len -= partialBytes_;
 
-    // Increment the bytesWritten_ count by totalBytesWritten
-    bytesWritten_ += totalBytesWritten;
-  }
-
-  void append(WriteRequest* next) {
-    assert(next_ == nullptr);
-    next_ = next;
+    // Increment the totalBytesWritten_ count by bytesWritten_;
+    totalBytesWritten_ += bytesWritten_;
   }
 
  private:
-  WriteRequest(WriteCallback* callback,
-               const struct iovec* ops,
-               uint32_t opCount,
-               unique_ptr<IOBuf>&& ioBuf,
-               WriteFlags flags)
-    : next_(nullptr)
-    , callback_(callback)
-    , bytesWritten_(0)
+  BytesWriteRequest(AsyncSocket* socket,
+                    WriteCallback* callback,
+                    const struct iovec* ops,
+                    uint32_t opCount,
+                    uint32_t partialBytes,
+                    uint32_t bytesWritten,
+                    unique_ptr<IOBuf>&& ioBuf,
+                    WriteFlags flags)
+    : AsyncSocket::WriteRequest(socket, nullptr, callback, 0)
     , opCount_(opCount)
     , opIndex_(0)
     , flags_(flags)
-    , ioBuf_(std::move(ioBuf)) {
+    , ioBuf_(std::move(ioBuf))
+    , opsWritten_(0)
+    , partialBytes_(partialBytes)
+    , bytesWritten_(bytesWritten) {
     memcpy(writeOps_, ops, sizeof(*ops) * opCount_);
   }
 
-  // Private destructor, to ensure callers use destroy()
-  ~WriteRequest() {}
+  // private destructor, to ensure callers use destroy()
+  virtual ~BytesWriteRequest() {}
+
+  const struct iovec* getOps() const {
+    assert(opCount_ > opIndex_);
+    return writeOps_ + opIndex_;
+  }
+
+  uint32_t getOpCount() const {
+    assert(opCount_ > opIndex_);
+    return opCount_ - opIndex_;
+  }
 
-  WriteRequest* next_;          ///< pointer to next WriteRequest
-  WriteCallback* callback_;     ///< completion callback
-  uint32_t bytesWritten_;       ///< bytes written
   uint32_t opCount_;            ///< number of entries in writeOps_
   uint32_t opIndex_;            ///< current index into writeOps_
   WriteFlags flags_;            ///< set for WriteFlags
   unique_ptr<IOBuf> ioBuf_;     ///< underlying IOBuf, or nullptr if N/A
+
+  // for consume(), how much we wrote on the last write
+  uint32_t opsWritten_;         ///< complete ops written
+  uint32_t partialBytes_;       ///< partial bytes of incomplete op written
+  ssize_t bytesWritten_;        ///< bytes written altogether
+
   struct iovec writeOps_[];     ///< write operation(s) list
 };
 
@@ -687,16 +732,16 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
   // Create a new WriteRequest to add to the queue
   WriteRequest* req;
   try {
-    req = WriteRequest::newRequest(callback, vec + countWritten,
-                                   count - countWritten, std::move(ioBuf),
-                                   flags);
+    req = BytesWriteRequest::newRequest(this, callback, vec + countWritten,
+                                        count - countWritten, partialWritten,
+                                        bytesWritten, std::move(ioBuf), flags);
   } catch (const std::exception& ex) {
     // we mainly expect to catch std::bad_alloc here
     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
         withAddr(string("failed to append new WriteRequest: ") + ex.what()));
     return failWrite(__func__, callback, bytesWritten, tex);
   }
-  req->consume(0, partialWritten, bytesWritten);
+  req->consume();
   if (writeReqTail_ == nullptr) {
     assert(writeReqHead_ == nullptr);
     writeReqHead_ = writeReqTail_ = req;
@@ -1346,20 +1391,11 @@ void AsyncSocket::handleWrite() noexcept {
   // (See the comment in handleRead() explaining how this can happen.)
   EventBase* originalEventBase = eventBase_;
   while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
-    uint32_t countWritten;
-    uint32_t partialWritten;
-    WriteFlags writeFlags = writeReqHead_->flags();
-    if (writeReqHead_->getNext() != nullptr) {
-      writeFlags = writeFlags | WriteFlags::CORK;
-    }
-    int bytesWritten = performWrite(writeReqHead_->getOps(),
-                                    writeReqHead_->getOpCount(),
-                                    writeFlags, &countWritten, &partialWritten);
-    if (bytesWritten < 0) {
+    if (!writeReqHead_->performWrite()) {
       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
                              withAddr("writev() failed"), errno);
       return failWrite(__func__, ex);
-    } else if (countWritten == writeReqHead_->getOpCount()) {
+    } else if (writeReqHead_->isComplete()) {
       // We finished this request
       WriteRequest* req = writeReqHead_;
       writeReqHead_ = req->getNext();
@@ -1424,7 +1460,7 @@ void AsyncSocket::handleWrite() noexcept {
       // We'll continue around the loop, trying to write another request
     } else {
       // Partial write.
-      writeReqHead_->consume(countWritten, partialWritten, bytesWritten);
+      writeReqHead_->consume();
       // Stop after a partial write; it's highly likely that a subsequent write
       // attempt will just return EAGAIN.
       //
@@ -1822,7 +1858,7 @@ void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
     WriteRequest* req = writeReqHead_;
     writeReqHead_ = req->getNext();
     WriteCallback* callback = req->getCallback();
-    uint32_t bytesWritten = req->getBytesWritten();
+    uint32_t bytesWritten = req->getTotalBytesWritten();
     req->destroy();
     if (callback) {
       callback->writeErr(bytesWritten, ex);
@@ -1859,7 +1895,7 @@ void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
     writeReqHead_ = req->getNext();
     WriteCallback* callback = req->getCallback();
     if (callback) {
-      callback->writeErr(req->getBytesWritten(), ex);
+      callback->writeErr(req->getTotalBytesWritten(), ex);
     }
     req->destroy();
   }
index e6209166b749e46f9007513d29ed97f8b42a248e..866c5d91284f56a409f7789c334a6ca40924b8c7 100644 (file)
@@ -517,6 +517,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   };
 
   class WriteRequest;
+  class BytesWriteRequest;
 
   class WriteTimeout : public AsyncTimeout {
    public: