X-Git-Url: http://plrg.eecs.uci.edu/git/?p=folly.git;a=blobdiff_plain;f=folly%2Fio%2Fasync%2FAsyncSSLSocket.h;h=474225c060bfd66ab0bfd86e925c26037ee01ef0;hp=8f3c8bd086870ea83901bc8fa378f1c31ecb031d;hb=4c7a736d6529f22451a0ec965e093e7e318695e3;hpb=ff9b70f3cd1f05fb8e8c4351248cd9f748c2644a diff --git a/folly/io/async/AsyncSSLSocket.h b/folly/io/async/AsyncSSLSocket.h index 8f3c8bd0..474225c0 100644 --- a/folly/io/async/AsyncSSLSocket.h +++ b/folly/io/async/AsyncSSLSocket.h @@ -1,5 +1,5 @@ /* - * Copyright 2014 Facebook, Inc. + * Copyright 2017 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,37 +16,27 @@ #pragma once -#include #include -#include #include #include #include -#include #include +#include #include +#include +#include +#include +#include #include #include #include - -using folly::io::Cursor; -using std::unique_ptr; +#include +#include namespace folly { -class SSLException: public folly::AsyncSocketException { - public: - SSLException(int sslError, int errno_copy); - - int getSSLError() const { return error_; } - - protected: - int error_; - char msg_[256]; -}; - /** * A class for performing asynchronous I/O on an SSL connection. * @@ -80,10 +70,11 @@ class SSLException: public folly::AsyncSocketException { class AsyncSSLSocket : public virtual AsyncSocket { public: typedef std::unique_ptr UniquePtr; + using X509_deleter = folly::static_function_deleter; class HandshakeCB { public: - virtual ~HandshakeCB() {} + virtual ~HandshakeCB() = default; /** * handshakeVer() is invoked during handshaking to give the @@ -99,9 +90,9 @@ class AsyncSSLSocket : public virtual AsyncSocket { * See the passages on verify_callback in SSL_CTX_set_verify(3) * for more details. */ - virtual bool handshakeVer(AsyncSSLSocket* sock, + virtual bool handshakeVer(AsyncSSLSocket* /*sock*/, bool preverifyOk, - X509_STORE_CTX* ctx) noexcept { + X509_STORE_CTX* /*ctx*/) noexcept { return preverifyOk; } @@ -132,37 +123,38 @@ class AsyncSSLSocket : public virtual AsyncSocket { noexcept = 0; }; - class HandshakeTimeout : public AsyncTimeout { + class Timeout : public AsyncTimeout { public: - HandshakeTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase) - : AsyncTimeout(eventBase) - , sslSocket_(sslSocket) {} + Timeout(AsyncSSLSocket* sslSocket, EventBase* eventBase) + : AsyncTimeout(eventBase), sslSocket_(sslSocket) {} - virtual void timeoutExpired() noexcept { - sslSocket_->timeoutExpired(); + bool scheduleTimeout(TimeoutManager::timeout_type timeout) { + timeout_ = timeout; + return AsyncTimeout::scheduleTimeout(timeout); } - private: - AsyncSSLSocket* sslSocket_; - }; + bool scheduleTimeout(uint32_t timeoutMs) { + return scheduleTimeout(std::chrono::milliseconds{timeoutMs}); + } + TimeoutManager::timeout_type getTimeout() { + return timeout_; + } - /** - * These are passed to the application via errno, packed in an SSL err which - * are outside the valid errno range. The values are chosen to be unique - * against values in ssl.h - */ - enum SSLError { - SSL_CLIENT_RENEGOTIATION_ATTEMPT = 900, - SSL_INVALID_RENEGOTIATION = 901, - SSL_EARLY_WRITE = 902 + void timeoutExpired() noexcept override { + sslSocket_->timeoutExpired(timeout_); + } + + private: + AsyncSSLSocket* sslSocket_; + TimeoutManager::timeout_type timeout_; }; /** * Create a client AsyncSSLSocket */ AsyncSSLSocket(const std::shared_ptr &ctx, - EventBase* evb); + EventBase* evb, bool deferSecurityNegotiation = false); /** * Create a server/client AsyncSSLSocket from an already connected @@ -178,19 +170,35 @@ class AsyncSSLSocket : public virtual AsyncSocket { * @param evb EventBase that will manage this socket. * @param fd File descriptor to take over (should be a connected socket). * @param server Is socket in server mode? + * @param deferSecurityNegotiation + * unencrypted data can be sent before sslConn/Accept */ - AsyncSSLSocket(const std::shared_ptr& ctx, - EventBase* evb, int fd, bool server = true); + AsyncSSLSocket( + const std::shared_ptr& ctx, + EventBase* evb, + int fd, + bool server = true, + bool deferSecurityNegotiation = false); + /** + * Create a server/client AsyncSSLSocket from an already connected + * AsyncSocket. + */ + AsyncSSLSocket( + const std::shared_ptr& ctx, + AsyncSocket::UniquePtr oldAsyncSocket, + bool server = true, + bool deferSecurityNegotiation = false); /** * Helper function to create a server/client shared_ptr. */ static std::shared_ptr newSocket( const std::shared_ptr& ctx, - EventBase* evb, int fd, bool server=true) { + EventBase* evb, int fd, bool server=true, + bool deferSecurityNegotiation = false) { return std::shared_ptr( - new AsyncSSLSocket(ctx, evb, fd, server), + new AsyncSSLSocket(ctx, evb, fd, server, deferSecurityNegotiation), Destructor()); } @@ -199,21 +207,22 @@ class AsyncSSLSocket : public virtual AsyncSocket { */ static std::shared_ptr newSocket( const std::shared_ptr& ctx, - EventBase* evb) { + EventBase* evb, bool deferSecurityNegotiation = false) { return std::shared_ptr( - new AsyncSSLSocket(ctx, evb), + new AsyncSSLSocket(ctx, evb, deferSecurityNegotiation), Destructor()); } -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI /** * Create a client AsyncSSLSocket with tlsext_servername in * the Client Hello message. */ AsyncSSLSocket(const std::shared_ptr &ctx, EventBase* evb, - const std::string& serverName); + const std::string& serverName, + bool deferSecurityNegotiation = false); /** * Create a client AsyncSSLSocket from an already connected @@ -230,20 +239,23 @@ class AsyncSSLSocket : public virtual AsyncSocket { * @param fd File descriptor to take over (should be a connected socket). * @param serverName tlsext_hostname that will be sent in ClientHello. */ - AsyncSSLSocket(const std::shared_ptr& ctx, - EventBase* evb, - int fd, - const std::string& serverName); + AsyncSSLSocket( + const std::shared_ptr& ctx, + EventBase* evb, + int fd, + const std::string& serverName, + bool deferSecurityNegotiation = false); static std::shared_ptr newSocket( const std::shared_ptr& ctx, EventBase* evb, - const std::string& serverName) { + const std::string& serverName, + bool deferSecurityNegotiation = false) { return std::shared_ptr( - new AsyncSSLSocket(ctx, evb, serverName), + new AsyncSSLSocket(ctx, evb, serverName, deferSecurityNegotiation), Destructor()); } -#endif +#endif // FOLLY_OPENSSL_HAS_SNI /** * TODO: implement support for SSL renegotiation. @@ -263,16 +275,20 @@ class AsyncSSLSocket : public virtual AsyncSocket { // See the documentation in TAsyncTransport.h // TODO: implement graceful shutdown in close() // TODO: implement detachSSL() that returns the SSL connection - virtual void closeNow(); - virtual void shutdownWrite(); - virtual void shutdownWriteNow(); - virtual bool good() const; - virtual bool connecting() const; - - bool isEorTrackingEnabled() const override; - virtual void setEorTracking(bool track); - virtual size_t getRawBytesWritten() const; - virtual size_t getRawBytesReceived() const; + void closeNow() override; + void shutdownWrite() override; + void shutdownWriteNow() override; + bool good() const override; + bool connecting() const override; + std::string getApplicationProtocol() noexcept override; + + std::string getSecurityProtocol() const override { + return "TLS"; + } + + void setEorTracking(bool track) override; + size_t getRawBytesWritten() const override; + size_t getRawBytesReceived() const override; void enableClientHelloParsing(); /** @@ -290,9 +306,11 @@ class AsyncSSLSocket : public virtual AsyncSocket { * context by default, can be set explcitly to override the * method in the context */ - virtual void sslAccept(HandshakeCB* callback, uint32_t timeout = 0, + virtual void sslAccept( + HandshakeCB* callback, + std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(), const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer = - folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); + folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); /** * Invoke SSL accept following an asynchronous session cache lookup @@ -308,17 +326,40 @@ class AsyncSSLSocket : public virtual AsyncSocket { const folly::SocketAddress& address, int timeout = 0, const OptionMap &options = emptyOptionMap, - const folly::SocketAddress& bindAddr = anyAddress) - noexcept; + const folly::SocketAddress& bindAddr = anyAddress()) + noexcept override; + + /** + * A variant of connect that allows the caller to specify + * the timeout for the regular connect and the ssl connect + * separately. + * connectTimeout is specified as the time to establish a TCP + * connection. + * totalConnectTimeout defines the + * time it takes from starting the TCP connection to the time + * the ssl connection is established. The reason the timeout is + * defined this way is because user's rarely need to specify the SSL + * timeout independently of the connect timeout. It allows us to + * bound the time for a connect and SSL connection in + * a finer grained manner than if timeout was just defined + * independently for SSL. + */ + virtual void connect( + ConnectCallback* callback, + const folly::SocketAddress& address, + std::chrono::milliseconds connectTimeout, + std::chrono::milliseconds totalConnectTimeout, + const OptionMap& options = emptyOptionMap, + const folly::SocketAddress& bindAddr = anyAddress()) noexcept; using AsyncSocket::connect; /** * Initiate an SSL connection on the socket - * THe callback will be invoked and uninstalled when an SSL connection + * The callback will be invoked and uninstalled when an SSL connection * has been establshed on the underlying socket. - * The verification option verifyPeer is applied if its passed explicitly. - * If its not, the options in SSLContext set on the underying SSLContext + * The verification option verifyPeer is applied if it's passed explicitly. + * If it's not, the options in SSLContext set on the underlying SSLContext * are applied. * * @param callback callback object to invoke on success/failure @@ -330,15 +371,18 @@ class AsyncSSLSocket : public virtual AsyncSocket { * SSL_VERIFY_PEER and invokes * HandshakeCB::handshakeVer(). */ - virtual void sslConn(HandshakeCB *callback, uint64_t timeout = 0, - const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer = - folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); + virtual void sslConn( + HandshakeCB* callback, + std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(), + const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer = + folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); enum SSLStateEnum { STATE_UNINIT, + STATE_UNENCRYPTED, STATE_ACCEPTING, STATE_CACHE_LOOKUP, - STATE_RSA_ASYNC_PENDING, + STATE_ASYNC_PENDING, STATE_CONNECTING, STATE_ESTABLISHED, STATE_REMOTE_CLOSED, /// remote end closed; we can still write @@ -357,6 +401,11 @@ class AsyncSSLSocket : public virtual AsyncSocket { */ SSL_SESSION *getSSLSession(); + /** + * Get a handle to the SSL struct. + */ + const SSL* getSSL() const; + /** * Set the SSL session to be used during sslConn. AsyncSSLSocket will * hold a reference to the session until it is destroyed or released by the @@ -369,7 +418,8 @@ class AsyncSSLSocket : public virtual AsyncSocket { /** * Get the name of the protocol selected by the client during - * Next Protocol Negotiation (NPN) + * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation + * (ALPN) * * Throw an exception if openssl does not support NPN * @@ -379,13 +429,17 @@ class AsyncSSLSocket : public virtual AsyncSocket { * Note: the AsyncSSLSocket retains ownership * of this string. * @param protoNameLen Length of the name. + * @param protoType Whether this was an NPN or ALPN negotiation */ - virtual void getSelectedNextProtocol(const unsigned char** protoName, - unsigned* protoLen) const; + virtual void getSelectedNextProtocol( + const unsigned char** protoName, + unsigned* protoLen, + SSLContext::NextProtocolType* protoType = nullptr) const; /** * Get the name of the protocol selected by the client during - * Next Protocol Negotiation (NPN) + * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation + * (ALPN) * * @param protoName Name of the protocol (not guaranteed to be * null terminated); will be set to nullptr if @@ -393,16 +447,19 @@ class AsyncSSLSocket : public virtual AsyncSocket { * Note: the AsyncSSLSocket retains ownership * of this string. * @param protoNameLen Length of the name. + * @param protoType Whether this was an NPN or ALPN negotiation * @return false if openssl does not support NPN */ - virtual bool getSelectedNextProtocolNoThrow(const unsigned char** protoName, - unsigned* protoLen) const; + virtual bool getSelectedNextProtocolNoThrow( + const unsigned char** protoName, + unsigned* protoLen, + SSLContext::NextProtocolType* protoType = nullptr) const; /** * Determine if the session specified during setSSLSession was reused * or if the server rejected it and issued a new session. */ - bool getSSLSessionReused() const; + virtual bool getSSLSessionReused() const; /** * true if the session was resumed using session ID @@ -418,7 +475,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { * Returns the cipher used or the constant value "NONE" when no SSL session * has been established. */ - const char *getNegotiatedCipherName() const; + virtual const char* getNegotiatedCipherName() const; /** * Get the server name for this SSL connection. @@ -445,38 +502,35 @@ class AsyncSSLSocket : public virtual AsyncSocket { int getSSLVersion() const; /** - * Get the certificate size used for this SSL connection. + * Get the signature algorithm used in the cert that is used for this + * connection. */ - int getSSLCertSize() const; + const char *getSSLCertSigAlgName() const; - /* Get the number of bytes read from the wire (including protocol - * overhead). Returns 0 once the connection has been closed. + /** + * Get the certificate size used for this SSL connection. */ - unsigned long getBytesRead() const { - if (ssl_ != nullptr) { - return BIO_number_read(SSL_get_rbio(ssl_)); - } - return 0; - } + int getSSLCertSize() const; - /* Get the number of bytes written to the wire (including protocol - * overhead). Returns 0 once the connection has been closed. + /** + * Get the certificate used for this SSL connection. May be null */ - unsigned long getBytesWritten() const { - if (ssl_ != nullptr) { - return BIO_number_written(SSL_get_wbio(ssl_)); - } - return 0; - } + const X509* getSelfCert() const override; - virtual void attachEventBase(EventBase* eventBase) { + void attachEventBase(EventBase* eventBase) override { AsyncSocket::attachEventBase(eventBase); handshakeTimeout_.attachEventBase(eventBase); + connectionTimeout_.attachEventBase(eventBase); } - virtual void detachEventBase() { + void detachEventBase() override { AsyncSocket::detachEventBase(); handshakeTimeout_.detachEventBase(); + connectionTimeout_.detachEventBase(); + } + + bool isDetachable() const override { + return AsyncSocket::isDetachable() && !handshakeTimeout_.isScheduled(); } virtual void attachTimeoutManager(TimeoutManager* manager) { @@ -501,7 +555,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { void detachSSLContext(); #endif -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI /** * Switch the SSLContext to continue the SSL handshake. * It can only be used in server mode. @@ -524,151 +578,195 @@ class AsyncSSLSocket : public virtual AsyncSocket { * ClientHello message. */ void setServerName(std::string serverName) noexcept; -#endif +#endif // FOLLY_OPENSSL_HAS_SNI - void timeoutExpired() noexcept; + void timeoutExpired(std::chrono::milliseconds timeout) noexcept; /** * Get the list of supported ciphers sent by the client in the client's * preference order. */ - void getSSLClientCiphers(std::string& clientCiphers) { - std::stringstream ciphersStream; - std::string cipherName; - - if (parseClientHello_ == false - || clientHelloInfo_->clientHelloCipherSuites_.empty()) { - clientCiphers = ""; - return; - } - - for (auto originalCipherCode : clientHelloInfo_->clientHelloCipherSuites_) - { - // OpenSSL expects code as a big endian char array - auto cipherCode = htons(originalCipherCode); - -#if defined(SSL_OP_NO_TLSv1_2) - const SSL_CIPHER* cipher = - TLSv1_2_method()->get_cipher_by_char((unsigned char*)&cipherCode); -#elif defined(SSL_OP_NO_TLSv1_1) - const SSL_CIPHER* cipher = - TLSv1_1_method()->get_cipher_by_char((unsigned char*)&cipherCode); -#elif defined(SSL_OP_NO_TLSv1) - const SSL_CIPHER* cipher = - TLSv1_method()->get_cipher_by_char((unsigned char*)&cipherCode); -#else - const SSL_CIPHER* cipher = - SSLv3_method()->get_cipher_by_char((unsigned char*)&cipherCode); -#endif - - if (cipher == nullptr) { - ciphersStream << std::setfill('0') << std::setw(4) << std::hex - << originalCipherCode << ":"; - } else { - ciphersStream << SSL_CIPHER_get_name(cipher) << ":"; - } - } - - clientCiphers = ciphersStream.str(); - clientCiphers.erase(clientCiphers.end() - 1); - } + void getSSLClientCiphers( + std::string& clientCiphers, + bool convertToString = true) const; /** * Get the list of compression methods sent by the client in TLS Hello. */ - std::string getSSLClientComprMethods() { - if (!parseClientHello_) { - return ""; - } - return folly::join(":", clientHelloInfo_->clientHelloCompressionMethods_); - } + std::string getSSLClientComprMethods() const; /** * Get the list of TLS extensions sent by the client in the TLS Hello. */ - std::string getSSLClientExts() { - if (!parseClientHello_) { - return ""; - } - return folly::join(":", clientHelloInfo_->clientHelloExtensions_); - } + std::string getSSLClientExts() const; + + std::string getSSLClientSigAlgs() const; + + /** + * Get the list of versions in the supported versions extension (used to + * negotiate TLS 1.3). + */ + std::string getSSLClientSupportedVersions() const; + + std::string getSSLAlertsReceived() const; + + /* + * Save an optional alert message generated during certificate verify + */ + void setSSLCertVerificationAlert(std::string alert); + + std::string getSSLCertVerificationAlert() const; /** * Get the list of shared ciphers between the server and the client. * Works well for only SSLv2, not so good for SSLv3 or TLSv1. */ - void getSSLSharedCiphers(std::string& sharedCiphers) { - char ciphersBuffer[1024]; - ciphersBuffer[0] = '\0'; - SSL_get_shared_ciphers(ssl_, ciphersBuffer, sizeof(ciphersBuffer) - 1); - sharedCiphers = ciphersBuffer; - } + void getSSLSharedCiphers(std::string& sharedCiphers) const; /** * Get the list of ciphers supported by the server in the server's * preference order. */ - void getSSLServerCiphers(std::string& serverCiphers) { - serverCiphers = SSL_get_cipher_list(ssl_, 0); - int i = 1; - const char *cipher; - while ((cipher = SSL_get_cipher_list(ssl_, i)) != nullptr) { - serverCiphers.append(":"); - serverCiphers.append(cipher); - i++; - } - } + void getSSLServerCiphers(std::string& serverCiphers) const; + + /** + * Method to check if peer verfication is set. + * + * @return true if peer verification is required. + */ + bool needsPeerVerification() const; static int getSSLExDataIndex(); static AsyncSSLSocket* getFromSSL(const SSL *ssl); - static int eorAwareBioWrite(BIO *b, const char *in, int inl); + static int bioWrite(BIO* b, const char* in, int inl); + static int bioRead(BIO* b, char* out, int outl); void resetClientHelloParsing(SSL *ssl); static void clientHelloParsingCallback(int write_p, int version, int content_type, const void *buf, size_t len, SSL *ssl, void *arg); - - struct ClientHelloInfo { - folly::IOBufQueue clientHelloBuf_; - uint8_t clientHelloMajorVersion_; - uint8_t clientHelloMinorVersion_; - std::vector clientHelloCipherSuites_; - std::vector clientHelloCompressionMethods_; - std::vector clientHelloExtensions_; - }; + static const char* getSSLServerNameFromSSL(SSL* ssl); // For unit-tests - ClientHelloInfo* getClientHelloInfo() { + ssl::ClientHelloInfo* getClientHelloInfo() const { return clientHelloInfo_.get(); } + /** + * Returns the time taken to complete a handshake. + */ + virtual std::chrono::nanoseconds getHandshakeTime() const { + return handshakeEndTime_ - handshakeStartTime_; + } + + void setMinWriteSize(size_t minWriteSize) { + minWriteSize_ = minWriteSize; + } + + size_t getMinWriteSize() const { + return minWriteSize_; + } + + void setReadCB(ReadCallback* callback) override; + + /** + * Tries to enable the buffer movable experimental feature in openssl. + * This is not guaranteed to succeed in case openssl does not have + * the experimental feature built in. + */ + void setBufferMovableEnabled(bool enabled); + + /** + * Returns the peer certificate, or nullptr if no peer certificate received. + */ + ssl::X509UniquePtr getPeerCert() const override { + if (!ssl_) { + return nullptr; + } + + X509* cert = SSL_get_peer_certificate(ssl_); + return ssl::X509UniquePtr(cert); + } + + /** + * Force AsyncSSLSocket object to cache local and peer socket addresses. + * If called with "true" before connect() this function forces full local + * and remote socket addresses to be cached in the socket object and available + * through getLocalAddress()/getPeerAddress() methods even after the socket is + * closed. + */ + void forceCacheAddrOnFailure(bool force) { cacheAddrOnFailure_ = force; } + + const std::string& getServiceIdentity() const { return serviceIdentity_; } + + void setServiceIdentity(std::string serviceIdentity) { + serviceIdentity_ = std::move(serviceIdentity); + } + + void setCertCacheHit(bool hit) { + certCacheHit_ = hit; + } + + bool getCertCacheHit() const { + return certCacheHit_; + } + + bool sessionResumptionAttempted() const { + return sessionResumptionAttempted_; + } + + /** + * If the SSL socket was used to connect as well + * as establish an SSL connection, this gives the total + * timeout for the connect + SSL connection that was + * set. + */ + std::chrono::milliseconds getTotalConnectTimeout() const { + return totalConnectTimeout_; + } + + private: + + void init(); + protected: /** * Protected destructor. * * Users of AsyncSSLSocket must never delete it directly. Instead, invoke - * destroy() instead. (See the documentation in TDelayedDestruction.h for + * destroy() instead. (See the documentation in DelayedDestruction.h for * more details.) */ - ~AsyncSSLSocket(); + ~AsyncSSLSocket() override; // Inherit event notification methods from AsyncSocket except // the following. - - void handleRead() noexcept; - void handleWrite() noexcept; + void prepareReadBuffer(void** buf, size_t* buflen) override; + void handleRead() noexcept override; + void handleWrite() noexcept override; void handleAccept() noexcept; - void handleConnect() noexcept; + void handleConnect() noexcept override; void invalidState(HandshakeCB* callback); - bool willBlock(int ret, int *errorOut) noexcept; + bool willBlock(int ret, + int* sslErrorOut, + unsigned long* errErrorOut) noexcept; - virtual void checkForImmediateRead() noexcept; + void checkForImmediateRead() noexcept override; // AsyncSocket calls this at the wrong time for SSL - void handleInitialReadWrite() noexcept {} + void handleInitialReadWrite() noexcept override {} + + WriteResult interpretSSLError(int rc, int error); + ReadResult performRead(void** buf, size_t* buflen, size_t* offset) override; + WriteResult performWrite( + const iovec* vec, + uint32_t count, + WriteFlags flags, + uint32_t* countWritten, + uint32_t* partialWritten) override; - ssize_t performRead(void* buf, size_t buflen); - ssize_t performWrite(const iovec* vec, uint32_t count, WriteFlags flags, - uint32_t* countWritten, uint32_t* partialWritten); + ssize_t performWriteIovec(const iovec* vec, uint32_t count, + WriteFlags flags, uint32_t* countWritten, + uint32_t* partialWritten); // This virtual wrapper around SSL_write exists solely for testing/mockability virtual int sslWriteImpl(SSL *ssl, const void *buf, int n) { @@ -685,6 +783,13 @@ class AsyncSSLSocket : public virtual AsyncSocket { */ void applyVerificationOptions(SSL * ssl); + /** + * Sets up SSL with a custom write bio which intercepts all writes. + * + * @return true, if succeeds and false if there is an error creating the bio. + */ + bool setupSSLBio(); + /** * A SSL_write wrapper that understand EOR * @@ -699,14 +804,19 @@ class AsyncSSLSocket : public virtual AsyncSocket { // Inherit error handling methods from AsyncSocket, plus the following. void failHandshake(const char* fn, const AsyncSocketException& ex); + void invokeHandshakeErr(const AsyncSocketException& ex); void invokeHandshakeCB(); + void invokeConnectErr(const AsyncSocketException& ex) override; + void invokeConnectSuccess() override; + void scheduleConnectTimeout() override; + + void startSSLConnect(); + static void sslInfoCallback(const SSL *ssl, int type, int val); - static std::mutex mutex_; - static int sslExDataIndex_; - // Whether we've applied the TCP_CORK option to the socket - bool corked_{false}; + // Whether the current write to the socket should use MSG_MORE. + bool corkCurrentWrite_{false}; // SSL related members. bool server_{false}; // Used to prevent client-initiated renegotiation. Note that AsyncSSLSocket @@ -721,7 +831,8 @@ class AsyncSSLSocket : public virtual AsyncSocket { HandshakeCB* handshakeCallback_{nullptr}; SSL* ssl_{nullptr}; SSL_SESSION *sslSession_{nullptr}; - HandshakeTimeout handshakeTimeout_; + Timeout handshakeTimeout_; + Timeout connectionTimeout_; // whether the SSL session was resumed using session ID or not bool sessionIDResumed_{false}; @@ -729,13 +840,21 @@ class AsyncSSLSocket : public virtual AsyncSocket { // Only one app EOR byte can be tracked. size_t appEorByteNo_{0}; + // Try to avoid calling SSL_write() for buffers smaller than this. + // It doesn't take effect when it is 0. + size_t minWriteSize_{1500}; + // When openssl is about to sendmsg() across the minEorRawBytesNo_, // it will pass MSG_EOR to sendmsg(). size_t minEorRawByteNo_{0}; -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI std::shared_ptr handshakeCtx_; std::string tlsextHostname_; #endif + + // a service identity that this socket/connection is associated with + std::string serviceIdentity_; + folly::SSLContext::SSLVerifyPeerEnum verifyPeer_{folly::SSLContext::SSLVerifyPeerEnum::USE_CTX}; @@ -743,7 +862,20 @@ class AsyncSSLSocket : public virtual AsyncSocket { static int sslVerifyCallback(int preverifyOk, X509_STORE_CTX* ctx); bool parseClientHello_{false}; - unique_ptr clientHelloInfo_; + bool cacheAddrOnFailure_{false}; + bool bufferMovableEnabled_{false}; + bool certCacheHit_{false}; + std::unique_ptr clientHelloInfo_; + std::vector> alertsReceived_; + + // Time taken to complete the ssl handshake. + std::chrono::steady_clock::time_point handshakeStartTime_; + std::chrono::steady_clock::time_point handshakeEndTime_; + std::chrono::milliseconds handshakeConnectTimeout_{0}; + bool sessionResumptionAttempted_{false}; + std::chrono::milliseconds totalConnectTimeout_{0}; + + std::string sslVerificationAlert_; }; } // namespace