/*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2014-present Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h>
+#include <folly/portability/Sockets.h>
#include <boost/noncopyable.hpp>
#include <errno.h>
#include <fcntl.h>
-#include <netinet/in.h>
-#include <netinet/tcp.h>
-#include <openssl/err.h>
-#include <openssl/asn1.h>
-#include <openssl/ssl.h>
#include <sys/types.h>
-#include <sys/socket.h>
-#include <unistd.h>
#include <chrono>
+#include <memory>
-#include <folly/Bits.h>
+#include <folly/Format.h>
#include <folly/SocketAddress.h>
#include <folly/SpinLock.h>
-#include <folly/io/IOBuf.h>
#include <folly/io/Cursor.h>
+#include <folly/io/IOBuf.h>
+#include <folly/lang/Bits.h>
+#include <folly/portability/OpenSSL.h>
using folly::SocketAddress;
using folly::SSLContext;
using folly::AsyncSocketException;
using folly::AsyncSSLSocket;
using folly::Optional;
+using folly::SSLContext;
+// For OpenSSL portability API
+using namespace folly::ssl;
+using folly::ssl::OpenSSLUtils;
+
// We have one single dummy SSL context so that we can implement attach
// and detach methods in a thread safe fashion without modifying opnessl.
static SSLContext *dummyCtx = nullptr;
static SpinLock dummyCtxLock;
-// Numbers chosen as to not collide with functions in ssl.h
-const uint8_t TASYNCSSLSOCKET_F_PERFORM_READ = 90;
-const uint8_t TASYNCSSLSOCKET_F_PERFORM_WRITE = 91;
-
// If given min write size is less than this, buffer will be allocated on
// stack, otherwise it is allocated on heap
const size_t MAX_STACK_BUF_SIZE = 2048;
timeoutLeft = timeout_ - (curTime - startTime_);
if (timeoutLeft <= 0) {
- AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
- "SSL connect timed out");
+ AsyncSocketException ex(
+ AsyncSocketException::TIMED_OUT,
+ folly::sformat("SSL connect timed out after {}ms", timeout_));
fail(ex);
delete this;
return;
}
}
- sslSocket_->sslConn(this, timeoutLeft);
+ sslSocket_->sslConn(this, std::chrono::milliseconds(timeoutLeft));
}
void connectErr(const AsyncSocketException& ex) noexcept override {
- LOG(ERROR) << "TCP connect failed: " << ex.what();
+ VLOG(1) << "TCP connect failed: " << ex.what();
fail(ex);
delete this;
}
- void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
+ void handshakeSuc(AsyncSSLSocket* /* sock */) noexcept override {
VLOG(7) << "client handshake success";
if (callback_) {
callback_->connectSuccess();
delete this;
}
- void handshakeErr(AsyncSSLSocket* socket,
+ void handshakeErr(AsyncSSLSocket* /* socket */,
const AsyncSocketException& ex) noexcept override {
- LOG(ERROR) << "client handshakeErr: " << ex.what();
+ VLOG(1) << "client handshakeErr: " << ex.what();
fail(ex);
delete this;
}
}
};
-// XXX: implement an equivalent to corking for platforms with TCP_NOPUSH?
-#ifdef TCP_CORK // Linux-only
-/**
- * Utility class that corks a TCP socket upon construction or uncorks
- * the socket upon destruction
- */
-class CorkGuard : private boost::noncopyable {
- public:
- CorkGuard(int fd, bool multipleWrites, bool haveMore, bool* corked):
- fd_(fd), haveMore_(haveMore), corked_(corked) {
- if (*corked_) {
- // socket is already corked; nothing to do
- return;
- }
- if (multipleWrites || haveMore) {
- // We are performing multiple writes in this performWrite() call,
- // and/or there are more calls to performWrite() that will be invoked
- // later, so enable corking
- int flag = 1;
- setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
- *corked_ = true;
- }
- }
-
- ~CorkGuard() {
- if (haveMore_) {
- // more data to come; don't uncork yet
- return;
- }
- if (!*corked_) {
- // socket isn't corked; nothing to do
- return;
- }
-
- int flag = 0;
- setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
- *corked_ = false;
- }
-
- private:
- int fd_;
- bool haveMore_;
- bool* corked_;
-};
-#else
-class CorkGuard : private boost::noncopyable {
- public:
- CorkGuard(int, bool, bool, bool*) {}
-};
-#endif
-
void setup_SSL_CTX(SSL_CTX *ctx) {
#ifdef SSL_MODE_RELEASE_BUFFERS
SSL_CTX_set_mode(ctx,
}
-BIO_METHOD eorAwareBioMethod;
+// Note: This is a Leaky Meyer's Singleton. The reason we can't use a non-leaky
+// thing is because we will be setting this BIO_METHOD* inside BIOs owned by
+// various SSL objects which may get callbacks even during teardown. We may
+// eventually try to fix this
+static BIO_METHOD* getSSLBioMethod() {
+ static auto const instance = OpenSSLUtils::newSocketBioMethod().release();
+ return instance;
+}
-void* initEorBioMethod(void) {
- memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod));
+void* initsslBioMethod() {
+ auto sslBioMethod = getSSLBioMethod();
// override the bwrite method for MSG_EOR support
- eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite;
+ OpenSSLUtils::setCustomBioWriteMethod(sslBioMethod, AsyncSSLSocket::bioWrite);
+ OpenSSLUtils::setCustomBioReadMethod(sslBioMethod, AsyncSSLSocket::bioRead);
- // Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not
+ // Note that the sslBioMethod.type and sslBioMethod.name are not
// set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
- // then have specific handlings. The eorAwareBioWrite should be compatible
+ // then have specific handlings. The sslWriteBioWrite should be compatible
// with the one in openssl.
// Return something here to enable AsyncSSLSocket to call this method using
return nullptr;
}
-} // anonymous namespace
+} // namespace
namespace folly {
-SSLException::SSLException(int sslError, int errno_copy):
- AsyncSocketException(
- AsyncSocketException::SSL_ERROR,
- ERR_error_string(sslError, msg_),
- sslError == SSL_ERROR_SYSCALL ? errno_copy : 0), error_(sslError) {}
-
/**
* Create a client AsyncSSLSocket
*/
EventBase* evb, bool deferSecurityNegotiation) :
AsyncSocket(evb),
ctx_(ctx),
- handshakeTimeout_(this, evb) {
+ handshakeTimeout_(this, evb),
+ connectionTimeout_(this, evb) {
init();
if (deferSecurityNegotiation) {
sslState_ = STATE_UNENCRYPTED;
/**
* Create a server/client AsyncSSLSocket
*/
-AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
- EventBase* evb, int fd, bool server,
- bool deferSecurityNegotiation) :
- AsyncSocket(evb, fd),
- server_(server),
- ctx_(ctx),
- handshakeTimeout_(this, evb) {
+AsyncSSLSocket::AsyncSSLSocket(
+ const shared_ptr<SSLContext>& ctx,
+ EventBase* evb,
+ int fd,
+ bool server,
+ bool deferSecurityNegotiation)
+ : AsyncSocket(evb, fd),
+ server_(server),
+ ctx_(ctx),
+ handshakeTimeout_(this, evb),
+ connectionTimeout_(this, evb) {
+ noTransparentTls_ = true;
+ init();
+ if (server) {
+ SSL_CTX_set_info_callback(
+ ctx_->getSSLCtx(), AsyncSSLSocket::sslInfoCallback);
+ }
+ if (deferSecurityNegotiation) {
+ sslState_ = STATE_UNENCRYPTED;
+ }
+}
+
+AsyncSSLSocket::AsyncSSLSocket(
+ const shared_ptr<SSLContext>& ctx,
+ AsyncSocket::UniquePtr oldAsyncSocket,
+ bool server,
+ bool deferSecurityNegotiation)
+ : AsyncSocket(std::move(oldAsyncSocket)),
+ server_(server),
+ ctx_(ctx),
+ handshakeTimeout_(this, AsyncSocket::getEventBase()),
+ connectionTimeout_(this, AsyncSocket::getEventBase()) {
+ noTransparentTls_ = true;
init();
if (server) {
SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
}
}
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
/**
* Create a client AsyncSSLSocket and allow tlsext_hostname
* to be sent in Client Hello.
* Create a client AsyncSSLSocket from an already connected fd
* and allow tlsext_hostname to be sent in Client Hello.
*/
-AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
- EventBase* evb, int fd,
- const std::string& serverName,
- bool deferSecurityNegotiation) :
- AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
+AsyncSSLSocket::AsyncSSLSocket(
+ const shared_ptr<SSLContext>& ctx,
+ EventBase* evb,
+ int fd,
+ const std::string& serverName,
+ bool deferSecurityNegotiation)
+ : AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
tlsextHostname_ = serverName;
}
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
AsyncSSLSocket::~AsyncSSLSocket() {
VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
void AsyncSSLSocket::init() {
// Do this here to ensure we initialize this once before any use of
// AsyncSSLSocket instances and not as part of library load.
- static const auto eorAwareBioMethodInitializer = initEorBioMethod();
+ static const auto sslBioMethodInitializer = initsslBioMethod();
+ (void)sslBioMethodInitializer;
+
setup_SSL_CTX(ctx_->getSSLCtx());
}
DestructorGuard dg(this);
- if (handshakeCallback_) {
- AsyncSocketException ex(AsyncSocketException::END_OF_FILE,
- "SSL connection closed locally");
- HandshakeCB* callback = handshakeCallback_;
- handshakeCallback_ = nullptr;
- callback->handshakeErr(this, ex);
- }
+ invokeHandshakeErr(
+ AsyncSocketException(
+ AsyncSocketException::END_OF_FILE,
+ "SSL connection closed locally"));
if (ssl_ != nullptr) {
SSL_free(ssl_);
sslState_ == STATE_CONNECTING))));
}
-bool AsyncSSLSocket::isEorTrackingEnabled() const {
- const BIO *wb = SSL_get_wbio(ssl_);
- return wb && wb->method == &eorAwareBioMethod;
+std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
+ const unsigned char* protoName = nullptr;
+ unsigned protoLength;
+ if (getSelectedNextProtocolNoThrow(&protoName, &protoLength)) {
+ return std::string(reinterpret_cast<const char*>(protoName), protoLength);
+ }
+ return "";
}
void AsyncSSLSocket::setEorTracking(bool track) {
- BIO *wb = SSL_get_wbio(ssl_);
- if (!wb) {
- throw AsyncSocketException(AsyncSocketException::INVALID_STATE,
- "setting EOR tracking without an initialized "
- "BIO");
- }
-
- if (track) {
- if (wb->method != &eorAwareBioMethod) {
- // only do this if we didn't
- wb->method = &eorAwareBioMethod;
- BIO_set_app_data(wb, this);
- appEorByteNo_ = 0;
- minEorRawByteNo_ = 0;
- }
- } else if (wb->method == &eorAwareBioMethod) {
- wb->method = BIO_s_socket();
- BIO_set_app_data(wb, nullptr);
+ if (isEorTrackingEnabled() != track) {
+ AsyncSocket::setEorTracking(track);
appEorByteNo_ = 0;
minEorRawByteNo_ = 0;
- } else {
- CHECK(wb->method == BIO_s_socket());
}
}
size_t AsyncSSLSocket::getRawBytesWritten() const {
+ // The bio(s) in the write path are in a chain
+ // each bio flushes to the next and finally written into the socket
+ // to get the rawBytesWritten on the socket,
+ // get the write bytes of the last bio
BIO *b;
if (!ssl_ || !(b = SSL_get_wbio(ssl_))) {
return 0;
}
+ BIO* next = BIO_next(b);
+ while (next != nullptr) {
+ b = next;
+ next = BIO_next(b);
+ }
return BIO_number_written(b);
}
void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_
<< ", state=" << int(state_) << ", sslState=" << sslState_ << ", "
- << "events=" << eventFlags_ << ", server=" << short(server_) << "): "
- << "sslAccept/Connect() called in invalid "
- << "state, handshake callback " << handshakeCallback_ << ", new callback "
- << callback;
+ << "events=" << eventFlags_ << ", server=" << short(server_)
+ << "): " << "sslAccept/Connect() called in invalid "
+ << "state, handshake callback " << handshakeCallback_
+ << ", new callback " << callback;
assert(!handshakeTimeout_.isScheduled());
sslState_ = STATE_ERROR;
AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
"sslAccept() called with socket in invalid state");
+ handshakeEndTime_ = std::chrono::steady_clock::now();
if (callback) {
callback->handshakeErr(this, ex);
}
- // Check the socket state not the ssl state here.
- if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
- failHandshake(__func__, ex);
- }
+ failHandshake(__func__, ex);
}
-void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
- const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
+void AsyncSSLSocket::sslAccept(
+ HandshakeCB* callback,
+ std::chrono::milliseconds timeout,
+ const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
DestructorGuard dg(this);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
verifyPeer_ = verifyPeer;
// Make sure we're in the uninitialized state
- if (!server_ || (sslState_ != STATE_UNINIT &&
- sslState_ != STATE_UNENCRYPTED) ||
+ if (!server_ ||
+ (sslState_ != STATE_UNINIT && sslState_ != STATE_UNENCRYPTED) ||
handshakeCallback_ != nullptr) {
return invalidState(callback);
}
+ // Cache local and remote socket addresses to keep them available
+ // after socket file descriptor is closed.
+ if (cacheAddrOnFailure_) {
+ cacheAddresses();
+ }
+
+ handshakeStartTime_ = std::chrono::steady_clock::now();
+ // Make end time at least >= start time.
+ handshakeEndTime_ = handshakeStartTime_;
+
sslState_ = STATE_ACCEPTING;
handshakeCallback_ = callback;
- if (timeout > 0) {
+ if (timeout > std::chrono::milliseconds::zero()) {
handshakeTimeout_.scheduleTimeout(timeout);
}
/* register for a read operation (waiting for CLIENT HELLO) */
updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
+
+ checkForImmediateRead();
}
-#if OPENSSL_VERSION_NUMBER >= 0x009080bfL
void AsyncSSLSocket::attachSSLContext(
const std::shared_ptr<SSLContext>& ctx) {
DCHECK(ctx->getSSLCtx());
ctx_ = ctx;
+ // It's possible this could be attached before ssl_ is set up
+ if (!ssl_) {
+ return;
+ }
+
// In order to call attachSSLContext, detachSSLContext must have been
- // previously called which sets the socket's context to the dummy
- // context. Thus we must acquire this lock.
+ // previously called.
+ // We need to update the initial_ctx if necessary
+ // The 'initial_ctx' inside an SSL* points to the context that it was created
+ // with, which is also where session callbacks and servername callbacks
+ // happen.
+ // When we switch to a different SSL_CTX, we want to update the initial_ctx as
+ // well so that any callbacks don't go to a different object
+ // NOTE: this will only work if we have access to ssl_ internals, so it may
+ // not work on
+ // OpenSSL version >= 1.1.0
+ auto sslCtx = ctx->getSSLCtx();
+ OpenSSLUtils::setSSLInitialCtx(ssl_, sslCtx);
+ // Detach sets the socket's context to the dummy context. Thus we must acquire
+ // this lock.
SpinLockGuard guard(dummyCtxLock);
- SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
+ SSL_set_SSL_CTX(ssl_, sslCtx);
}
void AsyncSSLSocket::detachSSLContext() {
DCHECK(ctx_);
ctx_.reset();
- // We aren't using the initial_ctx for now, and it can introduce race
- // conditions in the destructor of the SSL object.
-#ifndef OPENSSL_NO_TLSEXT
- if (ssl_->initial_ctx) {
- SSL_CTX_free(ssl_->initial_ctx);
- ssl_->initial_ctx = nullptr;
+ // It's possible for this to be called before ssl_ has been
+ // set up
+ if (!ssl_) {
+ return;
}
-#endif
+ // The 'initial_ctx' inside an SSL* points to the context that it was created
+ // with, which is also where session callbacks and servername callbacks
+ // happen.
+ // Detach the initial_ctx as well. It will be reattached in attachSSLContext
+ // it is used for session info.
+ // NOTE: this will only work if we have access to ssl_ internals, so it may
+ // not work on
+ // OpenSSL version >= 1.1.0
+ SSL_CTX* initialCtx = OpenSSLUtils::getSSLInitialCtx(ssl_);
+ if (initialCtx) {
+ SSL_CTX_free(initialCtx);
+ OpenSSLUtils::setSSLInitialCtx(ssl_, nullptr);
+ }
+
SpinLockGuard guard(dummyCtxLock);
if (nullptr == dummyCtx) {
// We need to lazily initialize the dummy context so we don't
// would not be thread safe.
SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
}
-#endif
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
void AsyncSSLSocket::switchServerSSLContext(
const std::shared_ptr<SSLContext>& handshakeCtx) {
CHECK(server_);
return false;
}
- if(!ss->tlsext_hostname) {
- return false;
- }
- return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
+ auto tlsextHostname = SSL_SESSION_get0_hostname(ss);
+ return (tlsextHostname && !tlsextHostname_.compare(tlsextHostname));
}
void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
tlsextHostname_ = std::move(serverName);
}
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
-void AsyncSSLSocket::timeoutExpired() noexcept {
+void AsyncSSLSocket::timeoutExpired(
+ std::chrono::milliseconds timeout) noexcept {
if (state_ == StateEnum::ESTABLISHED &&
- (sslState_ == STATE_CACHE_LOOKUP ||
- sslState_ == STATE_RSA_ASYNC_PENDING)) {
+ (sslState_ == STATE_CACHE_LOOKUP || sslState_ == STATE_ASYNC_PENDING)) {
sslState_ = STATE_ERROR;
// We are expecting a callback in restartSSLAccept. The cache lookup
// and rsa-call necessarily have pointers to this ssl socket, so delay
// the cleanup until he calls us back.
+ } else if (state_ == StateEnum::CONNECTING) {
+ assert(sslState_ == STATE_CONNECTING);
+ DestructorGuard dg(this);
+ AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
+ "Fallback connect timed out during TFO");
+ failHandshake(__func__, ex);
} else {
assert(state_ == StateEnum::ESTABLISHED &&
(sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
DestructorGuard dg(this);
- AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
- (sslState_ == STATE_CONNECTING) ?
- "SSL connect timed out" : "SSL accept timed out");
+ AsyncSocketException ex(
+ AsyncSocketException::TIMED_OUT,
+ folly::sformat(
+ "SSL {} timed out after {}ms",
+ (sslState_ == STATE_CONNECTING) ? "connect" : "accept",
+ timeout.count()));
failHandshake(__func__, ex);
}
}
-int AsyncSSLSocket::sslExDataIndex_ = -1;
-std::mutex AsyncSSLSocket::mutex_;
-
int AsyncSSLSocket::getSSLExDataIndex() {
- if (sslExDataIndex_ < 0) {
- std::lock_guard<std::mutex> g(mutex_);
- if (sslExDataIndex_ < 0) {
- sslExDataIndex_ = SSL_get_ex_new_index(0,
- (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
- }
- }
- return sslExDataIndex_;
+ static auto index = SSL_get_ex_new_index(
+ 0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
+ return index;
}
AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
getSSLExDataIndex()));
}
-void AsyncSSLSocket::failHandshake(const char* fn,
- const AsyncSocketException& ex) {
+void AsyncSSLSocket::failHandshake(const char* /* fn */,
+ const AsyncSocketException& ex) {
startFail();
-
if (handshakeTimeout_.isScheduled()) {
handshakeTimeout_.cancelTimeout();
}
+ invokeHandshakeErr(ex);
+ finishFail();
+}
+
+void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
+ handshakeEndTime_ = std::chrono::steady_clock::now();
if (handshakeCallback_ != nullptr) {
HandshakeCB* callback = handshakeCallback_;
handshakeCallback_ = nullptr;
callback->handshakeErr(this, ex);
}
-
- finishFail();
}
void AsyncSSLSocket::invokeHandshakeCB() {
+ handshakeEndTime_ = std::chrono::steady_clock::now();
if (handshakeTimeout_.isScheduled()) {
handshakeTimeout_.cancelTimeout();
}
}
}
-void AsyncSSLSocket::connect(ConnectCallback* callback,
- const folly::SocketAddress& address,
- int timeout,
- const OptionMap &options,
- const folly::SocketAddress& bindAddr)
- noexcept {
+void AsyncSSLSocket::connect(
+ ConnectCallback* callback,
+ const folly::SocketAddress& address,
+ int timeout,
+ const OptionMap& options,
+ const folly::SocketAddress& bindAddr) noexcept {
+ auto timeoutChrono = std::chrono::milliseconds(timeout);
+ connect(callback, address, timeoutChrono, timeoutChrono, options, bindAddr);
+}
+
+void AsyncSSLSocket::connect(
+ ConnectCallback* callback,
+ const folly::SocketAddress& address,
+ std::chrono::milliseconds connectTimeout,
+ std::chrono::milliseconds totalConnectTimeout,
+ const OptionMap& options,
+ const folly::SocketAddress& bindAddr) noexcept {
assert(!server_);
assert(state_ == StateEnum::UNINIT);
- assert(sslState_ == STATE_UNINIT);
- AsyncSSLSocketConnector *connector =
- new AsyncSSLSocketConnector(this, callback, timeout);
- AsyncSocket::connect(connector, address, timeout, options, bindAddr);
+ assert(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED);
+ noTransparentTls_ = true;
+ totalConnectTimeout_ = totalConnectTimeout;
+ if (sslState_ != STATE_UNENCRYPTED) {
+ callback = new AsyncSSLSocketConnector(
+ this, callback, int(totalConnectTimeout.count()));
+ }
+ AsyncSocket::connect(
+ callback, address, int(connectTimeout.count()), options, bindAddr);
+}
+
+bool AsyncSSLSocket::needsPeerVerification() const {
+ if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
+ return ctx_->needsPeerVerification();
+ }
+ return (
+ verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
+ verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
}
void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
}
}
-void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
- const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
+bool AsyncSSLSocket::setupSSLBio() {
+ auto sslBio = BIO_new(getSSLBioMethod());
+
+ if (!sslBio) {
+ return false;
+ }
+
+ OpenSSLUtils::setBioAppData(sslBio, this);
+ OpenSSLUtils::setBioFd(sslBio, fd_, BIO_NOCLOSE);
+ SSL_set_bio(ssl_, sslBio, sslBio);
+ return true;
+}
+
+void AsyncSSLSocket::sslConn(
+ HandshakeCB* callback,
+ std::chrono::milliseconds timeout,
+ const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
DestructorGuard dg(this);
- assert(eventBase_->isInEventBaseThread());
+ eventBase_->dcheckIsInEventBaseThread();
+
+ // Cache local and remote socket addresses to keep them available
+ // after socket file descriptor is closed.
+ if (cacheAddrOnFailure_) {
+ cacheAddresses();
+ }
verifyPeer_ = verifyPeer;
return failHandshake(__func__, ex);
}
+ if (!setupSSLBio()) {
+ sslState_ = STATE_ERROR;
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio");
+ return failHandshake(__func__, ex);
+ }
+
applyVerificationOptions(ssl_);
- SSL_set_fd(ssl_, fd_);
if (sslSession_ != nullptr) {
+ sessionResumptionAttempted_ = true;
SSL_set_session(ssl_, sslSession_);
SSL_SESSION_free(sslSession_);
sslSession_ = nullptr;
}
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
if (tlsextHostname_.size()) {
SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
}
SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
- if (timeout > 0) {
- handshakeTimeout_.scheduleTimeout(timeout);
- }
+ handshakeConnectTimeout_ = timeout;
+ startSSLConnect();
+}
+// This could be called multiple times, during normal ssl connections
+// and after TFO fallback.
+void AsyncSSLSocket::startSSLConnect() {
+ handshakeStartTime_ = std::chrono::steady_clock::now();
+ // Make end time at least >= start time.
+ handshakeEndTime_ = handshakeStartTime_;
+ if (handshakeConnectTimeout_ > std::chrono::milliseconds::zero()) {
+ handshakeTimeout_.scheduleTimeout(handshakeConnectTimeout_);
+ }
handleConnect();
}
return sslSession_;
}
+const SSL* AsyncSSLSocket::getSSL() const {
+ return ssl_;
+}
+
void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
sslSession_ = session;
if (!takeOwnership && session != nullptr) {
// Increment the reference count
- CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
+ // This API exists in BoringSSL and OpenSSL 1.1.0
+ SSL_SESSION_up_ref(session);
}
}
-void AsyncSSLSocket::getSelectedNextProtocol(const unsigned char** protoName,
- unsigned* protoLen) const {
- if (!getSelectedNextProtocolNoThrow(protoName, protoLen)) {
+void AsyncSSLSocket::getSelectedNextProtocol(
+ const unsigned char** protoName,
+ unsigned* protoLen,
+ SSLContext::NextProtocolType* protoType) const {
+ if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) {
throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
"NPN not supported");
}
}
bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
- const unsigned char** protoName,
- unsigned* protoLen) const {
+ const unsigned char** protoName,
+ unsigned* protoLen,
+ SSLContext::NextProtocolType* protoType) const {
*protoName = nullptr;
*protoLen = 0;
+#if FOLLY_OPENSSL_HAS_ALPN
+ SSL_get0_alpn_selected(ssl_, protoName, protoLen);
+ if (*protoLen > 0) {
+ if (protoType) {
+ *protoType = SSLContext::NextProtocolType::ALPN;
+ }
+ return true;
+ }
+#endif
#ifdef OPENSSL_NPN_NEGOTIATED
SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
+ if (protoType) {
+ *protoType = SSLContext::NextProtocolType::NPN;
+ }
return true;
#else
+ (void)protoType;
return false;
#endif
}
return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
}
+/* static */
+const char* AsyncSSLSocket::getSSLServerNameFromSSL(SSL* ssl) {
+ if (ssl == nullptr) {
+ return nullptr;
+ }
+#ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
+ return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
+#else
+ return nullptr;
+#endif
+}
+
const char *AsyncSSLSocket::getSSLServerName() const {
#ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
- return (ssl_ != nullptr) ? SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name)
- : nullptr;
+ return getSSLServerNameFromSSL(ssl_);
#else
throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
- "SNI not supported");
+ "SNI not supported");
#endif
}
const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
- try {
- return getSSLServerName();
- } catch (AsyncSocketException& ex) {
- return nullptr;
- }
+ return getSSLServerNameFromSSL(ssl_);
}
int AsyncSSLSocket::getSSLVersion() const {
return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
}
+const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
+ X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
+ if (cert) {
+ int nid = X509_get_signature_nid(cert);
+ return OBJ_nid2ln(nid);
+ }
+ return nullptr;
+}
+
int AsyncSSLSocket::getSSLCertSize() const {
int certSize = 0;
X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
return certSize;
}
-bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
- int error = *errorOut = SSL_get_error(ssl_, ret);
+const X509* AsyncSSLSocket::getSelfCert() const {
+ return (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
+}
+
+bool AsyncSSLSocket::willBlock(int ret,
+ int* sslErrorOut,
+ unsigned long* errErrorOut) noexcept {
+ *errErrorOut = 0;
+ int error = *sslErrorOut = SSL_get_error(ssl_, ret);
if (error == SSL_ERROR_WANT_READ) {
// Register for read event if not already.
updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
// The timeout (if set) keeps running here
return true;
#endif
+ } else if ((false
#ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
- } else if (error == SSL_ERROR_WANT_RSA_ASYNC_PENDING) {
+ || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
+#endif
+#ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
+ || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
+#endif
+ )) {
// Our custom openssl function has kicked off an async request to do
- // modular exponentiation. When that call returns, a callback will
+ // rsa/ecdsa private key operation. When that call returns, a callback will
// be invoked that will re-call handleAccept.
- sslState_ = STATE_RSA_ASYNC_PENDING;
+ sslState_ = STATE_ASYNC_PENDING;
// Unregister for all events while blocked here
updateEventRegistration(
// The timeout (if set) keeps running here
return true;
-#endif
} else {
- // SSL_ERROR_ZERO_RETURN is processed here so we can get some detail
- // in the log
- long lastError = ERR_get_error();
+ unsigned long lastError = *errErrorOut = ERR_get_error();
VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
<< "state=" << state_ << ", "
<< "sslState=" << sslState_ << ", "
<< "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
<< "func: " << ERR_func_error_string(lastError) << ", "
<< "reason: " << ERR_reason_error_string(lastError);
- if (error != SSL_ERROR_SYSCALL) {
- if (error == SSL_ERROR_SSL) {
- *errorOut = lastError;
- }
- if ((unsigned long)lastError < 0x8000) {
- errno = ENOSYS;
- } else {
- errno = lastError;
- }
- }
- ERR_clear_error();
return false;
}
}
// the socket to become readable again.
if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
AsyncSocket::handleRead();
+ } else {
+ AsyncSocket::checkForImmediateRead();
}
}
void
AsyncSSLSocket::restartSSLAccept()
{
- VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this << ", fd=" << fd_
- << ", state=" << int(state_) << ", "
+ VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this
+ << ", fd=" << fd_ << ", state=" << int(state_) << ", "
<< "sslState=" << sslState_ << ", events=" << eventFlags_;
DestructorGuard dg(this);
assert(
sslState_ == STATE_CACHE_LOOKUP ||
- sslState_ == STATE_RSA_ASYNC_PENDING ||
+ sslState_ == STATE_ASYNC_PENDING ||
sslState_ == STATE_ERROR ||
- sslState_ == STATE_CLOSED
- );
+ sslState_ == STATE_CLOSED);
if (sslState_ == STATE_CLOSED) {
// I sure hope whoever closed this socket didn't delete it already,
// but this is not strictly speaking an error
}
if (sslState_ == STATE_ERROR) {
// go straight to fail if timeout expired during lookup
- AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
- "SSL accept timed out");
+ AsyncSocketException ex(
+ AsyncSocketException::TIMED_OUT, "SSL accept timed out");
failHandshake(__func__, ex);
return;
}
<< ", fd=" << fd_ << "): " << e.what();
return failHandshake(__func__, ex);
}
- SSL_set_fd(ssl_, fd_);
+
+ if (!setupSSLBio()) {
+ sslState_ = STATE_ERROR;
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR, "error creating write bio");
+ return failHandshake(__func__, ex);
+ }
+
SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
applyVerificationOptions(ssl_);
}
if (server_ && parseClientHello_) {
- SSL_set_msg_callback_arg(ssl_, this);
SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
+ SSL_set_msg_callback_arg(ssl_, this);
}
- errno = 0;
int ret = SSL_accept(ssl_);
if (ret <= 0) {
- int error;
- if (willBlock(ret, &error)) {
+ int sslError;
+ unsigned long errError;
+ int errnoCopy = errno;
+ if (willBlock(ret, &sslError, &errError)) {
return;
} else {
sslState_ = STATE_ERROR;
- SSLException ex(error, errno);
+ SSLException ex(sslError, errError, ret, errnoCopy);
return failHandshake(__func__, ex);
}
}
return AsyncSocket::handleConnect();
}
- assert(state_ == StateEnum::ESTABLISHED &&
- sslState_ == STATE_CONNECTING);
+ assert(
+ (state_ == StateEnum::FAST_OPEN || state_ == StateEnum::ESTABLISHED) &&
+ sslState_ == STATE_CONNECTING);
assert(ssl_);
- errno = 0;
+ auto originalState = state_;
int ret = SSL_connect(ssl_);
if (ret <= 0) {
- int error;
- if (willBlock(ret, &error)) {
+ int sslError;
+ unsigned long errError;
+ int errnoCopy = errno;
+ if (willBlock(ret, &sslError, &errError)) {
+ // We fell back to connecting state due to TFO
+ if (state_ == StateEnum::CONNECTING) {
+ DCHECK_EQ(StateEnum::FAST_OPEN, originalState);
+ if (handshakeTimeout_.isScheduled()) {
+ handshakeTimeout_.cancelTimeout();
+ }
+ }
return;
} else {
sslState_ = STATE_ERROR;
- SSLException ex(error, errno);
+ SSLException ex(sslError, errError, ret, errnoCopy);
return failHandshake(__func__, ex);
}
}
// STATE_CONNECTING.
sslState_ = STATE_ESTABLISHED;
- VLOG(3) << "AsyncSSLSocket %p: fd %d successfully connected; "
+ VLOG(3) << "AsyncSSLSocket " << this << ": "
+ << "fd " << fd_ << " successfully connected; "
<< "state=" << int(state_) << ", sslState=" << sslState_
<< ", events=" << eventFlags_;
AsyncSocket::handleInitialReadWrite();
}
+void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) {
+ connectionTimeout_.cancelTimeout();
+ AsyncSocket::invokeConnectErr(ex);
+ if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
+ if (handshakeTimeout_.isScheduled()) {
+ handshakeTimeout_.cancelTimeout();
+ }
+ // If we fell back to connecting state during TFO and the connection
+ // failed, it would be an SSL failure as well.
+ invokeHandshakeErr(ex);
+ }
+}
+
+void AsyncSSLSocket::invokeConnectSuccess() {
+ connectionTimeout_.cancelTimeout();
+ if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
+ assert(tfoAttempted_);
+ // If we failed TFO, we'd fall back to trying to connect the socket,
+ // to setup things like timeouts.
+ startSSLConnect();
+ }
+ // still invoke the base class since it re-sets the connect time.
+ AsyncSocket::invokeConnectSuccess();
+}
+
+void AsyncSSLSocket::scheduleConnectTimeout() {
+ if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
+ // We fell back from TFO, and need to set the timeouts.
+ // We will not have a connect callback in this case, thus if the timer
+ // expires we would have no-one to notify.
+ // Thus we should reset even the connect timers to point to the handshake
+ // timeouts.
+ assert(connectCallback_ == nullptr);
+ // We use a different connect timeout here than the handshake timeout, so
+ // that we can disambiguate the 2 timers.
+ if (connectTimeout_.count() > 0) {
+ if (!connectionTimeout_.scheduleTimeout(connectTimeout_)) {
+ throw AsyncSocketException(
+ AsyncSocketException::INTERNAL_ERROR,
+ withAddr("failed to schedule AsyncSSLSocket connect timeout"));
+ }
+ }
+ return;
+ }
+ AsyncSocket::scheduleConnectTimeout();
+}
+
void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
// turn on the buffer movable in openssl
- if (!isBufferMovable_ && callback != nullptr && callback->isBufferMovable()) {
+ if (bufferMovableEnabled_ && ssl_ != nullptr && !isBufferMovable_ &&
+ callback != nullptr && callback->isBufferMovable()) {
SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
isBufferMovable_ = true;
}
AsyncSocket::setReadCB(callback);
}
-void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
+void AsyncSSLSocket::setBufferMovableEnabled(bool enabled) {
+ bufferMovableEnabled_ = enabled;
+}
+
+void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) {
CHECK(readCallback_);
if (isBufferMovable_) {
*buf = nullptr;
AsyncSocket::handleRead();
}
-ssize_t
+AsyncSocket::ReadResult
AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
- VLOG(4) << "AsyncSSLSocket::performRead() this=" << this
- << ", buf=" << *buf << ", buflen=" << *buflen;
+ VLOG(4) << "AsyncSSLSocket::performRead() this=" << this << ", buf=" << *buf
+ << ", buflen=" << *buflen;
if (sslState_ == STATE_UNENCRYPTED) {
return AsyncSocket::performRead(buf, buflen, offset);
}
- errno = 0;
- ssize_t bytes = 0;
+ int bytes = 0;
if (!isBufferMovable_) {
- bytes = SSL_read(ssl_, *buf, *buflen);
+ bytes = SSL_read(ssl_, *buf, int(*buflen));
}
#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
else {
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslstate=" << sslState_ << ", events=" << eventFlags_
<< "): client intitiated SSL renegotiation not permitted";
- // We pack our own SSLerr here with a dummy function
- errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
- SSL_CLIENT_RENEGOTIATION_ATTEMPT);
- ERR_clear_error();
- return READ_ERROR;
+ return ReadResult(
+ READ_ERROR,
+ std::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
}
if (bytes <= 0) {
int error = SSL_get_error(ssl_, bytes);
if (error == SSL_ERROR_WANT_READ) {
// The caller will register for read event if not already.
- return READ_BLOCKING;
+ if (errno == EWOULDBLOCK || errno == EAGAIN) {
+ return ReadResult(READ_BLOCKING);
+ } else {
+ return ReadResult(READ_ERROR);
+ }
} else if (error == SSL_ERROR_WANT_WRITE) {
// TODO: Even though we are attempting to read data, SSL_read() may
// need to write data if renegotiation is being performed. We currently
// don't support this and just fail the read.
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslState=" << sslState_ << ", events=" << eventFlags_
- << "): unsupported SSL renegotiation during read",
- errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
- SSL_INVALID_RENEGOTIATION);
- ERR_clear_error();
- return READ_ERROR;
+ << "): unsupported SSL renegotiation during read";
+ return ReadResult(
+ READ_ERROR,
+ std::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
} else {
- // TODO: Fix this code so that it can return a proper error message
- // to the callback, rather than relying on AsyncSocket code which
- // can't handle SSL errors.
- long lastError = ERR_get_error();
-
+ if (zero_return(error, bytes)) {
+ return ReadResult(bytes);
+ }
+ auto errError = ERR_get_error();
VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
<< "state=" << state_ << ", "
<< "sslState=" << sslState_ << ", "
<< "bytes: " << bytes << ", "
<< "error: " << error << ", "
<< "errno: " << errno << ", "
- << "func: " << ERR_func_error_string(lastError) << ", "
- << "reason: " << ERR_reason_error_string(lastError);
- ERR_clear_error();
- if (zero_return(error, bytes)) {
- return bytes;
- }
- if (error != SSL_ERROR_SYSCALL) {
- if ((unsigned long)lastError < 0x8000) {
- errno = ENOSYS;
- } else {
- errno = lastError;
- }
- }
- return READ_ERROR;
+ << "func: " << ERR_func_error_string(errError) << ", "
+ << "reason: " << ERR_reason_error_string(errError);
+ return ReadResult(
+ READ_ERROR,
+ std::make_unique<SSLException>(error, errError, bytes, errno));
}
} else {
appBytesReceived_ += bytes;
- return bytes;
+ return ReadResult(bytes);
}
}
AsyncSocket::handleWrite();
}
-int AsyncSSLSocket::interpretSSLError(int rc, int error) {
+AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
if (error == SSL_ERROR_WANT_READ) {
- // TODO: Even though we are attempting to write data, SSL_write() may
+ // Even though we are attempting to write data, SSL_write() may
// need to read data if renegotiation is being performed. We currently
// don't support this and just fail the write.
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslState=" << sslState_ << ", events=" << eventFlags_
- << "): " << "unsupported SSL renegotiation during write",
- errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
- SSL_INVALID_RENEGOTIATION);
- ERR_clear_error();
- return -1;
+ << "): "
+ << "unsupported SSL renegotiation during write";
+ return WriteResult(
+ WRITE_ERROR,
+ std::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
} else {
- // TODO: Fix this code so that it can return a proper error message
- // to the callback, rather than relying on AsyncSocket code which
- // can't handle SSL errors.
- long lastError = ERR_get_error();
+ if (zero_return(error, rc)) {
+ return WriteResult(0);
+ }
+ auto errError = ERR_get_error();
VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
<< "SSL error: " << error << ", errno: " << errno
- << ", func: " << ERR_func_error_string(lastError)
- << ", reason: " << ERR_reason_error_string(lastError);
- if (error != SSL_ERROR_SYSCALL) {
- if ((unsigned long)lastError < 0x8000) {
- errno = ENOSYS;
- } else {
- errno = lastError;
- }
- }
- ERR_clear_error();
- if (!zero_return(error, rc)) {
- return -1;
- } else {
- return 0;
- }
+ << ", func: " << ERR_func_error_string(errError)
+ << ", reason: " << ERR_reason_error_string(errError);
+ return WriteResult(
+ WRITE_ERROR,
+ std::make_unique<SSLException>(error, errError, rc, errno));
}
}
-ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
- uint32_t count,
- WriteFlags flags,
- uint32_t* countWritten,
- uint32_t* partialWritten) {
+AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
+ const iovec* vec,
+ uint32_t count,
+ WriteFlags flags,
+ uint32_t* countWritten,
+ uint32_t* partialWritten) {
if (sslState_ == STATE_UNENCRYPTED) {
return AsyncSocket::performWrite(
vec, count, flags, countWritten, partialWritten);
<< ", events=" << eventFlags_ << "): "
<< "TODO: AsyncSSLSocket currently does not support calling "
<< "write() before the handshake has fully completed";
- errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
- SSL_EARLY_WRITE);
- return -1;
- }
-
- bool cork = isSet(flags, WriteFlags::CORK);
- CorkGuard guard(fd_, count > 1, cork, &corked_);
-
-#if 0
-//#ifdef SSL_MODE_WRITE_IOVEC
- if (ssl_->expand == nullptr &&
- ssl_->compress == nullptr &&
- (ssl_->mode & SSL_MODE_WRITE_IOVEC)) {
- return performWriteIovec(vec, count, flags, countWritten, partialWritten);
+ return WriteResult(
+ WRITE_ERROR, std::make_unique<SSLException>(SSLError::EARLY_WRITE));
}
-#endif
// Declare a buffer used to hold small write requests. It could point to a
// memory block either on stack or on heap. If it is on heap, we release it
buf = ((const char*)v->iov_base) + offset;
ssize_t bytes;
- errno = 0;
uint32_t buffersStolen = 0;
+ auto sslWriteBuf = buf;
if ((len < minWriteSize_) && ((i + 1) < count)) {
// Combine this buffer with part or all of the next buffers in
// order to avoid really small-grained calls to SSL_write().
}
}
assert(combinedBuf != nullptr);
+ sslWriteBuf = combinedBuf;
memcpy(combinedBuf, buf, len);
do {
uint32_t nextIndex = i + buffersStolen + 1;
bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len,
minWriteSize_ - len);
- memcpy(combinedBuf + len, vec[nextIndex].iov_base,
- bytesStolenFromNextBuffer);
+ if (bytesStolenFromNextBuffer > 0) {
+ assert(vec[nextIndex].iov_base != nullptr);
+ ::memcpy(
+ combinedBuf + len,
+ vec[nextIndex].iov_base,
+ bytesStolenFromNextBuffer);
+ }
len += bytesStolenFromNextBuffer;
if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
// couldn't steal the whole buffer
buffersStolen++;
}
} while ((i + buffersStolen + 1) < count && (len < minWriteSize_));
- bytes = eorAwareSSLWrite(
- ssl_, combinedBuf, len,
- (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
+ }
- } else {
- bytes = eorAwareSSLWrite(ssl_, buf, len,
- (isSet(flags, WriteFlags::EOR) && i + 1 == count));
+ // Advance any empty buffers immediately after.
+ if (bytesStolenFromNextBuffer == 0) {
+ while ((i + buffersStolen + 1) < count &&
+ vec[i + buffersStolen + 1].iov_len == 0) {
+ buffersStolen++;
+ }
}
+ corkCurrentWrite_ =
+ isSet(flags, WriteFlags::CORK) || (i + buffersStolen + 1 < count);
+ bytes = eorAwareSSLWrite(
+ ssl_,
+ sslWriteBuf,
+ int(len),
+ (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
+
if (bytes <= 0) {
- int error = SSL_get_error(ssl_, bytes);
+ int error = SSL_get_error(ssl_, int(bytes));
if (error == SSL_ERROR_WANT_WRITE) {
// The caller will register for write event if not already.
- *partialWritten = offset;
- return totalWritten;
+ *partialWritten = uint32_t(offset);
+ return WriteResult(totalWritten);
}
- int rc = interpretSSLError(bytes, error);
- if (rc < 0) {
- return rc;
+ auto writeResult = interpretSSLError(int(bytes), error);
+ if (writeResult.writeReturn < 0) {
+ return writeResult;
} // else fall through to below to correctly record totalWritten
}
(*countWritten)++;
v = &(vec[++i]);
}
- *partialWritten = bytes;
- return totalWritten;
+ *partialWritten = uint32_t(bytes);
+ return WriteResult(totalWritten);
}
}
- return totalWritten;
-}
-
-#if 0
-//#ifdef SSL_MODE_WRITE_IOVEC
-ssize_t AsyncSSLSocket::performWriteIovec(const iovec* vec,
- uint32_t count,
- WriteFlags flags,
- uint32_t* countWritten,
- uint32_t* partialWritten) {
- size_t tot = 0;
- for (uint32_t j = 0; j < count; j++) {
- tot += vec[j].iov_len;
- }
-
- ssize_t totalWritten = SSL_write_iovec(ssl_, vec, count);
-
- *countWritten = 0;
- *partialWritten = 0;
- if (totalWritten <= 0) {
- return interpretSSLError(totalWritten, SSL_get_error(ssl_, totalWritten));
- } else {
- ssize_t bytes = totalWritten, i = 0;
- while (i < count && bytes >= (ssize_t)vec[i].iov_len) {
- // we managed to write all of this buf
- bytes -= vec[i].iov_len;
- (*countWritten)++;
- i++;
- }
- *partialWritten = bytes;
-
- VLOG(4) << "SSL_write_iovec() writes " << tot
- << ", returns " << totalWritten << " bytes"
- << ", max_send_fragment=" << ssl_->max_send_fragment
- << ", count=" << count << ", countWritten=" << *countWritten;
-
- return totalWritten;
- }
+ return WriteResult(totalWritten);
}
-#endif
int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
bool eor) {
- if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) {
+ if (eor && isEorTrackingEnabled()) {
if (appEorByteNo_) {
// cannot track for more than one app byte EOR
CHECK(appEorByteNo_ == appBytesWritten_ + n);
return n;
}
-void
-AsyncSSLSocket::sslInfoCallback(const SSL *ssl, int where, int ret) {
+void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
sslSocket->renegotiateAttempted_ = true;
}
+ if (where & SSL_CB_READ_ALERT) {
+ const char* type = SSL_alert_type_string(ret);
+ if (type) {
+ const char* desc = SSL_alert_desc_string(ret);
+ sslSocket->alertsReceived_.emplace_back(
+ *type, StringPiece(desc, std::strlen(desc)));
+ }
+ }
}
-int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) {
- int ret;
+int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
struct msghdr msg;
struct iovec iov;
- int flags = 0;
- AsyncSSLSocket *tsslSock;
+ AsyncSSLSocket* tsslSock;
- iov.iov_base = const_cast<char *>(in);
- iov.iov_len = inl;
+ iov.iov_base = const_cast<char*>(in);
+ iov.iov_len = size_t(inl);
memset(&msg, 0, sizeof(msg));
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
- tsslSock =
- reinterpret_cast<AsyncSSLSocket*>(BIO_get_app_data(b));
- if (tsslSock &&
- tsslSock->minEorRawByteNo_ &&
+ auto appData = OpenSSLUtils::getBioAppData(b);
+ CHECK(appData);
+
+ tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
+ CHECK(tsslSock);
+
+ WriteFlags flags = WriteFlags::NONE;
+ if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ &&
tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
- flags = MSG_EOR;
+ flags |= WriteFlags::EOR;
}
- errno = 0;
- ret = sendmsg(b->num, &msg, flags);
+ if (tsslSock->corkCurrentWrite_) {
+ flags |= WriteFlags::CORK;
+ }
+
+ int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(
+ flags, false /*zeroCopyEnabled*/);
+ msg.msg_controllen =
+ tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags);
+ CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
+ msg.msg_controllen);
+ if (msg.msg_controllen != 0) {
+ msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
+ tsslSock->getSendMsgParamsCB()->getAncillaryData(flags, msg.msg_control);
+ }
+
+ auto result = tsslSock->sendSocketMessage(
+ OpenSSLUtils::getBioFd(b, nullptr), &msg, msg_flags);
BIO_clear_retry_flags(b);
- if (ret <= 0) {
- if (BIO_sock_should_retry(ret))
+ if (!result.exception && result.writeReturn <= 0) {
+ if (OpenSSLUtils::getBioShouldRetryWrite(int(result.writeReturn))) {
BIO_set_retry_write(b);
+ }
}
- return(ret);
+ return int(result.writeReturn);
}
-int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
- X509_STORE_CTX* x509Ctx) {
+int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) {
+ if (!out) {
+ return 0;
+ }
+ BIO_clear_retry_flags(b);
+
+ auto appData = OpenSSLUtils::getBioAppData(b);
+ CHECK(appData);
+ auto sslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
+
+ if (sslSock->preReceivedData_ && !sslSock->preReceivedData_->empty()) {
+ VLOG(5) << "AsyncSSLSocket::bioRead() this=" << sslSock
+ << ", reading pre-received data";
+
+ Cursor cursor(sslSock->preReceivedData_.get());
+ auto len = cursor.pullAtMost(out, outl);
+
+ IOBufQueue queue;
+ queue.append(std::move(sslSock->preReceivedData_));
+ queue.trimStart(len);
+ sslSock->preReceivedData_ = queue.move();
+ return static_cast<int>(len);
+ } else {
+ auto result = int(recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0));
+ if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
+ BIO_set_retry_read(b);
+ }
+ return result;
+ }
+}
+
+int AsyncSSLSocket::sslVerifyCallback(
+ int preverifyOk,
+ X509_STORE_CTX* x509Ctx) {
SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
void AsyncSSLSocket::enableClientHelloParsing() {
parseClientHello_ = true;
- clientHelloInfo_.reset(new ClientHelloInfo());
+ clientHelloInfo_ = std::make_unique<ssl::ClientHelloInfo>();
}
void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl) {
clientHelloInfo_->clientHelloBuf_.clear();
}
-void
-AsyncSSLSocket::clientHelloParsingCallback(int written, int version,
- int contentType, const void *buf, size_t len, SSL *ssl, void *arg)
-{
+void AsyncSSLSocket::clientHelloParsingCallback(int written,
+ int /* version */,
+ int contentType,
+ const void* buf,
+ size_t len,
+ SSL* ssl,
+ void* arg) {
AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
if (written != 0) {
sock->resetClientHelloParsing(ssl);
return;
}
if (contentType != SSL3_RT_HANDSHAKE) {
- sock->resetClientHelloParsing(ssl);
return;
}
if (len == 0) {
if (cursor.totalLength() > 0) {
uint16_t extensionsLength = cursor.readBE<uint16_t>();
while (extensionsLength) {
+ ssl::TLSExtension extensionType =
+ static_cast<ssl::TLSExtension>(cursor.readBE<uint16_t>());
sock->clientHelloInfo_->
- clientHelloExtensions_.push_back(cursor.readBE<uint16_t>());
+ clientHelloExtensions_.push_back(extensionType);
extensionsLength -= 2;
uint16_t extensionDataLength = cursor.readBE<uint16_t>();
extensionsLength -= 2;
- cursor.skip(extensionDataLength);
extensionsLength -= extensionDataLength;
+
+ if (extensionType == ssl::TLSExtension::SIGNATURE_ALGORITHMS) {
+ cursor.skip(2);
+ extensionDataLength -= 2;
+ while (extensionDataLength) {
+ ssl::HashAlgorithm hashAlg =
+ static_cast<ssl::HashAlgorithm>(cursor.readBE<uint8_t>());
+ ssl::SignatureAlgorithm sigAlg =
+ static_cast<ssl::SignatureAlgorithm>(cursor.readBE<uint8_t>());
+ extensionDataLength -= 2;
+ sock->clientHelloInfo_->
+ clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg);
+ }
+ } else if (extensionType == ssl::TLSExtension::SUPPORTED_VERSIONS) {
+ cursor.skip(1);
+ extensionDataLength -= 1;
+ while (extensionDataLength) {
+ sock->clientHelloInfo_->clientHelloSupportedVersions_.push_back(
+ cursor.readBE<uint16_t>());
+ extensionDataLength -= 2;
+ }
+ } else {
+ cursor.skip(extensionDataLength);
+ }
}
}
- } catch (std::out_of_range& e) {
+ } catch (std::out_of_range&) {
// we'll use what we found and cleanup below.
VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
<< "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock;
sock->resetClientHelloParsing(ssl);
}
-} // namespace
+void AsyncSSLSocket::getSSLClientCiphers(
+ std::string& clientCiphers,
+ bool convertToString) const {
+ std::string ciphers;
+
+ if (parseClientHello_ == false
+ || clientHelloInfo_->clientHelloCipherSuites_.empty()) {
+ clientCiphers = "";
+ return;
+ }
+
+ bool first = true;
+ for (auto originalCipherCode : clientHelloInfo_->clientHelloCipherSuites_)
+ {
+ if (first) {
+ first = false;
+ } else {
+ ciphers += ":";
+ }
+
+ bool nameFound = convertToString;
+
+ if (convertToString) {
+ const auto& name = OpenSSLUtils::getCipherName(originalCipherCode);
+ if (name.empty()) {
+ nameFound = false;
+ } else {
+ ciphers += name;
+ }
+ }
+
+ if (!nameFound) {
+ folly::hexlify(
+ std::array<uint8_t, 2>{{
+ static_cast<uint8_t>((originalCipherCode >> 8) & 0xffL),
+ static_cast<uint8_t>(originalCipherCode & 0x00ffL) }},
+ ciphers,
+ /* append to ciphers = */ true);
+ }
+ }
+
+ clientCiphers = std::move(ciphers);
+}
+
+std::string AsyncSSLSocket::getSSLClientComprMethods() const {
+ if (!parseClientHello_) {
+ return "";
+ }
+ return folly::join(":", clientHelloInfo_->clientHelloCompressionMethods_);
+}
+
+std::string AsyncSSLSocket::getSSLClientExts() const {
+ if (!parseClientHello_) {
+ return "";
+ }
+ return folly::join(":", clientHelloInfo_->clientHelloExtensions_);
+}
+
+std::string AsyncSSLSocket::getSSLClientSigAlgs() const {
+ if (!parseClientHello_) {
+ return "";
+ }
+
+ std::string sigAlgs;
+ sigAlgs.reserve(clientHelloInfo_->clientHelloSigAlgs_.size() * 4);
+ for (size_t i = 0; i < clientHelloInfo_->clientHelloSigAlgs_.size(); i++) {
+ if (i) {
+ sigAlgs.push_back(':');
+ }
+ sigAlgs.append(folly::to<std::string>(
+ clientHelloInfo_->clientHelloSigAlgs_[i].first));
+ sigAlgs.push_back(',');
+ sigAlgs.append(folly::to<std::string>(
+ clientHelloInfo_->clientHelloSigAlgs_[i].second));
+ }
+
+ return sigAlgs;
+}
+
+std::string AsyncSSLSocket::getSSLClientSupportedVersions() const {
+ if (!parseClientHello_) {
+ return "";
+ }
+ return folly::join(":", clientHelloInfo_->clientHelloSupportedVersions_);
+}
+
+std::string AsyncSSLSocket::getSSLAlertsReceived() const {
+ std::string ret;
+
+ for (const auto& alert : alertsReceived_) {
+ if (!ret.empty()) {
+ ret.append(",");
+ }
+ ret.append(folly::to<std::string>(alert.first, ": ", alert.second));
+ }
+
+ return ret;
+}
+
+void AsyncSSLSocket::setSSLCertVerificationAlert(std::string alert) {
+ sslVerificationAlert_ = std::move(alert);
+}
+
+std::string AsyncSSLSocket::getSSLCertVerificationAlert() const {
+ return sslVerificationAlert_;
+}
+
+void AsyncSSLSocket::getSSLSharedCiphers(std::string& sharedCiphers) const {
+ char ciphersBuffer[1024];
+ ciphersBuffer[0] = '\0';
+ SSL_get_shared_ciphers(ssl_, ciphersBuffer, sizeof(ciphersBuffer) - 1);
+ sharedCiphers = ciphersBuffer;
+}
+
+void AsyncSSLSocket::getSSLServerCiphers(std::string& serverCiphers) const {
+ serverCiphers = SSL_get_cipher_list(ssl_, 0);
+ int i = 1;
+ const char *cipher;
+ while ((cipher = SSL_get_cipher_list(ssl_, i)) != nullptr) {
+ serverCiphers.append(":");
+ serverCiphers.append(cipher);
+ i++;
+ }
+}
+
+} // namespace folly