X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=folly%2Fio%2Fasync%2FAsyncTransport.h;h=337fba960c71eb2624e5ea91afdf631ec91fbead;hb=520e20a8baebc98a5ec84d67865da4cf9819f88a;hp=031b88e41e733b06d4f72e1b9198c66328683da2;hpb=7749a46977a772b1f8d310c055875a90bed3efa9;p=folly.git diff --git a/folly/io/async/AsyncTransport.h b/folly/io/async/AsyncTransport.h index 031b88e4..337fba96 100644 --- a/folly/io/async/AsyncTransport.h +++ b/folly/io/async/AsyncTransport.h @@ -1,5 +1,5 @@ /* - * Copyright 2015 Facebook, Inc. + * Copyright 2017 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,13 +17,14 @@ #pragma once #include -#include +#include +#include #include #include -#include - -#include +#include +#include +#include constexpr bool kOpenSslModeMoveBufferOwnership = #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP @@ -37,7 +38,6 @@ namespace folly { class AsyncSocketException; class EventBase; -class IOBuf; class SocketAddress; /* @@ -70,6 +70,14 @@ inline WriteFlags operator|(WriteFlags a, WriteFlags b) { static_cast(a) | static_cast(b)); } +/* + * compound assignment union operator + */ +inline WriteFlags& operator|=(WriteFlags& a, WriteFlags b) { + a = a | b; + return a; +} + /* * intersection operator */ @@ -78,6 +86,14 @@ inline WriteFlags operator&(WriteFlags a, WriteFlags b) { static_cast(a) & static_cast(b)); } +/* + * compound assignment intersection operator + */ +inline WriteFlags& operator&=(WriteFlags& a, WriteFlags b) { + a = a & b; + return a; +} + /* * exclusion parameter */ @@ -222,6 +238,16 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase { */ virtual bool readable() const = 0; + /** + * Determine if the transport is writable or not. + * + * @return true iff the transport is writable, false otherwise. + */ + virtual bool writable() const { + // By default return good() - leave it to implementers to override. + return good(); + } + /** * Determine if the there is pending data on the transport. * @@ -305,7 +331,21 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase { */ virtual void getLocalAddress(SocketAddress* address) const = 0; - virtual void getAddress(SocketAddress* address) const { + /** + * Get the address of the remote endpoint to which this transport is + * connected. + * + * This function may throw AsyncSocketException on error. + * + * @return Return the local address + */ + SocketAddress getLocalAddress() const { + SocketAddress addr; + getLocalAddress(&addr); + return addr; + } + + void getAddress(SocketAddress* address) const override { getLocalAddress(address); } @@ -320,6 +360,32 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase { */ virtual void getPeerAddress(SocketAddress* address) const = 0; + /** + * Get the address of the remote endpoint to which this transport is + * connected. + * + * This function may throw AsyncSocketException on error. + * + * @return Return the remote endpoint's address + */ + SocketAddress getPeerAddress() const { + SocketAddress addr; + getPeerAddress(&addr); + return addr; + } + + /** + * Get the certificate used to authenticate the peer. + */ + virtual ssl::X509UniquePtr getPeerCert() const { return nullptr; } + + /** + * The local certificate used for this connection. May be null + */ + virtual const X509* getSelfCert() const { + return nullptr; + } + /** * @return True iff end of record tracking is enabled */ @@ -332,8 +398,47 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase { virtual size_t getAppBytesReceived() const = 0; virtual size_t getRawBytesReceived() const = 0; + class BufferCallback { + public: + virtual ~BufferCallback() {} + virtual void onEgressBuffered() = 0; + virtual void onEgressBufferCleared() = 0; + }; + + /** + * Callback class to signal when a transport that did not have replay + * protection gains replay protection. This is needed for 0-RTT security + * protocols. + */ + class ReplaySafetyCallback { + public: + virtual ~ReplaySafetyCallback() = default; + + /** + * Called when the transport becomes replay safe. + */ + virtual void onReplaySafe() = 0; + }; + + /** + * False if the transport does not have replay protection, but will in the + * future. + */ + virtual bool isReplaySafe() const { return true; } + + /** + * Set the ReplaySafeCallback on this transport. + * + * This should only be called if isReplaySafe() returns false. + */ + virtual void setReplaySafetyCallback(ReplaySafetyCallback* callback) { + if (callback) { + CHECK(false) << "setReplaySafetyCallback() not supported"; + } + } + protected: - virtual ~AsyncTransport() = default; + ~AsyncTransport() override = default; }; class AsyncReader { @@ -419,6 +524,15 @@ class AsyncReader { return false; } + /** + * Suggested buffer size, allocated for read operations, + * if callback is movable and supports folly::IOBuf + */ + + virtual size_t maxBufferSize() const { + return 64 * 1024; // 64K + } + /** * readBufferAvailable() will be invoked when data has been successfully * read. @@ -432,7 +546,7 @@ class AsyncReader { */ virtual void readBufferAvailable(std::unique_ptr /*readBuf*/) - noexcept {}; + noexcept {} /** * readEOF() will be invoked when the transport is closed. @@ -464,12 +578,6 @@ class AsyncReader { class AsyncWriter { public: - class BufferCallback { - public: - virtual ~BufferCallback() {} - virtual void onEgressBuffered() = 0; - }; - class WriteCallback { public: virtual ~WriteCallback() = default; @@ -499,15 +607,12 @@ class AsyncWriter { // Write methods that aren't part of AsyncTransport virtual void write(WriteCallback* callback, const void* buf, size_t bytes, - WriteFlags flags = WriteFlags::NONE, - BufferCallback* bufCallback = nullptr) = 0; + WriteFlags flags = WriteFlags::NONE) = 0; virtual void writev(WriteCallback* callback, const iovec* vec, size_t count, - WriteFlags flags = WriteFlags::NONE, - BufferCallback* bufCallback = nullptr) = 0; + WriteFlags flags = WriteFlags::NONE) = 0; virtual void writeChain(WriteCallback* callback, std::unique_ptr&& buf, - WriteFlags flags = WriteFlags::NONE, - BufferCallback* bufCallback = nullptr) = 0; + WriteFlags flags = WriteFlags::NONE) = 0; protected: virtual ~AsyncWriter() = default; @@ -525,28 +630,55 @@ class AsyncTransportWrapper : virtual public AsyncTransport, // to keep compatibility. using ReadCallback = AsyncReader::ReadCallback; using WriteCallback = AsyncWriter::WriteCallback; - using BufferCallback = AsyncWriter::BufferCallback; - virtual void setReadCB(ReadCallback* callback) override = 0; - virtual ReadCallback* getReadCallback() const override = 0; - virtual void write(WriteCallback* callback, const void* buf, size_t bytes, - WriteFlags flags = WriteFlags::NONE, - BufferCallback* bufCallback = nullptr) override = 0; - virtual void writev(WriteCallback* callback, const iovec* vec, size_t count, - WriteFlags flags = WriteFlags::NONE, - BufferCallback* bufCallback = nullptr) override = 0; - virtual void writeChain(WriteCallback* callback, - std::unique_ptr&& buf, - WriteFlags flags = WriteFlags::NONE, - BufferCallback* bufCallback = nullptr) override = 0; + void setReadCB(ReadCallback* callback) override = 0; + ReadCallback* getReadCallback() const override = 0; + void write( + WriteCallback* callback, + const void* buf, + size_t bytes, + WriteFlags flags = WriteFlags::NONE) override = 0; + void writev( + WriteCallback* callback, + const iovec* vec, + size_t count, + WriteFlags flags = WriteFlags::NONE) override = 0; + void writeChain( + WriteCallback* callback, + std::unique_ptr&& buf, + WriteFlags flags = WriteFlags::NONE) override = 0; /** * The transport wrapper may wrap another transport. This returns the * transport that is wrapped. It returns nullptr if there is no wrapped * transport. */ - virtual AsyncTransportWrapper* getWrappedTransport() { + virtual const AsyncTransportWrapper* getWrappedTransport() const { return nullptr; } + /** + * In many cases when we need to set socket properties or otherwise access the + * underlying transport from a wrapped transport. This method allows access to + * the derived classes of the underlying transport. + */ + template + const T* getUnderlyingTransport() const { + const AsyncTransportWrapper* current = this; + while (current) { + auto sock = dynamic_cast(current); + if (sock) { + return sock; + } + current = current->getWrappedTransport(); + } + return nullptr; + } + + template + T* getUnderlyingTransport() { + return const_cast(static_cast(this) + ->getUnderlyingTransport()); + } + /** * Return the application protocol being used by the underlying transport * protocol. This is useful for transports which are used to tunnel other @@ -555,6 +687,11 @@ class AsyncTransportWrapper : virtual public AsyncTransport, virtual std::string getApplicationProtocol() noexcept { return ""; } + + /** + * Returns the name of the security protocol being used. + */ + virtual std::string getSecurityProtocol() const { return ""; } }; } // folly