X-Git-Url: http://plrg.eecs.uci.edu/git/?p=folly.git;a=blobdiff_plain;f=folly%2Fio%2Fasync%2FAsyncSSLSocket.h;h=e9aca826c379cb788578460c6179f709b827aa6f;hp=47ad97b08160279e9c9419eaa9a97a2a164688a0;hb=5c74326fdc75ccdfc2152c15203625d8588096b6;hpb=12ace86198d5050388099abb0cdb58faa7d4ac74 diff --git a/folly/io/async/AsyncSSLSocket.h b/folly/io/async/AsyncSSLSocket.h index 47ad97b0..e9aca826 100644 --- a/folly/io/async/AsyncSSLSocket.h +++ b/folly/io/async/AsyncSSLSocket.h @@ -1,5 +1,5 @@ /* - * Copyright 2016 Facebook, Inc. + * Copyright 2017 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ #include #include #include +#include #include namespace folly { @@ -202,7 +203,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { } -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI /** * Create a client AsyncSSLSocket with tlsext_servername in * the Client Hello message. @@ -242,7 +243,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { new AsyncSSLSocket(ctx, evb, serverName, deferSecurityNegotiation), Destructor()); } -#endif +#endif // FOLLY_OPENSSL_HAS_SNI /** * TODO: implement support for SSL renegotiation. @@ -271,12 +272,13 @@ class AsyncSSLSocket : public virtual AsyncSocket { virtual std::string getSecurityProtocol() const override { return "TLS"; } - bool isEorTrackingEnabled() const override; virtual void setEorTracking(bool track) override; virtual size_t getRawBytesWritten() const override; virtual size_t getRawBytesReceived() const override; void enableClientHelloParsing(); + void setPreReceivedData(std::unique_ptr data); + /** * Accept an SSL connection on the socket. * @@ -292,9 +294,11 @@ class AsyncSSLSocket : public virtual AsyncSocket { * context by default, can be set explcitly to override the * method in the context */ - virtual void sslAccept(HandshakeCB* callback, uint32_t timeout = 0, + virtual void sslAccept( + HandshakeCB* callback, + std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(), const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer = - folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); + folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); /** * Invoke SSL accept following an asynchronous session cache lookup @@ -332,9 +336,11 @@ class AsyncSSLSocket : public virtual AsyncSocket { * SSL_VERIFY_PEER and invokes * HandshakeCB::handshakeVer(). */ - virtual void sslConn(HandshakeCB *callback, uint64_t timeout = 0, - const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer = - folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); + virtual void sslConn( + HandshakeCB* callback, + std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(), + const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer = + folly::SSLContext::SSLVerifyPeerEnum::USE_CTX); enum SSLStateEnum { STATE_UNINIT, @@ -479,11 +485,13 @@ class AsyncSSLSocket : public virtual AsyncSocket { virtual void attachEventBase(EventBase* eventBase) override { AsyncSocket::attachEventBase(eventBase); handshakeTimeout_.attachEventBase(eventBase); + connectionTimeout_.attachEventBase(eventBase); } virtual void detachEventBase() override { AsyncSocket::detachEventBase(); handshakeTimeout_.detachEventBase(); + connectionTimeout_.detachEventBase(); } virtual bool isDetachable() const override { @@ -512,7 +520,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { void detachSSLContext(); #endif -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI /** * Switch the SSLContext to continue the SSL handshake. * It can only be used in server mode. @@ -535,7 +543,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { * ClientHello message. */ void setServerName(std::string serverName) noexcept; -#endif +#endif // FOLLY_OPENSSL_HAS_SNI void timeoutExpired() noexcept; @@ -545,133 +553,44 @@ class AsyncSSLSocket : public virtual AsyncSocket { */ void getSSLClientCiphers( std::string& clientCiphers, - bool convertToString = true) const { - std::stringstream ciphersStream; - std::string cipherName; - - if (parseClientHello_ == false - || clientHelloInfo_->clientHelloCipherSuites_.empty()) { - clientCiphers = ""; - return; - } - - for (auto originalCipherCode : clientHelloInfo_->clientHelloCipherSuites_) - { - const SSL_CIPHER* cipher = nullptr; - if (convertToString) { - // OpenSSL expects code as a big endian char array - auto cipherCode = htons(originalCipherCode); - -#if defined(SSL_OP_NO_TLSv1_2) - cipher = - TLSv1_2_method()->get_cipher_by_char((unsigned char*)&cipherCode); -#elif defined(SSL_OP_NO_TLSv1_1) - cipher = - TLSv1_1_method()->get_cipher_by_char((unsigned char*)&cipherCode); -#elif defined(SSL_OP_NO_TLSv1) - cipher = - TLSv1_method()->get_cipher_by_char((unsigned char*)&cipherCode); -#else - cipher = - SSLv3_method()->get_cipher_by_char((unsigned char*)&cipherCode); -#endif - } - - if (cipher == nullptr) { - ciphersStream << std::setfill('0') << std::setw(4) << std::hex - << originalCipherCode << ":"; - } else { - ciphersStream << SSL_CIPHER_get_name(cipher) << ":"; - } - } - - clientCiphers = ciphersStream.str(); - clientCiphers.erase(clientCiphers.end() - 1); - } + bool convertToString = true) const; /** * Get the list of compression methods sent by the client in TLS Hello. */ - std::string getSSLClientComprMethods() const { - if (!parseClientHello_) { - return ""; - } - return folly::join(":", clientHelloInfo_->clientHelloCompressionMethods_); - } + std::string getSSLClientComprMethods() const; /** * Get the list of TLS extensions sent by the client in the TLS Hello. */ - std::string getSSLClientExts() const { - if (!parseClientHello_) { - return ""; - } - return folly::join(":", clientHelloInfo_->clientHelloExtensions_); - } + std::string getSSLClientExts() const; - std::string getSSLClientSigAlgs() const { - if (!parseClientHello_) { - return ""; - } - - std::string sigAlgs; - sigAlgs.reserve(clientHelloInfo_->clientHelloSigAlgs_.size() * 4); - for (size_t i = 0; i < clientHelloInfo_->clientHelloSigAlgs_.size(); i++) { - if (i) { - sigAlgs.push_back(':'); - } - sigAlgs.append(folly::to( - clientHelloInfo_->clientHelloSigAlgs_[i].first)); - sigAlgs.push_back(','); - sigAlgs.append(folly::to( - clientHelloInfo_->clientHelloSigAlgs_[i].second)); - } + std::string getSSLClientSigAlgs() const; - return sigAlgs; - } + /** + * Get the list of versions in the supported versions extension (used to + * negotiate TLS 1.3). + */ + std::string getSSLClientSupportedVersions() const; - std::string getSSLAlertsReceived() const { - std::string ret; - - for (const auto& alert : alertsReceived_) { - if (!ret.empty()) { - ret.append(","); - } - ret.append(folly::to(alert.first, ": ", alert.second)); - } - - return ret; - } + std::string getSSLAlertsReceived() const; /** * Get the list of shared ciphers between the server and the client. * Works well for only SSLv2, not so good for SSLv3 or TLSv1. */ - void getSSLSharedCiphers(std::string& sharedCiphers) const { - char ciphersBuffer[1024]; - ciphersBuffer[0] = '\0'; - SSL_get_shared_ciphers(ssl_, ciphersBuffer, sizeof(ciphersBuffer) - 1); - sharedCiphers = ciphersBuffer; - } + void getSSLSharedCiphers(std::string& sharedCiphers) const; /** * Get the list of ciphers supported by the server in the server's * preference order. */ - void getSSLServerCiphers(std::string& serverCiphers) const { - serverCiphers = SSL_get_cipher_list(ssl_, 0); - int i = 1; - const char *cipher; - while ((cipher = SSL_get_cipher_list(ssl_, i)) != nullptr) { - serverCiphers.append(":"); - serverCiphers.append(cipher); - i++; - } - } + void getSSLServerCiphers(std::string& serverCiphers) const; static int getSSLExDataIndex(); static AsyncSSLSocket* getFromSSL(const SSL *ssl); static int bioWrite(BIO* b, const char* in, int inl); + static int bioRead(BIO* b, char* out, int outl); void resetClientHelloParsing(SSL *ssl); static void clientHelloParsingCallback(int write_p, int version, int content_type, const void *buf, size_t len, SSL *ssl, void *arg); @@ -741,9 +660,24 @@ class AsyncSSLSocket : public virtual AsyncSocket { return certCacheHit_; } + bool sessionResumptionAttempted() const { + return sessionResumptionAttempted_; + } + + /** + * Clears the ERR stack before invoking SSL methods. + * This is useful if unrelated code that runs in the same thread + * does not properly handle SSL error conditions, in which case + * it could cause SSL_* methods to fail with incorrect error codes. + */ + void setClearOpenSSLErrors(bool clearErr) { + clearOpenSSLErrors_ = clearErr; + } + private: void init(); + void clearOpenSSLErrors(); protected: @@ -758,7 +692,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { // Inherit event notification methods from AsyncSocket except // the following. - void prepareReadBuffer(void** buf, size_t* buflen) noexcept override; + void prepareReadBuffer(void** buf, size_t* buflen) override; void handleRead() noexcept override; void handleWrite() noexcept override; void handleAccept() noexcept; @@ -788,6 +722,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { // This virtual wrapper around SSL_write exists solely for testing/mockability virtual int sslWriteImpl(SSL *ssl, const void *buf, int n) { + clearOpenSSLErrors(); return SSL_write(ssl, buf, n); } @@ -835,8 +770,8 @@ class AsyncSSLSocket : public virtual AsyncSocket { static void sslInfoCallback(const SSL *ssl, int type, int val); - // Whether we've applied the TCP_CORK option to the socket - bool corked_{false}; + // Whether the current write to the socket should use MSG_MORE. + bool corkCurrentWrite_{false}; // SSL related members. bool server_{false}; // Used to prevent client-initiated renegotiation. Note that AsyncSSLSocket @@ -856,9 +791,6 @@ class AsyncSSLSocket : public virtual AsyncSocket { // whether the SSL session was resumed using session ID or not bool sessionIDResumed_{false}; - // Whether to track EOR or not. - bool trackEor_{false}; - // The app byte num that we are tracking for the MSG_EOR // Only one app EOR byte can be tracked. size_t appEorByteNo_{0}; @@ -870,7 +802,7 @@ class AsyncSSLSocket : public virtual AsyncSocket { // When openssl is about to sendmsg() across the minEorRawBytesNo_, // it will pass MSG_EOR to sendmsg(). size_t minEorRawByteNo_{0}; -#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT) +#if FOLLY_OPENSSL_HAS_SNI std::shared_ptr handshakeCtx_; std::string tlsextHostname_; #endif @@ -894,7 +826,13 @@ class AsyncSSLSocket : public virtual AsyncSocket { // Time taken to complete the ssl handshake. std::chrono::steady_clock::time_point handshakeStartTime_; std::chrono::steady_clock::time_point handshakeEndTime_; - uint64_t handshakeConnectTimeout_{0}; + std::chrono::milliseconds handshakeConnectTimeout_{0}; + bool sessionResumptionAttempted_{false}; + + std::unique_ptr preReceivedData_; + // Whether or not to clear the err stack before invocation of another + // SSL method + bool clearOpenSSLErrors_{false}; }; } // namespace