X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=folly%2Fio%2Fasync%2FAsyncSSLSocket.h;h=474225c060bfd66ab0bfd86e925c26037ee01ef0;hb=4c7a736d6529f22451a0ec965e093e7e318695e3;hp=99728dddcee12302547fcf4834bf3a5f431b2b97;hpb=6f8d37dc510dfbbf8beacc67b23d937bff69182d;p=folly.git diff --git a/folly/io/async/AsyncSSLSocket.h b/folly/io/async/AsyncSSLSocket.h index 99728ddd..474225c0 100644 --- a/folly/io/async/AsyncSSLSocket.h +++ b/folly/io/async/AsyncSSLSocket.h @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include #include @@ -141,7 +141,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { return timeout_; } - virtual void timeoutExpired() noexcept override { + void timeoutExpired() noexcept override { sslSocket_->timeoutExpired(timeout_); } @@ -173,10 +173,22 @@ class AsyncSSLSocket : public virtual AsyncSocket { * @param deferSecurityNegotiation * unencrypted data can be sent before sslConn/Accept */ - AsyncSSLSocket(const std::shared_ptr& ctx, - EventBase* evb, int fd, - bool server = true, bool deferSecurityNegotiation = false); + AsyncSSLSocket( + const std::shared_ptr& ctx, + EventBase* evb, + int fd, + bool server = true, + bool deferSecurityNegotiation = false); + /** + * Create a server/client AsyncSSLSocket from an already connected + * AsyncSocket. + */ + AsyncSSLSocket( + const std::shared_ptr& ctx, + AsyncSocket::UniquePtr oldAsyncSocket, + bool server = true, + bool deferSecurityNegotiation = false); /** * Helper function to create a server/client shared_ptr. @@ -227,11 +239,12 @@ class AsyncSSLSocket : public virtual AsyncSocket { * @param fd File descriptor to take over (should be a connected socket). * @param serverName tlsext_hostname that will be sent in ClientHello. */ - AsyncSSLSocket(const std::shared_ptr& ctx, - EventBase* evb, - int fd, - const std::string& serverName, - bool deferSecurityNegotiation = false); + AsyncSSLSocket( + const std::shared_ptr& ctx, + EventBase* evb, + int fd, + const std::string& serverName, + bool deferSecurityNegotiation = false); static std::shared_ptr newSocket( const std::shared_ptr& ctx, @@ -262,21 +275,21 @@ class AsyncSSLSocket : public virtual AsyncSocket { // See the documentation in TAsyncTransport.h // TODO: implement graceful shutdown in close() // TODO: implement detachSSL() that returns the SSL connection - virtual void closeNow() override; - virtual void shutdownWrite() override; - virtual void shutdownWriteNow() override; - virtual bool good() const override; - virtual bool connecting() const override; - virtual std::string getApplicationProtocol() noexcept override; - - virtual std::string getSecurityProtocol() const override { return "TLS"; } - - virtual void setEorTracking(bool track) override; - virtual size_t getRawBytesWritten() const override; - virtual size_t getRawBytesReceived() const override; - void enableClientHelloParsing(); + void closeNow() override; + void shutdownWrite() override; + void shutdownWriteNow() override; + bool good() const override; + bool connecting() const override; + std::string getApplicationProtocol() noexcept override; + + std::string getSecurityProtocol() const override { + return "TLS"; + } - void setPreReceivedData(std::unique_ptr data); + void setEorTracking(bool track) override; + size_t getRawBytesWritten() const override; + size_t getRawBytesReceived() const override; + void enableClientHelloParsing(); /** * Accept an SSL connection on the socket. @@ -502,21 +515,21 @@ class AsyncSSLSocket : public virtual AsyncSocket { /** * Get the certificate used for this SSL connection. May be null */ - virtual const X509* getSelfCert() const override; + const X509* getSelfCert() const override; - virtual void attachEventBase(EventBase* eventBase) override { + void attachEventBase(EventBase* eventBase) override { AsyncSocket::attachEventBase(eventBase); handshakeTimeout_.attachEventBase(eventBase); connectionTimeout_.attachEventBase(eventBase); } - virtual void detachEventBase() override { + void detachEventBase() override { AsyncSocket::detachEventBase(); handshakeTimeout_.detachEventBase(); connectionTimeout_.detachEventBase(); } - virtual bool isDetachable() const override { + bool isDetachable() const override { return AsyncSocket::isDetachable() && !handshakeTimeout_.isScheduled(); } @@ -597,6 +610,13 @@ class AsyncSSLSocket : public virtual AsyncSocket { std::string getSSLAlertsReceived() const; + /* + * Save an optional alert message generated during certificate verify + */ + void setSSLCertVerificationAlert(std::string alert); + + std::string getSSLCertVerificationAlert() const; + /** * Get the list of shared ciphers between the server and the client. * Works well for only SSLv2, not so good for SSLv3 or TLSv1. @@ -657,7 +677,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { /** * Returns the peer certificate, or nullptr if no peer certificate received. */ - virtual ssl::X509UniquePtr getPeerCert() const override { + ssl::X509UniquePtr getPeerCert() const override { if (!ssl_) { return nullptr; } @@ -716,7 +736,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { * destroy() instead. (See the documentation in DelayedDestruction.h for * more details.) */ - ~AsyncSSLSocket(); + ~AsyncSSLSocket() override; // Inherit event notification methods from AsyncSocket except // the following. @@ -731,7 +751,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { int* sslErrorOut, unsigned long* errErrorOut) noexcept; - virtual void checkForImmediateRead() noexcept override; + void checkForImmediateRead() noexcept override; // AsyncSocket calls this at the wrong time for SSL void handleInitialReadWrite() noexcept override {} @@ -791,8 +811,6 @@ class AsyncSSLSocket : public virtual AsyncSocket { void invokeConnectSuccess() override; void scheduleConnectTimeout() override; - void cacheLocalPeerAddr(); - void startSSLConnect(); static void sslInfoCallback(const SSL *ssl, int type, int val); @@ -857,7 +875,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { bool sessionResumptionAttempted_{false}; std::chrono::milliseconds totalConnectTimeout_{0}; - std::unique_ptr preReceivedData_; + std::string sslVerificationAlert_; }; } // namespace