Return if we handle any error messages to avoid unnecessarily calling recv/send
[folly.git] / folly / io / async / AsyncSocket.h
index e99300fb238a6ae491f453bdbc886b405037adda..c85ac6b87bd51392d9a667cffadbb8c5fe886bff 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2017 Facebook, Inc.
+ * Copyright 2017-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -13,7 +13,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 #pragma once
 
 #include <folly/Optional.h>
@@ -222,7 +221,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    */
   explicit AsyncSocket(EventBase* evb);
 
-  void setShutdownSocketSet(ShutdownSocketSet* ss);
+  void setShutdownSocketSet(const std::weak_ptr<ShutdownSocketSet>& wSS);
 
   /**
    * Create a new AsyncSocket and begin the connection process.
@@ -261,8 +260,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    *
    * @param evb EventBase that will manage this socket.
    * @param fd  File descriptor to take over (should be a connected socket).
+   * @param zeroCopyBufId Zerocopy buf id to start with.
    */
-  AsyncSocket(EventBase* evb, int fd);
+  AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId = 0);
 
   /**
    * Create an AsyncSocket from a different, already connected AsyncSocket.
@@ -504,21 +504,15 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   void setReadCB(ReadCallback* callback) override;
   ReadCallback* getReadCallback() const override;
 
-  static const size_t kDefaultZeroCopyThreshold = 32768; // 32KB
-
   bool setZeroCopy(bool enable);
   bool getZeroCopy() const {
     return zeroCopyEnabled_;
   }
 
-  void setZeroCopyWriteChainThreshold(size_t threshold);
-  size_t getZeroCopyWriteChainThreshold() const {
-    return zeroCopyWriteChainThreshold_;
+  uint32_t getZeroCopyBufId() const {
+    return zeroCopyBufId_;
   }
 
-  bool isZeroCopyMsg(const cmsghdr& cmsg) const;
-  void processZeroCopyMsg(const cmsghdr& cmsg);
-
   void write(WriteCallback* callback, const void* buf, size_t bytes,
              WriteFlags flags = WriteFlags::NONE) override;
   void writev(WriteCallback* callback, const iovec* vec, size_t count,
@@ -808,6 +802,18 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    */
   void cacheAddresses();
 
+  /**
+   * Returns true if there is any zero copy write in progress
+   * Needs to be called from within the socket's EVB thread
+   */
+  bool isZeroCopyWriteInProgress() const noexcept;
+
+  /**
+   * Tries to process the msg error queue
+   * And returns true if there are no more zero copy writes in progress
+   */
+  bool processZeroCopyWriteInProgress() noexcept;
+
   /**
    * writeReturn is the total number of bytes written, or WRITE_ERROR on error.
    * If no data has been written, 0 is returned.
@@ -1016,7 +1022,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   virtual void checkForImmediateRead() noexcept;
   virtual void handleInitialReadWrite() noexcept;
   virtual void prepareReadBuffer(void** buf, size_t* buflen);
-  virtual void handleErrMessages() noexcept;
+  virtual size_t handleErrMessages() noexcept;
   virtual void handleRead() noexcept;
   virtual void handleWrite() noexcept;
   virtual void handleConnect() noexcept;
@@ -1149,30 +1155,33 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   void cachePeerAddress() const;
 
   bool isZeroCopyRequest(WriteFlags flags);
-  uint32_t getNextZeroCopyBuffId() {
-    return zeroCopyBuffId_++;
+
+  bool isZeroCopyMsg(const cmsghdr& cmsg) const;
+  void processZeroCopyMsg(const cmsghdr& cmsg);
+
+  uint32_t getNextZeroCopyBufId() {
+    return zeroCopyBufId_++;
   }
-  void adjustZeroCopyFlags(folly::IOBuf* buf, folly::WriteFlags& flags);
-  void adjustZeroCopyFlags(
-      const iovec* vec,
-      uint32_t count,
-      folly::WriteFlags& flags);
-  void addZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf);
-  void addZeroCopyBuff(folly::IOBuf* ptr);
-  void setZeroCopyBuff(std::unique_ptr<folly::IOBuf>&& buf);
-  bool containsZeroCopyBuff(folly::IOBuf* ptr);
-  void releaseZeroCopyBuff(uint32_t id);
+  void adjustZeroCopyFlags(folly::WriteFlags& flags);
+  void addZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf);
+  void addZeroCopyBuf(folly::IOBuf* ptr);
+  void setZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf);
+  bool containsZeroCopyBuf(folly::IOBuf* ptr);
+  void releaseZeroCopyBuf(uint32_t id);
 
   // a folly::IOBuf can be used in multiple partial requests
-  // so we keep a map that maps a buffer id to a raw folly::IOBuf ptr
-  // and one more map that adds a ref count for a folly::IOBuf that is either
+  // there is a that maps a buffer id to a raw folly::IOBuf ptr
+  // and another one that adds a ref count for a folly::IOBuf that is either
   // the original ptr or nullptr
-  uint32_t zeroCopyBuffId_{0};
+  uint32_t zeroCopyBufId_{0};
+
+  struct IOBufInfo {
+    uint32_t count_{0};
+    std::unique_ptr<folly::IOBuf> buf_;
+  };
+
   std::unordered_map<uint32_t, folly::IOBuf*> idZeroCopyBufPtrMap_;
-  std::unordered_map<
-      folly::IOBuf*,
-      std::pair<uint32_t, std::unique_ptr<folly::IOBuf>>>
-      idZeroCopyBufPtrToBufMap_;
+  std::unordered_map<folly::IOBuf*, IOBufInfo> idZeroCopyBufInfoMap_;
 
   StateEnum state_;                      ///< StateEnum describing current state
   uint8_t shutdownFlags_;                ///< Shutdown state (ShutdownFlags)
@@ -1183,6 +1192,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
                                          ///< The address we are connecting from
   uint32_t sendTimeout_;                 ///< The send timeout, in milliseconds
   uint16_t maxReadsPerEvent_;            ///< Max reads per event loop iteration
+
+  bool isBufferMovable_{false};
+
+  int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any
+
   EventBase* eventBase_;                 ///< The EventBase
   WriteTimeout writeTimeout_;            ///< A timeout for connect and write
   IoHandler ioHandler_;                  ///< A EventHandler to monitor the fd
@@ -1195,22 +1209,21 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   ReadCallback* readCallback_;           ///< ReadCallback
   WriteRequest* writeReqHead_;           ///< Chain of WriteRequests
   WriteRequest* writeReqTail_;           ///< End of WriteRequest chain
-  ShutdownSocketSet* shutdownSocketSet_;
+  std::weak_ptr<ShutdownSocketSet> wShutdownSocketSet_;
   size_t appBytesReceived_;              ///< Num of bytes received from socket
   size_t appBytesWritten_;               ///< Num of bytes written to socket
-  bool isBufferMovable_{false};
 
   // Pre-received data, to be returned to read callback before any data from the
   // socket.
   std::unique_ptr<IOBuf> preReceivedData_;
 
-  int8_t readErr_{READ_NO_ERROR};        ///< The read error encountered, if any
-
   std::chrono::steady_clock::time_point connectStartTime_;
   std::chrono::steady_clock::time_point connectEndTime_;
 
   std::chrono::milliseconds connectTimeout_{0};
 
+  std::unique_ptr<EvbChangeCallback> evbChangeCb_{nullptr};
+
   BufferCallback* bufferCallback_{nullptr};
   bool tfoEnabled_{false};
   bool tfoAttempted_{false};
@@ -1221,9 +1234,6 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   bool trackEor_{false};
   bool zeroCopyEnabled_{false};
   bool zeroCopyVal_{false};
-  size_t zeroCopyWriteChainThreshold_{kDefaultZeroCopyThreshold};
-
-  std::unique_ptr<EvbChangeCallback> evbChangeCb_{nullptr};
 };
 #ifdef _MSC_VER
 #pragma vtordisp(pop)