#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/SSLContext.h>
#include <folly/io/async/AsyncTimeout.h>
+#include <folly/io/async/OpenSSLPtrTypes.h>
#include <folly/io/async/TimeoutManager.h>
#include <folly/Bits.h>
#include <folly/io/IOBuf.h>
#include <folly/io/Cursor.h>
-using folly::io::Cursor;
-using std::unique_ptr;
-
namespace folly {
class SSLException: public folly::AsyncSocketException {
class AsyncSSLSocket : public virtual AsyncSocket {
public:
typedef std::unique_ptr<AsyncSSLSocket, Destructor> UniquePtr;
+ using X509_deleter = folly::static_function_deleter<X509, &X509_free>;
class HandshakeCB {
public:
- virtual ~HandshakeCB() {}
+ virtual ~HandshakeCB() = default;
/**
* handshakeVer() is invoked during handshaking to give the
* See the passages on verify_callback in SSL_CTX_set_verify(3)
* for more details.
*/
- virtual bool handshakeVer(AsyncSSLSocket* sock,
+ virtual bool handshakeVer(AsyncSSLSocket* /*sock*/,
bool preverifyOk,
- X509_STORE_CTX* ctx) noexcept {
+ X509_STORE_CTX* /*ctx*/) noexcept {
return preverifyOk;
}
* Create a client AsyncSSLSocket
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
- EventBase* evb);
+ EventBase* evb, bool deferSecurityNegotiation = false);
/**
* Create a server/client AsyncSSLSocket from an already connected
* @param evb EventBase that will manage this socket.
* @param fd File descriptor to take over (should be a connected socket).
* @param server Is socket in server mode?
+ * @param deferSecurityNegotiation
+ * unencrypted data can be sent before sslConn/Accept
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
- EventBase* evb, int fd, bool server = true);
+ EventBase* evb, int fd,
+ bool server = true, bool deferSecurityNegotiation = false);
/**
*/
static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
- EventBase* evb, int fd, bool server=true) {
+ EventBase* evb, int fd, bool server=true,
+ bool deferSecurityNegotiation = false) {
return std::shared_ptr<AsyncSSLSocket>(
- new AsyncSSLSocket(ctx, evb, fd, server),
+ new AsyncSSLSocket(ctx, evb, fd, server, deferSecurityNegotiation),
Destructor());
}
*/
static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
- EventBase* evb) {
+ EventBase* evb, bool deferSecurityNegotiation = false) {
return std::shared_ptr<AsyncSSLSocket>(
- new AsyncSSLSocket(ctx, evb),
+ new AsyncSSLSocket(ctx, evb, deferSecurityNegotiation),
Destructor());
}
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
EventBase* evb,
- const std::string& serverName);
+ const std::string& serverName,
+ bool deferSecurityNegotiation = false);
/**
* Create a client AsyncSSLSocket from an already connected
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb,
int fd,
- const std::string& serverName);
+ const std::string& serverName,
+ bool deferSecurityNegotiation = false);
static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb,
- const std::string& serverName) {
+ const std::string& serverName,
+ bool deferSecurityNegotiation = false) {
return std::shared_ptr<AsyncSSLSocket>(
- new AsyncSSLSocket(ctx, evb, serverName),
+ new AsyncSSLSocket(ctx, evb, serverName, deferSecurityNegotiation),
Destructor());
}
#endif
virtual void shutdownWriteNow() override;
virtual bool good() const override;
virtual bool connecting() const override;
+ virtual std::string getApplicationProtocol() noexcept override;
+
+ virtual std::string getSecurityProtocol() const override { return "TLS"; }
bool isEorTrackingEnabled() const override;
virtual void setEorTracking(bool track) override;
/**
* Initiate an SSL connection on the socket
- * THe callback will be invoked and uninstalled when an SSL connection
+ * The callback will be invoked and uninstalled when an SSL connection
* has been establshed on the underlying socket.
- * The verification option verifyPeer is applied if its passed explicitly.
- * If its not, the options in SSLContext set on the underying SSLContext
+ * The verification option verifyPeer is applied if it's passed explicitly.
+ * If it's not, the options in SSLContext set on the underlying SSLContext
* are applied.
*
* @param callback callback object to invoke on success/failure
enum SSLStateEnum {
STATE_UNINIT,
+ STATE_UNENCRYPTED,
STATE_ACCEPTING,
STATE_CACHE_LOOKUP,
STATE_RSA_ASYNC_PENDING,
/**
* Get the name of the protocol selected by the client during
- * Next Protocol Negotiation (NPN)
+ * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
+ * (ALPN)
*
* Throw an exception if openssl does not support NPN
*
* Note: the AsyncSSLSocket retains ownership
* of this string.
* @param protoNameLen Length of the name.
+ * @param protoType Whether this was an NPN or ALPN negotiation
*/
- virtual void getSelectedNextProtocol(const unsigned char** protoName,
- unsigned* protoLen) const;
+ virtual void getSelectedNextProtocol(
+ const unsigned char** protoName,
+ unsigned* protoLen,
+ SSLContext::NextProtocolType* protoType = nullptr) const;
/**
* Get the name of the protocol selected by the client during
- * Next Protocol Negotiation (NPN)
+ * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
+ * (ALPN)
*
* @param protoName Name of the protocol (not guaranteed to be
* null terminated); will be set to nullptr if
* Note: the AsyncSSLSocket retains ownership
* of this string.
* @param protoNameLen Length of the name.
+ * @param protoType Whether this was an NPN or ALPN negotiation
* @return false if openssl does not support NPN
*/
- virtual bool getSelectedNextProtocolNoThrow(const unsigned char** protoName,
- unsigned* protoLen) const;
+ virtual bool getSelectedNextProtocolNoThrow(
+ const unsigned char** protoName,
+ unsigned* protoLen,
+ SSLContext::NextProtocolType* protoType = nullptr) const;
/**
* Determine if the session specified during setSSLSession was reused
* or if the server rejected it and issued a new session.
*/
- bool getSSLSessionReused() const;
+ virtual bool getSSLSessionReused() const;
/**
* true if the session was resumed using session ID
* Returns the cipher used or the constant value "NONE" when no SSL session
* has been established.
*/
- const char *getNegotiatedCipherName() const;
+ virtual const char* getNegotiatedCipherName() const;
/**
* Get the server name for this SSL connection.
int getSSLVersion() const;
/**
- * Get the certificate size used for this SSL connection.
+ * Get the signature algorithm used in the cert that is used for this
+ * connection.
*/
- int getSSLCertSize() const;
+ const char *getSSLCertSigAlgName() const;
- /* Get the number of bytes read from the wire (including protocol
- * overhead). Returns 0 once the connection has been closed.
- */
- unsigned long getBytesRead() const {
- if (ssl_ != nullptr) {
- return BIO_number_read(SSL_get_rbio(ssl_));
- }
- return 0;
- }
-
- /* Get the number of bytes written to the wire (including protocol
- * overhead). Returns 0 once the connection has been closed.
+ /**
+ * Get the certificate size used for this SSL connection.
*/
- unsigned long getBytesWritten() const {
- if (ssl_ != nullptr) {
- return BIO_number_written(SSL_get_wbio(ssl_));
- }
- return 0;
- }
+ int getSSLCertSize() const;
virtual void attachEventBase(EventBase* eventBase) override {
AsyncSocket::attachEventBase(eventBase);
handshakeTimeout_.detachEventBase();
}
+ virtual bool isDetachable() const override {
+ return AsyncSocket::isDetachable() && !handshakeTimeout_.isScheduled();
+ }
+
virtual void attachTimeoutManager(TimeoutManager* manager) {
handshakeTimeout_.attachTimeoutManager(manager);
}
* Get the list of supported ciphers sent by the client in the client's
* preference order.
*/
- void getSSLClientCiphers(std::string& clientCiphers) {
+ void getSSLClientCiphers(std::string& clientCiphers) const {
std::stringstream ciphersStream;
std::string cipherName;
/**
* Get the list of compression methods sent by the client in TLS Hello.
*/
- std::string getSSLClientComprMethods() {
+ std::string getSSLClientComprMethods() const {
if (!parseClientHello_) {
return "";
}
/**
* Get the list of TLS extensions sent by the client in the TLS Hello.
*/
- std::string getSSLClientExts() {
+ std::string getSSLClientExts() const {
if (!parseClientHello_) {
return "";
}
return folly::join(":", clientHelloInfo_->clientHelloExtensions_);
}
+ std::string getSSLClientSigAlgs() const {
+ if (!parseClientHello_) {
+ return "";
+ }
+
+ std::string sigAlgs;
+ sigAlgs.reserve(clientHelloInfo_->clientHelloSigAlgs_.size() * 4);
+ for (size_t i = 0; i < clientHelloInfo_->clientHelloSigAlgs_.size(); i++) {
+ if (i) {
+ sigAlgs.push_back(':');
+ }
+ sigAlgs.append(folly::to<std::string>(
+ clientHelloInfo_->clientHelloSigAlgs_[i].first));
+ sigAlgs.push_back(',');
+ sigAlgs.append(folly::to<std::string>(
+ clientHelloInfo_->clientHelloSigAlgs_[i].second));
+ }
+
+ return sigAlgs;
+ }
+
/**
* Get the list of shared ciphers between the server and the client.
* Works well for only SSLv2, not so good for SSLv3 or TLSv1.
*/
- void getSSLSharedCiphers(std::string& sharedCiphers) {
+ void getSSLSharedCiphers(std::string& sharedCiphers) const {
char ciphersBuffer[1024];
ciphersBuffer[0] = '\0';
SSL_get_shared_ciphers(ssl_, ciphersBuffer, sizeof(ciphersBuffer) - 1);
* Get the list of ciphers supported by the server in the server's
* preference order.
*/
- void getSSLServerCiphers(std::string& serverCiphers) {
+ void getSSLServerCiphers(std::string& serverCiphers) const {
serverCiphers = SSL_get_cipher_list(ssl_, 0);
int i = 1;
const char *cipher;
static void clientHelloParsingCallback(int write_p, int version,
int content_type, const void *buf, size_t len, SSL *ssl, void *arg);
+ // http://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml
+ enum class TLSExtension: uint16_t {
+ SERVER_NAME = 0,
+ MAX_FRAGMENT_LENGTH = 1,
+ CLIENT_CERTIFICATE_URL = 2,
+ TRUSTED_CA_KEYS = 3,
+ TRUNCATED_HMAC = 4,
+ STATUS_REQUEST = 5,
+ USER_MAPPING = 6,
+ CLIENT_AUTHZ = 7,
+ SERVER_AUTHZ = 8,
+ CERT_TYPE = 9,
+ SUPPORTED_GROUPS = 10,
+ EC_POINT_FORMATS = 11,
+ SRP = 12,
+ SIGNATURE_ALGORITHMS = 13,
+ USE_SRTP = 14,
+ HEARTBEAT = 15,
+ APPLICATION_LAYER_PROTOCOL_NEGOTIATION = 16,
+ STATUS_REQUEST_V2 = 17,
+ SIGNED_CERTIFICATE_TIMESTAMP = 18,
+ CLIENT_CERTIFICATE_TYPE = 19,
+ SERVER_CERTIFICATE_TYPE = 20,
+ PADDING = 21,
+ ENCRYPT_THEN_MAC = 22,
+ EXTENDED_MASTER_SECRET = 23,
+ SESSION_TICKET = 35,
+ RENEGOTIATION_INFO = 65281
+ };
+
+ // http://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-18
+ enum class HashAlgorithm: uint8_t {
+ NONE = 0,
+ MD5 = 1,
+ SHA1 = 2,
+ SHA224 = 3,
+ SHA256 = 4,
+ SHA384 = 5,
+ SHA512 = 6
+ };
+
+ // http://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-16
+ enum class SignatureAlgorithm: uint8_t {
+ ANONYMOUS = 0,
+ RSA = 1,
+ DSA = 2,
+ ECDSA = 3
+ };
+
struct ClientHelloInfo {
folly::IOBufQueue clientHelloBuf_;
uint8_t clientHelloMajorVersion_;
uint8_t clientHelloMinorVersion_;
std::vector<uint16_t> clientHelloCipherSuites_;
std::vector<uint8_t> clientHelloCompressionMethods_;
- std::vector<uint16_t> clientHelloExtensions_;
+ std::vector<TLSExtension> clientHelloExtensions_;
+ std::vector<
+ std::pair<HashAlgorithm, SignatureAlgorithm>> clientHelloSigAlgs_;
};
// For unit-tests
- ClientHelloInfo* getClientHelloInfo() {
+ ClientHelloInfo* getClientHelloInfo() const {
return clientHelloInfo_.get();
}
+ /**
+ * Returns the time taken to complete a handshake.
+ */
+ virtual std::chrono::nanoseconds getHandshakeTime() const {
+ return handshakeEndTime_ - handshakeStartTime_;
+ }
+
+ void setMinWriteSize(size_t minWriteSize) {
+ minWriteSize_ = minWriteSize;
+ }
+
+ size_t getMinWriteSize() const {
+ return minWriteSize_;
+ }
+
+ void setReadCB(ReadCallback* callback) override;
+
+ /**
+ * Returns the peer certificate, or nullptr if no peer certificate received.
+ */
+ virtual X509_UniquePtr getPeerCert() const {
+ if (!ssl_) {
+ return nullptr;
+ }
+
+ X509* cert = SSL_get_peer_certificate(ssl_);
+ return X509_UniquePtr(cert);
+ }
+
+ private:
+
+ void init();
+
protected:
/**
* Protected destructor.
*
* Users of AsyncSSLSocket must never delete it directly. Instead, invoke
- * destroy() instead. (See the documentation in TDelayedDestruction.h for
+ * destroy() instead. (See the documentation in DelayedDestruction.h for
* more details.)
*/
~AsyncSSLSocket();
// Inherit event notification methods from AsyncSocket except
// the following.
-
+ void prepareReadBuffer(void** buf, size_t* buflen) noexcept override;
void handleRead() noexcept override;
void handleWrite() noexcept override;
void handleAccept() noexcept;
// AsyncSocket calls this at the wrong time for SSL
void handleInitialReadWrite() noexcept override {}
- ssize_t performRead(void* buf, size_t buflen) override;
+ int interpretSSLError(int rc, int error);
+ ssize_t performRead(void** buf, size_t* buflen, size_t* offset) override;
ssize_t performWrite(const iovec* vec, uint32_t count, WriteFlags flags,
uint32_t* countWritten, uint32_t* partialWritten)
override;
+ ssize_t performWriteIovec(const iovec* vec, uint32_t count,
+ WriteFlags flags, uint32_t* countWritten,
+ uint32_t* partialWritten);
+
// This virtual wrapper around SSL_write exists solely for testing/mockability
virtual int sslWriteImpl(SSL *ssl, const void *buf, int n) {
return SSL_write(ssl, buf, n);
// Inherit error handling methods from AsyncSocket, plus the following.
void failHandshake(const char* fn, const AsyncSocketException& ex);
+ void invokeHandshakeErr(const AsyncSocketException& ex);
void invokeHandshakeCB();
static void sslInfoCallback(const SSL *ssl, int type, int val);
- static std::mutex mutex_;
- static int sslExDataIndex_;
- // Whether we've applied the TCP_CORK option to the socket
- bool corked_{false};
// SSL related members.
bool server_{false};
// Used to prevent client-initiated renegotiation. Note that AsyncSSLSocket
// Only one app EOR byte can be tracked.
size_t appEorByteNo_{0};
+ // Try to avoid calling SSL_write() for buffers smaller than this.
+ // It doesn't take effect when it is 0.
+ size_t minWriteSize_{1500};
+
// When openssl is about to sendmsg() across the minEorRawBytesNo_,
// it will pass MSG_EOR to sendmsg().
size_t minEorRawByteNo_{0};
static int sslVerifyCallback(int preverifyOk, X509_STORE_CTX* ctx);
bool parseClientHello_{false};
- unique_ptr<ClientHelloInfo> clientHelloInfo_;
+ std::unique_ptr<ClientHelloInfo> clientHelloInfo_;
+
+ // Time taken to complete the ssl handshake.
+ std::chrono::steady_clock::time_point handshakeStartTime_;
+ std::chrono::steady_clock::time_point handshakeEndTime_;
};
} // namespace