Adding 'compound assignment union operator' for folly::WriteFlags enum class
[folly.git] / folly / io / async / AsyncTransport.h
index 031b88e41e733b06d4f72e1b9198c66328683da2..8af868f7b79e586bedf18a3d706e67eac2658363 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 #pragma once
 
 #include <memory>
-#include <sys/uio.h>
 
+#include <folly/io/IOBuf.h>
+#include <folly/io/async/AsyncSocketBase.h>
 #include <folly/io/async/DelayedDestruction.h>
 #include <folly/io/async/EventBase.h>
-#include <folly/io/async/AsyncSocketBase.h>
+#include <folly/io/async/ssl/OpenSSLPtrTypes.h>
+#include <folly/portability/SysUio.h>
 
 #include <openssl/ssl.h>
 
@@ -37,7 +39,6 @@ namespace folly {
 
 class AsyncSocketException;
 class EventBase;
-class IOBuf;
 class SocketAddress;
 
 /*
@@ -70,6 +71,14 @@ inline WriteFlags operator|(WriteFlags a, WriteFlags b) {
     static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
 }
 
+/*
+ * compound assignment union operator
+ */
+inline WriteFlags& operator|=(WriteFlags& a, WriteFlags b) {
+  a = a | b;
+  return a;
+}
+
 /*
  * intersection operator
  */
@@ -78,6 +87,14 @@ inline WriteFlags operator&(WriteFlags a, WriteFlags b) {
     static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
 }
 
+/*
+ * compound assignment intersection operator
+ */
+inline WriteFlags& operator&=(WriteFlags& a, WriteFlags b) {
+  a = a & b;
+  return a;
+}
+
 /*
  * exclusion parameter
  */
@@ -320,6 +337,18 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase {
    */
   virtual void getPeerAddress(SocketAddress* address) const = 0;
 
+  /**
+   * Get the certificate used to authenticate the peer.
+   */
+  virtual ssl::X509UniquePtr getPeerCert() const { return nullptr; }
+
+  /**
+   * The local certificate used for this connection. May be null
+   */
+  virtual const X509* getSelfCert() const {
+    return nullptr;
+  }
+
   /**
    * @return True iff end of record tracking is enabled
    */
@@ -332,6 +361,45 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase {
   virtual size_t getAppBytesReceived() const = 0;
   virtual size_t getRawBytesReceived() const = 0;
 
+  class BufferCallback {
+   public:
+    virtual ~BufferCallback() {}
+    virtual void onEgressBuffered() = 0;
+    virtual void onEgressBufferCleared() = 0;
+  };
+
+  /**
+   * Callback class to signal when a transport that did not have replay
+   * protection gains replay protection. This is needed for 0-RTT security
+   * protocols.
+   */
+  class ReplaySafetyCallback {
+   public:
+    virtual ~ReplaySafetyCallback() = default;
+
+    /**
+     * Called when the transport becomes replay safe.
+     */
+    virtual void onReplaySafe() = 0;
+  };
+
+  /**
+   * False if the transport does not have replay protection, but will in the
+   * future.
+   */
+  virtual bool isReplaySafe() const { return true; }
+
+  /**
+   * Set the ReplaySafeCallback on this transport.
+   *
+   * This should only be called if isReplaySafe() returns false.
+   */
+  virtual void setReplaySafetyCallback(ReplaySafetyCallback* callback) {
+    if (callback) {
+      CHECK(false) << "setReplaySafetyCallback() not supported";
+    }
+  }
+
  protected:
   virtual ~AsyncTransport() = default;
 };
@@ -419,6 +487,15 @@ class AsyncReader {
       return false;
     }
 
+    /**
+     * Suggested buffer size, allocated for read operations,
+     * if callback is movable and supports folly::IOBuf
+     */
+
+    virtual size_t maxBufferSize() const {
+      return 64 * 1024; // 64K
+    }
+
     /**
      * readBufferAvailable() will be invoked when data has been successfully
      * read.
@@ -432,7 +509,7 @@ class AsyncReader {
      */
 
     virtual void readBufferAvailable(std::unique_ptr<IOBuf> /*readBuf*/)
-      noexcept {};
+      noexcept {}
 
     /**
      * readEOF() will be invoked when the transport is closed.
@@ -464,12 +541,6 @@ class AsyncReader {
 
 class AsyncWriter {
  public:
-  class BufferCallback {
-   public:
-    virtual ~BufferCallback() {}
-    virtual void onEgressBuffered() = 0;
-  };
-
   class WriteCallback {
    public:
     virtual ~WriteCallback() = default;
@@ -499,15 +570,12 @@ class AsyncWriter {
 
   // Write methods that aren't part of AsyncTransport
   virtual void write(WriteCallback* callback, const void* buf, size_t bytes,
-                     WriteFlags flags = WriteFlags::NONE,
-                     BufferCallback* bufCallback = nullptr) = 0;
+                     WriteFlags flags = WriteFlags::NONE) = 0;
   virtual void writev(WriteCallback* callback, const iovec* vec, size_t count,
-                      WriteFlags flags = WriteFlags::NONE,
-                      BufferCallback* bufCallback = nullptr) = 0;
+                      WriteFlags flags = WriteFlags::NONE) = 0;
   virtual void writeChain(WriteCallback* callback,
                           std::unique_ptr<IOBuf>&& buf,
-                          WriteFlags flags = WriteFlags::NONE,
-                          BufferCallback* bufCallback = nullptr) = 0;
+                          WriteFlags flags = WriteFlags::NONE) = 0;
 
  protected:
   virtual ~AsyncWriter() = default;
@@ -525,28 +593,48 @@ class AsyncTransportWrapper : virtual public AsyncTransport,
   // to keep compatibility.
   using ReadCallback    = AsyncReader::ReadCallback;
   using WriteCallback   = AsyncWriter::WriteCallback;
-  using BufferCallback  = AsyncWriter::BufferCallback;
   virtual void setReadCB(ReadCallback* callback) override = 0;
   virtual ReadCallback* getReadCallback() const override = 0;
   virtual void write(WriteCallback* callback, const void* buf, size_t bytes,
-                     WriteFlags flags = WriteFlags::NONE,
-                     BufferCallback* bufCallback = nullptr) override = 0;
+                     WriteFlags flags = WriteFlags::NONE) override = 0;
   virtual void writev(WriteCallback* callback, const iovec* vec, size_t count,
-                      WriteFlags flags = WriteFlags::NONE,
-                      BufferCallback* bufCallback = nullptr) override = 0;
+                      WriteFlags flags = WriteFlags::NONE) override = 0;
   virtual void writeChain(WriteCallback* callback,
                           std::unique_ptr<IOBuf>&& buf,
-                          WriteFlags flags = WriteFlags::NONE,
-                          BufferCallback* bufCallback = nullptr) override = 0;
+                          WriteFlags flags = WriteFlags::NONE) override = 0;
   /**
    * The transport wrapper may wrap another transport. This returns the
    * transport that is wrapped. It returns nullptr if there is no wrapped
    * transport.
    */
-  virtual AsyncTransportWrapper* getWrappedTransport() {
+  virtual const AsyncTransportWrapper* getWrappedTransport() const {
+    return nullptr;
+  }
+
+  /**
+   * In many cases when we need to set socket properties or otherwise access the
+   * underlying transport from a wrapped transport. This method allows access to
+   * the derived classes of the underlying transport.
+   */
+  template <class T>
+  const T* getUnderlyingTransport() const {
+    const AsyncTransportWrapper* current = this;
+    while (current) {
+      auto sock = dynamic_cast<const T*>(current);
+      if (sock) {
+        return sock;
+      }
+      current = current->getWrappedTransport();
+    }
     return nullptr;
   }
 
+  template <class T>
+  T* getUnderlyingTransport() {
+    return const_cast<T*>(static_cast<const AsyncTransportWrapper*>(this)
+        ->getUnderlyingTransport<T>());
+  }
+
   /**
    * Return the application protocol being used by the underlying transport
    * protocol. This is useful for transports which are used to tunnel other
@@ -555,6 +643,11 @@ class AsyncTransportWrapper : virtual public AsyncTransport,
   virtual std::string getApplicationProtocol() noexcept {
     return "";
   }
+
+  /**
+   * Returns the name of the security protocol being used.
+   */
+  virtual std::string getSecurityProtocol() const { return ""; }
 };
 
 } // folly