Fix copyright lines
[folly.git] / folly / io / async / AsyncTransport.h
index 0b929c16ff2bde5f0731196c42248aed81d4db53..050c22bf9ae719e6a363d0c79f6e35d27b4dbcb4 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2014 Facebook, Inc.
+ * Copyright 2014-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 
 #pragma once
 
+#include <memory>
+
+#include <folly/io/IOBuf.h>
+#include <folly/io/async/AsyncSocketBase.h>
+#include <folly/io/async/DelayedDestruction.h>
+#include <folly/io/async/EventBase.h>
+#include <folly/portability/OpenSSL.h>
+#include <folly/portability/SysUio.h>
+#include <folly/ssl/OpenSSLPtrTypes.h>
+
+constexpr bool kOpenSslModeMoveBufferOwnership =
+#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
+  true
+#else
+  false
+#endif
+;
+
 namespace folly {
 
+class AsyncSocketException;
+class EventBase;
+class SocketAddress;
+
 /*
  * flags given by the application for write* calls
  */
@@ -34,6 +56,14 @@ enum class WriteFlags : uint32_t {
    * will be acknowledged.
    */
   EOR = 0x02,
+  /*
+   * this indicates that only the write side of socket should be shutdown
+   */
+  WRITE_SHUTDOWN = 0x04,
+  /*
+   * use msg zerocopy if allowed
+   */
+  WRITE_MSG_ZEROCOPY = 0x08,
 };
 
 /*
@@ -44,6 +74,14 @@ inline WriteFlags operator|(WriteFlags a, WriteFlags b) {
     static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
 }
 
+/*
+ * compound assignment union operator
+ */
+inline WriteFlags& operator|=(WriteFlags& a, WriteFlags b) {
+  a = a | b;
+  return a;
+}
+
 /*
  * intersection operator
  */
@@ -52,6 +90,14 @@ inline WriteFlags operator&(WriteFlags a, WriteFlags b) {
     static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
 }
 
+/*
+ * compound assignment intersection operator
+ */
+inline WriteFlags& operator&=(WriteFlags& a, WriteFlags b) {
+  a = a & b;
+  return a;
+}
+
 /*
  * exclusion parameter
  */
@@ -101,7 +147,7 @@ inline bool isSet(WriteFlags a, WriteFlags b) {
  * timeout, since most callers want to give up if the remote end stops
  * responding and no further progress can be made sending the data.
  */
-class AsyncTransport : public DelayedDestruction {
+class AsyncTransport : public DelayedDestruction, public AsyncSocketBase {
  public:
   typedef std::unique_ptr<AsyncTransport, Destructor> UniquePtr;
 
@@ -196,6 +242,16 @@ class AsyncTransport : public DelayedDestruction {
    */
   virtual bool readable() const = 0;
 
+  /**
+   * Determine if the transport is writable or not.
+   *
+   * @return  true iff the transport is writable, false otherwise.
+   */
+  virtual bool writable() const {
+    // By default return good() - leave it to implementers to override.
+    return good();
+  }
+
   /**
    * Determine if the there is pending data on the transport.
    *
@@ -204,6 +260,7 @@ class AsyncTransport : public DelayedDestruction {
   virtual bool isPending() const {
     return readable();
   }
+
   /**
    * Determine if transport is connected to the endpoint
    *
@@ -246,14 +303,6 @@ class AsyncTransport : public DelayedDestruction {
    */
   virtual bool isDetachable() const = 0;
 
-  /**
-   * Get the EventBase used by this transport.
-   *
-   * Returns nullptr if this transport is not currently attached to a
-   * EventBase.
-   */
-  virtual EventBase* getEventBase() const = 0;
-
   /**
    * Set the send timeout.
    *
@@ -284,7 +333,25 @@ class AsyncTransport : public DelayedDestruction {
    * @param address  The local address will be stored in the specified
    *                 SocketAddress.
    */
-  virtual void getLocalAddress(folly::SocketAddress* address) const = 0;
+  virtual void getLocalAddress(SocketAddress* address) const = 0;
+
+  /**
+   * Get the address of the remote endpoint to which this transport is
+   * connected.
+   *
+   * This function may throw AsyncSocketException on error.
+   *
+   * @return         Return the local address
+   */
+  SocketAddress getLocalAddress() const {
+    SocketAddress addr;
+    getLocalAddress(&addr);
+    return addr;
+  }
+
+  void getAddress(SocketAddress* address) const override {
+    getLocalAddress(address);
+  }
 
   /**
    * Get the address of the remote endpoint to which this transport is
@@ -295,7 +362,49 @@ class AsyncTransport : public DelayedDestruction {
    * @param address  The remote endpoint's address will be stored in the
    *                 specified SocketAddress.
    */
-  virtual void getPeerAddress(folly::SocketAddress* address) const = 0;
+  virtual void getPeerAddress(SocketAddress* address) const = 0;
+
+  /**
+   * Get the address of the remote endpoint to which this transport is
+   * connected.
+   *
+   * This function may throw AsyncSocketException on error.
+   *
+   * @return         Return the remote endpoint's address
+   */
+  SocketAddress getPeerAddress() const {
+    SocketAddress addr;
+    getPeerAddress(&addr);
+    return addr;
+  }
+
+  /**
+   * Get the certificate used to authenticate the peer.
+   */
+  virtual ssl::X509UniquePtr getPeerCert() const { return nullptr; }
+
+  /**
+   * The local certificate used for this connection. May be null
+   */
+  virtual const X509* getSelfCert() const {
+    return nullptr;
+  }
+
+  /**
+   * Return the application protocol being used by the underlying transport
+   * protocol. This is useful for transports which are used to tunnel other
+   * protocols.
+   */
+  virtual std::string getApplicationProtocol() noexcept {
+    return "";
+  }
+
+  /**
+   * Returns the name of the security protocol being used.
+   */
+  virtual std::string getSecurityProtocol() const {
+    return "";
+  }
 
   /**
    * @return True iff end of record tracking is enabled
@@ -309,9 +418,286 @@ class AsyncTransport : public DelayedDestruction {
   virtual size_t getAppBytesReceived() const = 0;
   virtual size_t getRawBytesReceived() const = 0;
 
+  class BufferCallback {
+   public:
+    virtual ~BufferCallback() {}
+    virtual void onEgressBuffered() = 0;
+    virtual void onEgressBufferCleared() = 0;
+  };
+
+  /**
+   * Callback class to signal when a transport that did not have replay
+   * protection gains replay protection. This is needed for 0-RTT security
+   * protocols.
+   */
+  class ReplaySafetyCallback {
+   public:
+    virtual ~ReplaySafetyCallback() = default;
+
+    /**
+     * Called when the transport becomes replay safe.
+     */
+    virtual void onReplaySafe() = 0;
+  };
+
+  /**
+   * False if the transport does not have replay protection, but will in the
+   * future.
+   */
+  virtual bool isReplaySafe() const { return true; }
+
+  /**
+   * Set the ReplaySafeCallback on this transport.
+   *
+   * This should only be called if isReplaySafe() returns false.
+   */
+  virtual void setReplaySafetyCallback(ReplaySafetyCallback* callback) {
+    if (callback) {
+      CHECK(false) << "setReplaySafetyCallback() not supported";
+    }
+  }
+
+ protected:
+  ~AsyncTransport() override = default;
+};
+
+class AsyncReader {
+ public:
+  class ReadCallback {
+   public:
+    virtual ~ReadCallback() = default;
+
+    /**
+     * When data becomes available, getReadBuffer() will be invoked to get the
+     * buffer into which data should be read.
+     *
+     * This method allows the ReadCallback to delay buffer allocation until
+     * data becomes available.  This allows applications to manage large
+     * numbers of idle connections, without having to maintain a separate read
+     * buffer for each idle connection.
+     *
+     * It is possible that in some cases, getReadBuffer() may be called
+     * multiple times before readDataAvailable() is invoked.  In this case, the
+     * data will be written to the buffer returned from the most recent call to
+     * readDataAvailable().  If the previous calls to readDataAvailable()
+     * returned different buffers, the ReadCallback is responsible for ensuring
+     * that they are not leaked.
+     *
+     * If getReadBuffer() throws an exception, returns a nullptr buffer, or
+     * returns a 0 length, the ReadCallback will be uninstalled and its
+     * readError() method will be invoked.
+     *
+     * getReadBuffer() is not allowed to change the transport state before it
+     * returns.  (For example, it should never uninstall the read callback, or
+     * set a different read callback.)
+     *
+     * @param bufReturn getReadBuffer() should update *bufReturn to contain the
+     *                  address of the read buffer.  This parameter will never
+     *                  be nullptr.
+     * @param lenReturn getReadBuffer() should update *lenReturn to contain the
+     *                  maximum number of bytes that may be written to the read
+     *                  buffer.  This parameter will never be nullptr.
+     */
+    virtual void getReadBuffer(void** bufReturn, size_t* lenReturn) = 0;
+
+    /**
+     * readDataAvailable() will be invoked when data has been successfully read
+     * into the buffer returned by the last call to getReadBuffer().
+     *
+     * The read callback remains installed after readDataAvailable() returns.
+     * It must be explicitly uninstalled to stop receiving read events.
+     * getReadBuffer() will be called at least once before each call to
+     * readDataAvailable().  getReadBuffer() will also be called before any
+     * call to readEOF().
+     *
+     * @param len       The number of bytes placed in the buffer.
+     */
+
+    virtual void readDataAvailable(size_t len) noexcept = 0;
+
+    /**
+     * When data becomes available, isBufferMovable() will be invoked to figure
+     * out which API will be used, readBufferAvailable() or
+     * readDataAvailable(). If isBufferMovable() returns true, that means
+     * ReadCallback supports the IOBuf ownership transfer and
+     * readBufferAvailable() will be used.  Otherwise, not.
+
+     * By default, isBufferMovable() always return false. If
+     * readBufferAvailable() is implemented and to be invoked, You should
+     * overwrite isBufferMovable() and return true in the inherited class.
+     *
+     * This method allows the AsyncSocket/AsyncSSLSocket do buffer allocation by
+     * itself until data becomes available.  Compared with the pre/post buffer
+     * allocation in getReadBuffer()/readDataAvailabe(), readBufferAvailable()
+     * has two advantages.  First, this can avoid memcpy. E.g., in
+     * AsyncSSLSocket, the decrypted data was copied from the openssl internal
+     * buffer to the readbuf buffer.  With the buffer ownership transfer, the
+     * internal buffer can be directly "moved" to ReadCallback. Second, the
+     * memory allocation can be more precise.  The reason is
+     * AsyncSocket/AsyncSSLSocket can allocate the memory of precise size
+     * because they have more context about the available data than
+     * ReadCallback.  Think about the getReadBuffer() pre-allocate 4072 bytes
+     * buffer, but the available data is always 16KB (max OpenSSL record size).
+     */
+
+    virtual bool isBufferMovable() noexcept {
+      return false;
+    }
+
+    /**
+     * Suggested buffer size, allocated for read operations,
+     * if callback is movable and supports folly::IOBuf
+     */
+
+    virtual size_t maxBufferSize() const {
+      return 64 * 1024; // 64K
+    }
+
+    /**
+     * readBufferAvailable() will be invoked when data has been successfully
+     * read.
+     *
+     * Note that only either readBufferAvailable() or readDataAvailable() will
+     * be invoked according to the return value of isBufferMovable(). The timing
+     * and aftereffect of readBufferAvailable() are the same as
+     * readDataAvailable()
+     *
+     * @param readBuf The unique pointer of read buffer.
+     */
+
+    virtual void readBufferAvailable(std::unique_ptr<IOBuf> /*readBuf*/)
+      noexcept {}
+
+    /**
+     * readEOF() will be invoked when the transport is closed.
+     *
+     * The read callback will be automatically uninstalled immediately before
+     * readEOF() is invoked.
+     */
+    virtual void readEOF() noexcept = 0;
+
+    /**
+     * readError() will be invoked if an error occurs reading from the
+     * transport.
+     *
+     * The read callback will be automatically uninstalled immediately before
+     * readError() is invoked.
+     *
+     * @param ex        An exception describing the error that occurred.
+     */
+    virtual void readErr(const AsyncSocketException& ex) noexcept = 0;
+  };
+
+  // Read methods that aren't part of AsyncTransport.
+  virtual void setReadCB(ReadCallback* callback) = 0;
+  virtual ReadCallback* getReadCallback() const = 0;
+
+ protected:
+  virtual ~AsyncReader() = default;
+};
+
+class AsyncWriter {
+ public:
+  class WriteCallback {
+   public:
+    virtual ~WriteCallback() = default;
+
+    /**
+     * writeSuccess() will be invoked when all of the data has been
+     * successfully written.
+     *
+     * Note that this mainly signals that the buffer containing the data to
+     * write is no longer needed and may be freed or re-used.  It does not
+     * guarantee that the data has been fully transmitted to the remote
+     * endpoint.  For example, on socket-based transports, writeSuccess() only
+     * indicates that the data has been given to the kernel for eventual
+     * transmission.
+     */
+    virtual void writeSuccess() noexcept = 0;
+
+    /**
+     * writeError() will be invoked if an error occurs writing the data.
+     *
+     * @param bytesWritten      The number of bytes that were successfull
+     * @param ex                An exception describing the error that occurred.
+     */
+    virtual void writeErr(size_t bytesWritten,
+                          const AsyncSocketException& ex) noexcept = 0;
+  };
+
+  // Write methods that aren't part of AsyncTransport
+  virtual void write(WriteCallback* callback, const void* buf, size_t bytes,
+                     WriteFlags flags = WriteFlags::NONE) = 0;
+  virtual void writev(WriteCallback* callback, const iovec* vec, size_t count,
+                      WriteFlags flags = WriteFlags::NONE) = 0;
+  virtual void writeChain(WriteCallback* callback,
+                          std::unique_ptr<IOBuf>&& buf,
+                          WriteFlags flags = WriteFlags::NONE) = 0;
+
  protected:
-  virtual ~AsyncTransport() {}
+  virtual ~AsyncWriter() = default;
 };
 
+// Transitional intermediate interface. This is deprecated.
+// Wrapper around folly::AsyncTransport, that includes read/write callbacks
+class AsyncTransportWrapper : virtual public AsyncTransport,
+                              virtual public AsyncReader,
+                              virtual public AsyncWriter {
+ public:
+  using UniquePtr = std::unique_ptr<AsyncTransportWrapper, Destructor>;
+
+  // Alias for inherited members from AsyncReader and AsyncWriter
+  // to keep compatibility.
+  using ReadCallback    = AsyncReader::ReadCallback;
+  using WriteCallback   = AsyncWriter::WriteCallback;
+  void setReadCB(ReadCallback* callback) override = 0;
+  ReadCallback* getReadCallback() const override = 0;
+  void write(
+      WriteCallback* callback,
+      const void* buf,
+      size_t bytes,
+      WriteFlags flags = WriteFlags::NONE) override = 0;
+  void writev(
+      WriteCallback* callback,
+      const iovec* vec,
+      size_t count,
+      WriteFlags flags = WriteFlags::NONE) override = 0;
+  void writeChain(
+      WriteCallback* callback,
+      std::unique_ptr<IOBuf>&& buf,
+      WriteFlags flags = WriteFlags::NONE) override = 0;
+  /**
+   * The transport wrapper may wrap another transport. This returns the
+   * transport that is wrapped. It returns nullptr if there is no wrapped
+   * transport.
+   */
+  virtual const AsyncTransportWrapper* getWrappedTransport() const {
+    return nullptr;
+  }
+
+  /**
+   * In many cases when we need to set socket properties or otherwise access the
+   * underlying transport from a wrapped transport. This method allows access to
+   * the derived classes of the underlying transport.
+   */
+  template <class T>
+  const T* getUnderlyingTransport() const {
+    const AsyncTransportWrapper* current = this;
+    while (current) {
+      auto sock = dynamic_cast<const T*>(current);
+      if (sock) {
+        return sock;
+      }
+      current = current->getWrappedTransport();
+    }
+    return nullptr;
+  }
+
+  template <class T>
+  T* getUnderlyingTransport() {
+    return const_cast<T*>(static_cast<const AsyncTransportWrapper*>(this)
+        ->getUnderlyingTransport<T>());
+  }
+};
 
-} // folly
+} // namespace folly