Enable EOR flag configuration for folly::AsyncSocket.
[folly.git] / folly / io / async / AsyncSocket.h
index 38939ee786e908dbc11dadcd453292cb84e32ba7..5aeb159c65fb635b8bc872688560ed6712348832 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 
 #pragma once
 
-#include <sys/types.h>
-#include <sys/socket.h>
-#include <glog/logging.h>
+#include <folly/Optional.h>
 #include <folly/SocketAddress.h>
-#include <folly/io/ShutdownSocketSet.h>
+#include <folly/detail/SocketFastOpen.h>
 #include <folly/io/IOBuf.h>
-#include <folly/io/async/AsyncTimeout.h>
+#include <folly/io/ShutdownSocketSet.h>
 #include <folly/io/async/AsyncSocketException.h>
+#include <folly/io/async/AsyncTimeout.h>
 #include <folly/io/async/AsyncTransport.h>
-#include <folly/io/async/EventHandler.h>
 #include <folly/io/async/DelayedDestruction.h>
+#include <folly/io/async/EventHandler.h>
+#include <folly/portability/Sockets.h>
+
+#include <sys/types.h>
 
+#include <chrono>
 #include <memory>
 #include <map>
 
@@ -61,6 +64,17 @@ namespace folly {
  * responding and no further progress can be made sending the data.
  */
 
+#if defined __linux__ && !defined SO_NO_TRANSPARENT_TLS
+#define SO_NO_TRANSPARENT_TLS 200
+#endif
+
+#ifdef _MSC_VER
+// We do a dynamic_cast on this, in
+// AsyncTransportWrapper::getUnderlyingTransport so be safe and
+// force displacements for it. See:
+// https://msdn.microsoft.com/en-us/library/7sf3txa8.aspx
+#pragma vtordisp(push, 2)
+#endif
 class AsyncSocket : virtual public AsyncTransportWrapper {
  public:
   typedef std::unique_ptr<AsyncSocket, Destructor> UniquePtr;
@@ -84,6 +98,19 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
       noexcept = 0;
   };
 
+  class EvbChangeCallback {
+   public:
+    virtual ~EvbChangeCallback() = default;
+
+    // Called when the socket has been attached to a new EVB
+    // and is called from within that EVB thread
+    virtual void evbAttached(AsyncSocket* socket) = 0;
+
+    // Called when the socket is detached from an EVB and
+    // is called from the EVB thread being detached
+    virtual void evbDetached(AsyncSocket* socket) = 0;
+  };
+
   explicit AsyncSocket();
   /**
    * Create a new unconnected AsyncSocket.
@@ -250,15 +277,19 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    *                  does not succeed within this period,
    *                  callback->connectError() will be invoked.
    */
-  virtual void connect(ConnectCallback* callback,
-               const folly::SocketAddress& address,
-               int timeout = 0,
-               const OptionMap &options = emptyOptionMap,
-               const folly::SocketAddress& bindAddr = anyAddress()
-               ) noexcept;
-  void connect(ConnectCallback* callback, const std::string& ip, uint16_t port,
-               int timeout = 00,
-               const OptionMap &options = emptyOptionMap) noexcept;
+  virtual void connect(
+      ConnectCallback* callback,
+      const folly::SocketAddress& address,
+      int timeout = 0,
+      const OptionMap& options = emptyOptionMap,
+      const folly::SocketAddress& bindAddr = anyAddress()) noexcept;
+
+  void connect(
+      ConnectCallback* callback,
+      const std::string& ip,
+      uint16_t port,
+      int timeout = 0,
+      const OptionMap& options = emptyOptionMap) noexcept;
 
   /**
    * If a connect request is in-flight, cancels it and closes the socket
@@ -361,14 +392,28 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   void getPeerAddress(
     folly::SocketAddress* address) const override;
 
-  bool isEorTrackingEnabled() const override { return false; }
+  bool isEorTrackingEnabled() const override {
+    return trackEor_;
+  }
 
-  void setEorTracking(bool /*track*/) override {}
+  void setEorTracking(bool track) override {
+    trackEor_ = track;
+  }
 
   bool connecting() const override {
     return (state_ == StateEnum::CONNECTING);
   }
 
+  virtual bool isClosedByPeer() const {
+    return (state_ == StateEnum::CLOSED &&
+            (readErr_ == READ_EOF || readErr_ == READ_ERROR));
+  }
+
+  virtual bool isClosedBySelf() const {
+    return (state_ == StateEnum::CLOSED &&
+            (readErr_ != READ_EOF && readErr_ != READ_ERROR));
+  }
+
   size_t getAppBytesWritten() const override {
     return appBytesWritten_;
   }
@@ -385,6 +430,42 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     return getAppBytesReceived();
   }
 
+  std::chrono::nanoseconds getConnectTime() const {
+    return connectEndTime_ - connectStartTime_;
+  }
+
+  std::chrono::milliseconds getConnectTimeout() const {
+    return connectTimeout_;
+  }
+
+  bool getTFOAttempted() const {
+    return tfoAttempted_;
+  }
+
+  /**
+   * Returns whether or not the attempt to use TFO
+   * finished successfully. This does not necessarily
+   * mean TFO worked, just that trying to use TFO
+   * succeeded.
+   */
+  bool getTFOFinished() const {
+    return tfoFinished_;
+  }
+
+  /**
+   * Returns whether or not TFO attempt succeded on this
+   * connection.
+   * For servers this is pretty straightforward API and can
+   * be invoked right after the connection is accepted. This API
+   * will perform one syscall.
+   * This API is a bit tricky to use for clients, since clients
+   * only know this for sure after the SYN-ACK is returned. So it's
+   * appropriate to call this only after the first application
+   * data is read from the socket when the caller knows that
+   * the SYN has been ACKed by the server.
+   */
+  bool getTFOSucceded() const;
+
   // Methods controlling socket options
 
   /**
@@ -444,7 +525,6 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   #define SO_SET_NAMESPACE        41
   int setTCPProfile(int profd);
 
-
   /**
    * Generic API for reading a socket option.
    *
@@ -475,12 +555,75 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     return setsockopt(fd_, level, optname, optval, sizeof(T));
   }
 
+  virtual void setPeek(bool peek) {
+    peek_ = peek;
+  }
+
+  /**
+   * Enables TFO behavior on the AsyncSocket if FOLLY_ALLOW_TFO
+   * is set.
+   */
+  void enableTFO() {
+    // No-op if folly does not allow tfo
+#if FOLLY_ALLOW_TFO
+    tfoEnabled_ = true;
+#endif
+  }
+
+  void disableTransparentTls() {
+    noTransparentTls_ = true;
+  }
+
   enum class StateEnum : uint8_t {
     UNINIT,
     CONNECTING,
     ESTABLISHED,
     CLOSED,
-    ERROR
+    ERROR,
+    FAST_OPEN,
+  };
+
+  void setBufferCallback(BufferCallback* cb);
+
+  // Callers should set this prior to connecting the socket for the safest
+  // behavior.
+  void setEvbChangedCallback(std::unique_ptr<EvbChangeCallback> cb) {
+    evbChangeCb_ = std::move(cb);
+  }
+
+  /**
+   * writeReturn is the total number of bytes written, or WRITE_ERROR on error.
+   * If no data has been written, 0 is returned.
+   * exception is a more specific exception that cause a write error.
+   * Not all writes have exceptions associated with them thus writeReturn
+   * should be checked to determine whether the operation resulted in an error.
+   */
+  struct WriteResult {
+    explicit WriteResult(ssize_t ret) : writeReturn(ret) {}
+
+    WriteResult(ssize_t ret, std::unique_ptr<const AsyncSocketException> e)
+        : writeReturn(ret), exception(std::move(e)) {}
+
+    ssize_t writeReturn;
+    std::unique_ptr<const AsyncSocketException> exception;
+  };
+
+  /**
+   * readReturn is the number of bytes read, or READ_EOF on EOF, or
+   * READ_ERROR on error, or READ_BLOCKING if the operation will
+   * block.
+   * exception is a more specific exception that may have caused a read error.
+   * Not all read errors have exceptions associated with them thus readReturn
+   * should be checked to determine whether the operation resulted in an error.
+   */
+  struct ReadResult {
+    explicit ReadResult(ssize_t ret) : readReturn(ret) {}
+
+    ReadResult(ssize_t ret, std::unique_ptr<const AsyncSocketException> e)
+        : readReturn(ret), exception(std::move(e)) {}
+
+    ssize_t readReturn;
+    std::unique_ptr<const AsyncSocketException> exception;
   };
 
   /**
@@ -491,11 +634,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     WriteRequest(AsyncSocket* socket, WriteCallback* callback) :
       socket_(socket), callback_(callback) {}
 
-    virtual void start() {};
+    virtual void start() {}
 
     virtual void destroy() = 0;
 
-    virtual bool performWrite() = 0;
+    virtual WriteResult performWrite() = 0;
 
     virtual void consume() = 0;
 
@@ -523,7 +666,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     }
 
     void bytesWritten(size_t count) {
-      totalBytesWritten_ += count;
+      totalBytesWritten_ += uint32_t(count);
       socket_->appBytesWritten_ += count;
     }
 
@@ -542,6 +685,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     READ_EOF = 0,
     READ_ERROR = -1,
     READ_BLOCKING = -2,
+    READ_NO_ERROR = -3,
+  };
+
+  enum WriteResultEnum {
+    WRITE_ERROR = -1,
   };
 
   /**
@@ -632,11 +780,25 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     }
   }
 
+  /**
+   * Schedule handleInitalReadWrite to run in the next iteration.
+   */
+  void scheduleInitialReadWrite() noexcept {
+    if (good()) {
+      DestructorGuard dg(this);
+      eventBase_->runInLoop([this, dg] {
+        if (good()) {
+          handleInitialReadWrite();
+        }
+      });
+    }
+  }
+
   // event notification methods
   void ioReady(uint16_t events) noexcept;
   virtual void checkForImmediateRead() noexcept;
   virtual void handleInitialReadWrite() noexcept;
-  virtual void prepareReadBuffer(void** buf, size_t* buflen) noexcept;
+  virtual void prepareReadBuffer(void** buf, size_t* buflen);
   virtual void handleRead() noexcept;
   virtual void handleWrite() noexcept;
   virtual void handleConnect() noexcept;
@@ -648,11 +810,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    * @param buf      The buffer to read data into.
    * @param buflen   The length of the buffer.
    *
-   * @return Returns the number of bytes read, or READ_EOF on EOF, or
-   * READ_ERROR on error, or READ_BLOCKING if the operation will
-   * block.
+   * @return Returns a read result. See read result for details.
    */
-  virtual ssize_t performRead(void** buf, size_t* buflen, size_t* offset);
+  virtual ReadResult performRead(void** buf, size_t* buflen, size_t* offset);
 
   /**
    * Populate an iovec array from an IOBuf and attempt to write it.
@@ -701,12 +861,30 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    *                          will contain the number of bytes written in the
    *                          partially written iovec entry.
    *
-   * @return Returns the total number of bytes written, or -1 on error.  If no
-   *     data can be written immediately, 0 is returned.
+   * @return Returns a WriteResult. See WriteResult for more details.
+   */
+  virtual WriteResult performWrite(
+      const iovec* vec,
+      uint32_t count,
+      WriteFlags flags,
+      uint32_t* countWritten,
+      uint32_t* partialWritten);
+
+  /**
+   * Sends the message over the socket using sendmsg
+   *
+   * @param msg       Message to send
+   * @param msg_flags Flags to pass to sendmsg
    */
-  virtual ssize_t performWrite(const iovec* vec, uint32_t count,
-                               WriteFlags flags, uint32_t* countWritten,
-                               uint32_t* partialWritten);
+  AsyncSocket::WriteResult
+  sendSocketMessage(int fd, struct msghdr* msg, int msg_flags);
+
+  virtual ssize_t tfoSendMsg(int fd, struct msghdr* msg, int msg_flags);
+
+  int socketConnect(const struct sockaddr* addr, socklen_t len);
+
+  virtual void scheduleConnectTimeout();
+  void registerForConnectEvents();
 
   bool updateEventRegistration();
 
@@ -730,6 +908,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   // error handling methods
   void startFail();
   void finishFail();
+  void finishFail(const AsyncSocketException& ex);
+  void invokeAllErrors(const AsyncSocketException& ex);
   void fail(const char* fn, const AsyncSocketException& ex);
   void failConnect(const char* fn, const AsyncSocketException& ex);
   void failRead(const char* fn, const AsyncSocketException& ex);
@@ -737,6 +917,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
                  const AsyncSocketException& ex);
   void failWrite(const char* fn, const AsyncSocketException& ex);
   void failAllWrites(const AsyncSocketException& ex);
+  virtual void invokeConnectErr(const AsyncSocketException& ex);
+  virtual void invokeConnectSuccess();
   void invalidState(ConnectCallback* callback);
   void invalidState(ReadCallback* callback);
   void invalidState(WriteCallback* callback);
@@ -747,11 +929,12 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   uint8_t shutdownFlags_;               ///< Shutdown state (ShutdownFlags)
   uint16_t eventFlags_;                 ///< EventBase::HandlerFlags settings
   int fd_;                              ///< The socket file descriptor
-  mutable
-    folly::SocketAddress addr_;    ///< The address we tried to connect to
+  mutable folly::SocketAddress addr_;    ///< The address we tried to connect to
+  mutable folly::SocketAddress localAddr_;
+                                        ///< The address we are connecting from
   uint32_t sendTimeout_;                ///< The send timeout, in milliseconds
   uint16_t maxReadsPerEvent_;           ///< Max reads per event loop iteration
-  EventBase* eventBase_;               ///< The EventBase
+  EventBase* eventBase_;                ///< The EventBase
   WriteTimeout writeTimeout_;           ///< A timeout for connect and write
   IoHandler ioHandler_;                 ///< A EventHandler to monitor the fd
   ImmediateReadCB immediateReadHandler_; ///< LoopCallback for checking read
@@ -764,7 +947,28 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   size_t appBytesReceived_;             ///< Num of bytes received from socket
   size_t appBytesWritten_;              ///< Num of bytes written to socket
   bool isBufferMovable_{false};
-};
 
+  bool peek_{false}; // Peek bytes.
+
+  int8_t readErr_{READ_NO_ERROR};      ///< The read error encountered, if any.
+
+  std::chrono::steady_clock::time_point connectStartTime_;
+  std::chrono::steady_clock::time_point connectEndTime_;
+
+  std::chrono::milliseconds connectTimeout_{0};
+
+  BufferCallback* bufferCallback_{nullptr};
+  bool tfoEnabled_{false};
+  bool tfoAttempted_{false};
+  bool tfoFinished_{false};
+  bool noTransparentTls_{false};
+  // Whether to track EOR or not.
+  bool trackEor_{false};
+
+  std::unique_ptr<EvbChangeCallback> evbChangeCb_{nullptr};
+};
+#ifdef _MSC_VER
+#pragma vtordisp(pop)
+#endif
 
 } // folly