/*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
#pragma once
#include <memory>
-#include <sys/uio.h>
+#include <folly/io/IOBuf.h>
+#include <folly/io/async/AsyncSocketBase.h>
#include <folly/io/async/DelayedDestruction.h>
#include <folly/io/async/EventBase.h>
-#include <folly/io/async/AsyncSocketBase.h>
-
-#include <openssl/ssl.h>
+#include <folly/portability/OpenSSL.h>
+#include <folly/portability/SysUio.h>
+#include <folly/ssl/OpenSSLPtrTypes.h>
constexpr bool kOpenSslModeMoveBufferOwnership =
#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
class AsyncSocketException;
class EventBase;
-class IOBuf;
class SocketAddress;
/*
static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
}
+/*
+ * compound assignment union operator
+ */
+inline WriteFlags& operator|=(WriteFlags& a, WriteFlags b) {
+ a = a | b;
+ return a;
+}
+
/*
* intersection operator
*/
static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
}
+/*
+ * compound assignment intersection operator
+ */
+inline WriteFlags& operator&=(WriteFlags& a, WriteFlags b) {
+ a = a & b;
+ return a;
+}
+
/*
* exclusion parameter
*/
*/
virtual bool readable() const = 0;
+ /**
+ * Determine if the transport is writable or not.
+ *
+ * @return true iff the transport is writable, false otherwise.
+ */
+ virtual bool writable() const {
+ // By default return good() - leave it to implementers to override.
+ return good();
+ }
+
/**
* Determine if the there is pending data on the transport.
*
*/
virtual void getLocalAddress(SocketAddress* address) const = 0;
- virtual void getAddress(SocketAddress* address) const {
+ /**
+ * Get the address of the remote endpoint to which this transport is
+ * connected.
+ *
+ * This function may throw AsyncSocketException on error.
+ *
+ * @return Return the local address
+ */
+ SocketAddress getLocalAddress() const {
+ SocketAddress addr;
+ getLocalAddress(&addr);
+ return addr;
+ }
+
+ void getAddress(SocketAddress* address) const override {
getLocalAddress(address);
}
*/
virtual void getPeerAddress(SocketAddress* address) const = 0;
+ /**
+ * Get the address of the remote endpoint to which this transport is
+ * connected.
+ *
+ * This function may throw AsyncSocketException on error.
+ *
+ * @return Return the remote endpoint's address
+ */
+ SocketAddress getPeerAddress() const {
+ SocketAddress addr;
+ getPeerAddress(&addr);
+ return addr;
+ }
+
+ /**
+ * Get the certificate used to authenticate the peer.
+ */
+ virtual ssl::X509UniquePtr getPeerCert() const { return nullptr; }
+
+ /**
+ * The local certificate used for this connection. May be null
+ */
+ virtual const X509* getSelfCert() const {
+ return nullptr;
+ }
+
/**
* @return True iff end of record tracking is enabled
*/
virtual size_t getAppBytesReceived() const = 0;
virtual size_t getRawBytesReceived() const = 0;
+ class BufferCallback {
+ public:
+ virtual ~BufferCallback() {}
+ virtual void onEgressBuffered() = 0;
+ virtual void onEgressBufferCleared() = 0;
+ };
+
+ /**
+ * Callback class to signal when a transport that did not have replay
+ * protection gains replay protection. This is needed for 0-RTT security
+ * protocols.
+ */
+ class ReplaySafetyCallback {
+ public:
+ virtual ~ReplaySafetyCallback() = default;
+
+ /**
+ * Called when the transport becomes replay safe.
+ */
+ virtual void onReplaySafe() = 0;
+ };
+
+ /**
+ * False if the transport does not have replay protection, but will in the
+ * future.
+ */
+ virtual bool isReplaySafe() const { return true; }
+
+ /**
+ * Set the ReplaySafeCallback on this transport.
+ *
+ * This should only be called if isReplaySafe() returns false.
+ */
+ virtual void setReplaySafetyCallback(ReplaySafetyCallback* callback) {
+ if (callback) {
+ CHECK(false) << "setReplaySafetyCallback() not supported";
+ }
+ }
+
protected:
- virtual ~AsyncTransport() = default;
+ ~AsyncTransport() override = default;
};
class AsyncReader {
return false;
}
+ /**
+ * Suggested buffer size, allocated for read operations,
+ * if callback is movable and supports folly::IOBuf
+ */
+
+ virtual size_t maxBufferSize() const {
+ return 64 * 1024; // 64K
+ }
+
/**
* readBufferAvailable() will be invoked when data has been successfully
* read.
*/
virtual void readBufferAvailable(std::unique_ptr<IOBuf> /*readBuf*/)
- noexcept {};
+ noexcept {}
/**
* readEOF() will be invoked when the transport is closed.
class AsyncWriter {
public:
- class BufferCallback {
- public:
- virtual ~BufferCallback() {}
- virtual void onEgressBuffered() = 0;
- };
-
class WriteCallback {
public:
virtual ~WriteCallback() = default;
// Write methods that aren't part of AsyncTransport
virtual void write(WriteCallback* callback, const void* buf, size_t bytes,
- WriteFlags flags = WriteFlags::NONE,
- BufferCallback* bufCallback = nullptr) = 0;
+ WriteFlags flags = WriteFlags::NONE) = 0;
virtual void writev(WriteCallback* callback, const iovec* vec, size_t count,
- WriteFlags flags = WriteFlags::NONE,
- BufferCallback* bufCallback = nullptr) = 0;
+ WriteFlags flags = WriteFlags::NONE) = 0;
virtual void writeChain(WriteCallback* callback,
std::unique_ptr<IOBuf>&& buf,
- WriteFlags flags = WriteFlags::NONE,
- BufferCallback* bufCallback = nullptr) = 0;
+ WriteFlags flags = WriteFlags::NONE) = 0;
protected:
virtual ~AsyncWriter() = default;
// to keep compatibility.
using ReadCallback = AsyncReader::ReadCallback;
using WriteCallback = AsyncWriter::WriteCallback;
- using BufferCallback = AsyncWriter::BufferCallback;
- virtual void setReadCB(ReadCallback* callback) override = 0;
- virtual ReadCallback* getReadCallback() const override = 0;
- virtual void write(WriteCallback* callback, const void* buf, size_t bytes,
- WriteFlags flags = WriteFlags::NONE,
- BufferCallback* bufCallback = nullptr) override = 0;
- virtual void writev(WriteCallback* callback, const iovec* vec, size_t count,
- WriteFlags flags = WriteFlags::NONE,
- BufferCallback* bufCallback = nullptr) override = 0;
- virtual void writeChain(WriteCallback* callback,
- std::unique_ptr<IOBuf>&& buf,
- WriteFlags flags = WriteFlags::NONE,
- BufferCallback* bufCallback = nullptr) override = 0;
+ void setReadCB(ReadCallback* callback) override = 0;
+ ReadCallback* getReadCallback() const override = 0;
+ void write(
+ WriteCallback* callback,
+ const void* buf,
+ size_t bytes,
+ WriteFlags flags = WriteFlags::NONE) override = 0;
+ void writev(
+ WriteCallback* callback,
+ const iovec* vec,
+ size_t count,
+ WriteFlags flags = WriteFlags::NONE) override = 0;
+ void writeChain(
+ WriteCallback* callback,
+ std::unique_ptr<IOBuf>&& buf,
+ WriteFlags flags = WriteFlags::NONE) override = 0;
/**
* The transport wrapper may wrap another transport. This returns the
* transport that is wrapped. It returns nullptr if there is no wrapped
* transport.
*/
- virtual AsyncTransportWrapper* getWrappedTransport() {
+ virtual const AsyncTransportWrapper* getWrappedTransport() const {
return nullptr;
}
* the derived classes of the underlying transport.
*/
template <class T>
- T* getUnderlyingTransport() {
- AsyncTransportWrapper* current = this;
+ const T* getUnderlyingTransport() const {
+ const AsyncTransportWrapper* current = this;
while (current) {
- auto sock = dynamic_cast<T*>(current);
+ auto sock = dynamic_cast<const T*>(current);
if (sock) {
return sock;
}
return nullptr;
}
+ template <class T>
+ T* getUnderlyingTransport() {
+ return const_cast<T*>(static_cast<const AsyncTransportWrapper*>(this)
+ ->getUnderlyingTransport<T>());
+ }
+
/**
* Return the application protocol being used by the underlying transport
* protocol. This is useful for transports which are used to tunnel other
virtual std::string getApplicationProtocol() noexcept {
return "";
}
+
+ /**
+ * Returns the name of the security protocol being used.
+ */
+ virtual std::string getSecurityProtocol() const { return ""; }
};
} // folly