X-Git-Url: http://plrg.eecs.uci.edu/git/?p=folly.git;a=blobdiff_plain;f=folly%2Fio%2Fasync%2FAsyncSSLSocket.h;h=474225c060bfd66ab0bfd86e925c26037ee01ef0;hp=6f8b14644bba8fa7875d021bfafa01a20917e436;hb=4c7a736d6529f22451a0ec965e093e7e318695e3;hpb=5c52b281c005e7d7a735bf2ffb6b3de14896cf2e diff --git a/folly/io/async/AsyncSSLSocket.h b/folly/io/async/AsyncSSLSocket.h index 6f8b1464..474225c0 100644 --- a/folly/io/async/AsyncSSLSocket.h +++ b/folly/io/async/AsyncSSLSocket.h @@ -1,5 +1,5 @@ /* - * Copyright 2016 Facebook, Inc. + * Copyright 2017 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include #include @@ -32,6 +32,7 @@ #include #include #include +#include #include namespace folly { @@ -122,32 +123,31 @@ class AsyncSSLSocket : public virtual AsyncSocket { noexcept = 0; }; - class HandshakeTimeout : public AsyncTimeout { + class Timeout : public AsyncTimeout { public: - HandshakeTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase) - : AsyncTimeout(eventBase) - , sslSocket_(sslSocket) {} + Timeout(AsyncSSLSocket* sslSocket, EventBase* eventBase) + : AsyncTimeout(eventBase), sslSocket_(sslSocket) {} - virtual void timeoutExpired() noexcept { - sslSocket_->timeoutExpired(); + bool scheduleTimeout(TimeoutManager::timeout_type timeout) { + timeout_ = timeout; + return AsyncTimeout::scheduleTimeout(timeout); } - private: - AsyncSSLSocket* sslSocket_; - }; + bool scheduleTimeout(uint32_t timeoutMs) { + return scheduleTimeout(std::chrono::milliseconds{timeoutMs}); + } - // Timer for if we fallback from SSL connects to TCP connects - class ConnectionTimeout : public AsyncTimeout { - public: - ConnectionTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase) - : AsyncTimeout(eventBase), sslSocket_(sslSocket) {} + TimeoutManager::timeout_type getTimeout() { + return timeout_; + } - virtual void timeoutExpired() noexcept override { - sslSocket_->timeoutExpired(); + void timeoutExpired() noexcept override { + sslSocket_->timeoutExpired(timeout_); } private: AsyncSSLSocket* sslSocket_; + TimeoutManager::timeout_type timeout_; }; /** @@ -173,10 +173,22 @@ class AsyncSSLSocket : public virtual AsyncSocket { * @param deferSecurityNegotiation * unencrypted data can be sent before sslConn/Accept */ - AsyncSSLSocket(const std::shared_ptr& ctx, - EventBase* evb, int fd, - bool server = true, bool deferSecurityNegotiation = false); + AsyncSSLSocket( + const std::shared_ptr& ctx, + EventBase* evb, + int fd, + bool server = true, + bool deferSecurityNegotiation = false); + /** + * Create a server/client AsyncSSLSocket from an already connected + * AsyncSocket. + */ + AsyncSSLSocket( + const std::shared_ptr& ctx, + AsyncSocket::UniquePtr oldAsyncSocket, + bool server = true, + bool deferSecurityNegotiation = false); /** * Helper function to create a server/client shared_ptr. @@ -202,7 +214,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { } -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI /** * Create a client AsyncSSLSocket with tlsext_servername in * the Client Hello message. @@ -227,11 +239,12 @@ class AsyncSSLSocket : public virtual AsyncSocket { * @param fd File descriptor to take over (should be a connected socket). * @param serverName tlsext_hostname that will be sent in ClientHello. */ - AsyncSSLSocket(const std::shared_ptr& ctx, - EventBase* evb, - int fd, - const std::string& serverName, - bool deferSecurityNegotiation = false); + AsyncSSLSocket( + const std::shared_ptr& ctx, + EventBase* evb, + int fd, + const std::string& serverName, + bool deferSecurityNegotiation = false); static std::shared_ptr newSocket( const std::shared_ptr& ctx, @@ -242,7 +255,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { new AsyncSSLSocket(ctx, evb, serverName, deferSecurityNegotiation), Destructor()); } -#endif +#endif // FOLLY_OPENSSL_HAS_SNI /** * TODO: implement support for SSL renegotiation. @@ -262,19 +275,20 @@ class AsyncSSLSocket : public virtual AsyncSocket { // See the documentation in TAsyncTransport.h // TODO: implement graceful shutdown in close() // TODO: implement detachSSL() that returns the SSL connection - virtual void closeNow() override; - virtual void shutdownWrite() override; - virtual void shutdownWriteNow() override; - virtual bool good() const override; - virtual bool connecting() const override; - virtual std::string getApplicationProtocol() noexcept override; - - virtual std::string getSecurityProtocol() const override { return "TLS"; } - - bool isEorTrackingEnabled() const override; - virtual void setEorTracking(bool track) override; - virtual size_t getRawBytesWritten() const override; - virtual size_t getRawBytesReceived() const override; + void closeNow() override; + void shutdownWrite() override; + void shutdownWriteNow() override; + bool good() const override; + bool connecting() const override; + std::string getApplicationProtocol() noexcept override; + + std::string getSecurityProtocol() const override { + return "TLS"; + } + + void setEorTracking(bool track) override; + size_t getRawBytesWritten() const override; + size_t getRawBytesReceived() const override; void enableClientHelloParsing(); /** @@ -292,9 +306,11 @@ class AsyncSSLSocket : public virtual AsyncSocket { * context by default, can be set explcitly to override the * method in the context */ - virtual void sslAccept(HandshakeCB* callback, uint32_t timeout = 0, + virtual void sslAccept( + HandshakeCB* callback, + std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(), const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer = - folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); + folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); /** * Invoke SSL accept following an asynchronous session cache lookup @@ -313,6 +329,29 @@ class AsyncSSLSocket : public virtual AsyncSocket { const folly::SocketAddress& bindAddr = anyAddress()) noexcept override; + /** + * A variant of connect that allows the caller to specify + * the timeout for the regular connect and the ssl connect + * separately. + * connectTimeout is specified as the time to establish a TCP + * connection. + * totalConnectTimeout defines the + * time it takes from starting the TCP connection to the time + * the ssl connection is established. The reason the timeout is + * defined this way is because user's rarely need to specify the SSL + * timeout independently of the connect timeout. It allows us to + * bound the time for a connect and SSL connection in + * a finer grained manner than if timeout was just defined + * independently for SSL. + */ + virtual void connect( + ConnectCallback* callback, + const folly::SocketAddress& address, + std::chrono::milliseconds connectTimeout, + std::chrono::milliseconds totalConnectTimeout, + const OptionMap& options = emptyOptionMap, + const folly::SocketAddress& bindAddr = anyAddress()) noexcept; + using AsyncSocket::connect; /** @@ -332,9 +371,11 @@ class AsyncSSLSocket : public virtual AsyncSocket { * SSL_VERIFY_PEER and invokes * HandshakeCB::handshakeVer(). */ - virtual void sslConn(HandshakeCB *callback, uint64_t timeout = 0, - const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer = - folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); + virtual void sslConn( + HandshakeCB* callback, + std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(), + const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer = + folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); enum SSLStateEnum { STATE_UNINIT, @@ -474,21 +515,21 @@ class AsyncSSLSocket : public virtual AsyncSocket { /** * Get the certificate used for this SSL connection. May be null */ - virtual const X509* getSelfCert() const override; + const X509* getSelfCert() const override; - virtual void attachEventBase(EventBase* eventBase) override { + void attachEventBase(EventBase* eventBase) override { AsyncSocket::attachEventBase(eventBase); handshakeTimeout_.attachEventBase(eventBase); connectionTimeout_.attachEventBase(eventBase); } - virtual void detachEventBase() override { + void detachEventBase() override { AsyncSocket::detachEventBase(); handshakeTimeout_.detachEventBase(); connectionTimeout_.detachEventBase(); } - virtual bool isDetachable() const override { + bool isDetachable() const override { return AsyncSocket::isDetachable() && !handshakeTimeout_.isScheduled(); } @@ -514,7 +555,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { void detachSSLContext(); #endif -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI /** * Switch the SSLContext to continue the SSL handshake. * It can only be used in server mode. @@ -537,9 +578,9 @@ class AsyncSSLSocket : public virtual AsyncSocket { * ClientHello message. */ void setServerName(std::string serverName) noexcept; -#endif +#endif // FOLLY_OPENSSL_HAS_SNI - void timeoutExpired() noexcept; + void timeoutExpired(std::chrono::milliseconds timeout) noexcept; /** * Get the list of supported ciphers sent by the client in the client's @@ -561,8 +602,21 @@ class AsyncSSLSocket : public virtual AsyncSocket { std::string getSSLClientSigAlgs() const; + /** + * Get the list of versions in the supported versions extension (used to + * negotiate TLS 1.3). + */ + std::string getSSLClientSupportedVersions() const; + std::string getSSLAlertsReceived() const; + /* + * Save an optional alert message generated during certificate verify + */ + void setSSLCertVerificationAlert(std::string alert); + + std::string getSSLCertVerificationAlert() const; + /** * Get the list of shared ciphers between the server and the client. * Works well for only SSLv2, not so good for SSLv3 or TLSv1. @@ -575,9 +629,17 @@ class AsyncSSLSocket : public virtual AsyncSocket { */ void getSSLServerCiphers(std::string& serverCiphers) const; + /** + * Method to check if peer verfication is set. + * + * @return true if peer verification is required. + */ + bool needsPeerVerification() const; + static int getSSLExDataIndex(); static AsyncSSLSocket* getFromSSL(const SSL *ssl); static int bioWrite(BIO* b, const char* in, int inl); + static int bioRead(BIO* b, char* out, int outl); void resetClientHelloParsing(SSL *ssl); static void clientHelloParsingCallback(int write_p, int version, int content_type, const void *buf, size_t len, SSL *ssl, void *arg); @@ -615,7 +677,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { /** * Returns the peer certificate, or nullptr if no peer certificate received. */ - virtual ssl::X509UniquePtr getPeerCert() const override { + ssl::X509UniquePtr getPeerCert() const override { if (!ssl_) { return nullptr; } @@ -651,6 +713,16 @@ class AsyncSSLSocket : public virtual AsyncSocket { return sessionResumptionAttempted_; } + /** + * If the SSL socket was used to connect as well + * as establish an SSL connection, this gives the total + * timeout for the connect + SSL connection that was + * set. + */ + std::chrono::milliseconds getTotalConnectTimeout() const { + return totalConnectTimeout_; + } + private: void init(); @@ -664,7 +736,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { * destroy() instead. (See the documentation in DelayedDestruction.h for * more details.) */ - ~AsyncSSLSocket(); + ~AsyncSSLSocket() override; // Inherit event notification methods from AsyncSocket except // the following. @@ -679,7 +751,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { int* sslErrorOut, unsigned long* errErrorOut) noexcept; - virtual void checkForImmediateRead() noexcept override; + void checkForImmediateRead() noexcept override; // AsyncSocket calls this at the wrong time for SSL void handleInitialReadWrite() noexcept override {} @@ -739,8 +811,6 @@ class AsyncSSLSocket : public virtual AsyncSocket { void invokeConnectSuccess() override; void scheduleConnectTimeout() override; - void cacheLocalPeerAddr(); - void startSSLConnect(); static void sslInfoCallback(const SSL *ssl, int type, int val); @@ -761,14 +831,11 @@ class AsyncSSLSocket : public virtual AsyncSocket { HandshakeCB* handshakeCallback_{nullptr}; SSL* ssl_{nullptr}; SSL_SESSION *sslSession_{nullptr}; - HandshakeTimeout handshakeTimeout_; - ConnectionTimeout connectionTimeout_; + Timeout handshakeTimeout_; + Timeout connectionTimeout_; // whether the SSL session was resumed using session ID or not bool sessionIDResumed_{false}; - // Whether to track EOR or not. - bool trackEor_{false}; - // The app byte num that we are tracking for the MSG_EOR // Only one app EOR byte can be tracked. size_t appEorByteNo_{0}; @@ -780,7 +847,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { // When openssl is about to sendmsg() across the minEorRawBytesNo_, // it will pass MSG_EOR to sendmsg(). size_t minEorRawByteNo_{0}; -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI std::shared_ptr handshakeCtx_; std::string tlsextHostname_; #endif @@ -804,8 +871,11 @@ class AsyncSSLSocket : public virtual AsyncSocket { // Time taken to complete the ssl handshake. std::chrono::steady_clock::time_point handshakeStartTime_; std::chrono::steady_clock::time_point handshakeEndTime_; - uint64_t handshakeConnectTimeout_{0}; + std::chrono::milliseconds handshakeConnectTimeout_{0}; bool sessionResumptionAttempted_{false}; + std::chrono::milliseconds totalConnectTimeout_{0}; + + std::string sslVerificationAlert_; }; } // namespace