/*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2016 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h>
+#include <folly/portability/Sockets.h>
#include <boost/noncopyable.hpp>
#include <errno.h>
#include <fcntl.h>
-#include <netinet/in.h>
-#include <netinet/tcp.h>
#include <openssl/err.h>
#include <openssl/asn1.h>
#include <openssl/ssl.h>
#include <sys/types.h>
-#include <sys/socket.h>
-#include <unistd.h>
#include <chrono>
#include <folly/Bits.h>
#include <folly/SpinLock.h>
#include <folly/io/IOBuf.h>
#include <folly/io/Cursor.h>
+#include <folly/portability/Unistd.h>
using folly::SocketAddress;
using folly::SSLContext;
using folly::AsyncSocketException;
using folly::AsyncSSLSocket;
using folly::Optional;
+using folly::SSLContext;
// We have one single dummy SSL context so that we can implement attach
// and detach methods in a thread safe fashion without modifying opnessl.
static SSLContext *dummyCtx = nullptr;
static SpinLock dummyCtxLock;
-// Numbers chosen as to not collide with functions in ssl.h
-const uint8_t TASYNCSSLSOCKET_F_PERFORM_READ = 90;
-const uint8_t TASYNCSSLSOCKET_F_PERFORM_WRITE = 91;
-
// If given min write size is less than this, buffer will be allocated on
// stack, otherwise it is allocated on heap
const size_t MAX_STACK_BUF_SIZE = 2048;
int64_t startTime_;
protected:
- virtual ~AsyncSSLSocketConnector() {
- }
+ ~AsyncSSLSocketConnector() override {}
public:
AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
std::chrono::steady_clock::now().time_since_epoch()).count()) {
}
- virtual void connectSuccess() noexcept {
+ void connectSuccess() noexcept override {
VLOG(7) << "client socket connected";
int64_t timeoutLeft = 0;
sslSocket_->sslConn(this, timeoutLeft);
}
- virtual void connectErr(const AsyncSocketException& ex) noexcept {
+ void connectErr(const AsyncSocketException& ex) noexcept override {
LOG(ERROR) << "TCP connect failed: " << ex.what();
fail(ex);
delete this;
}
- virtual void handshakeSuc(AsyncSSLSocket *sock) noexcept {
+ void handshakeSuc(AsyncSSLSocket* /* sock */) noexcept override {
VLOG(7) << "client handshake success";
if (callback_) {
callback_->connectSuccess();
delete this;
}
- virtual void handshakeErr(AsyncSSLSocket *socket,
- const AsyncSocketException& ex) noexcept {
+ void handshakeErr(AsyncSSLSocket* /* socket */,
+ const AsyncSocketException& ex) noexcept override {
LOG(ERROR) << "client handshakeErr: " << ex.what();
fail(ex);
delete this;
SSL_MODE_ENABLE_PARTIAL_WRITE
);
#endif
+// SSL_CTX_set_mode is a Macro
+#ifdef SSL_MODE_WRITE_IOVEC
+ SSL_CTX_set_mode(ctx,
+ SSL_CTX_get_mode(ctx)
+ | SSL_MODE_WRITE_IOVEC);
+#endif
+
}
-BIO_METHOD eorAwareBioMethod;
+BIO_METHOD sslWriteBioMethod;
-__attribute__((__constructor__))
-void initEorBioMethod(void) {
- memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod));
+void* initsslWriteBioMethod(void) {
+ memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod));
// override the bwrite method for MSG_EOR support
- eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite;
+ sslWriteBioMethod.bwrite = AsyncSSLSocket::bioWrite;
- // Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not
+ // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not
// set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
- // then have specific handlings. The eorAwareBioWrite should be compatible
+ // then have specific handlings. The sslWriteBioWrite should be compatible
// with the one in openssl.
+
+ // Return something here to enable AsyncSSLSocket to call this method using
+ // a function-scoped static.
+ return nullptr;
}
} // anonymous namespace
namespace folly {
-SSLException::SSLException(int sslError, int errno_copy):
- AsyncSocketException(
- AsyncSocketException::SSL_ERROR,
- ERR_error_string(sslError, msg_),
- sslError == SSL_ERROR_SYSCALL ? errno_copy : 0), error_(sslError) {}
-
/**
* Create a client AsyncSSLSocket
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
- EventBase* evb) :
+ EventBase* evb, bool deferSecurityNegotiation) :
AsyncSocket(evb),
ctx_(ctx),
handshakeTimeout_(this, evb) {
- setup_SSL_CTX(ctx_->getSSLCtx());
+ init();
+ if (deferSecurityNegotiation) {
+ sslState_ = STATE_UNENCRYPTED;
+ }
}
/**
* Create a server/client AsyncSSLSocket
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
- EventBase* evb, int fd, bool server) :
+ EventBase* evb, int fd, bool server,
+ bool deferSecurityNegotiation) :
AsyncSocket(evb, fd),
server_(server),
ctx_(ctx),
handshakeTimeout_(this, evb) {
- setup_SSL_CTX(ctx_->getSSLCtx());
+ init();
if (server) {
SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
AsyncSSLSocket::sslInfoCallback);
}
+ if (deferSecurityNegotiation) {
+ sslState_ = STATE_UNENCRYPTED;
+ }
}
#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
EventBase* evb,
- const std::string& serverName) :
- AsyncSocket(evb),
- ctx_(ctx),
- handshakeTimeout_(this, evb),
- tlsextHostname_(serverName) {
- setup_SSL_CTX(ctx_->getSSLCtx());
+ const std::string& serverName,
+ bool deferSecurityNegotiation) :
+ AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
+ tlsextHostname_ = serverName;
}
/**
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
EventBase* evb, int fd,
- const std::string& serverName) :
- AsyncSocket(evb, fd),
- ctx_(ctx),
- handshakeTimeout_(this, evb),
- tlsextHostname_(serverName) {
- setup_SSL_CTX(ctx_->getSSLCtx());
+ const std::string& serverName,
+ bool deferSecurityNegotiation) :
+ AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
+ tlsextHostname_ = serverName;
}
#endif
<< sslState_ << ", events=" << eventFlags_ << ")";
}
+void AsyncSSLSocket::init() {
+ // Do this here to ensure we initialize this once before any use of
+ // AsyncSSLSocket instances and not as part of library load.
+ static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod();
+ (void)sslWriteBioMethodInitializer;
+
+ setup_SSL_CTX(ctx_->getSSLCtx());
+}
+
void AsyncSSLSocket::closeNow() {
// Close the SSL connection.
if (ssl_ != nullptr && fd_ != -1) {
DestructorGuard dg(this);
- if (handshakeCallback_) {
- AsyncSocketException ex(AsyncSocketException::END_OF_FILE,
- "SSL connection closed locally");
- HandshakeCB* callback = handshakeCallback_;
- handshakeCallback_ = nullptr;
- callback->handshakeErr(this, ex);
- }
+ invokeHandshakeErr(
+ AsyncSocketException(
+ AsyncSocketException::END_OF_FILE,
+ "SSL connection closed locally"));
if (ssl_ != nullptr) {
SSL_free(ssl_);
bool AsyncSSLSocket::good() const {
return (AsyncSocket::good() &&
(sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
- sslState_ == STATE_ESTABLISHED));
+ sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
}
// The TAsyncTransport definition of 'good' states that the transport is
sslState_ == STATE_CONNECTING))));
}
+std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
+ const unsigned char* protoName = nullptr;
+ unsigned protoLength;
+ if (getSelectedNextProtocolNoThrow(&protoName, &protoLength)) {
+ return std::string(reinterpret_cast<const char*>(protoName), protoLength);
+ }
+ return "";
+}
+
bool AsyncSSLSocket::isEorTrackingEnabled() const {
- const BIO *wb = SSL_get_wbio(ssl_);
- return wb && wb->method == &eorAwareBioMethod;
+ return trackEor_;
}
void AsyncSSLSocket::setEorTracking(bool track) {
- BIO *wb = SSL_get_wbio(ssl_);
- if (!wb) {
- throw AsyncSocketException(AsyncSocketException::INVALID_STATE,
- "setting EOR tracking without an initialized "
- "BIO");
- }
-
- if (track) {
- if (wb->method != &eorAwareBioMethod) {
- // only do this if we didn't
- wb->method = &eorAwareBioMethod;
- BIO_set_app_data(wb, this);
- appEorByteNo_ = 0;
- minEorRawByteNo_ = 0;
- }
- } else if (wb->method == &eorAwareBioMethod) {
- wb->method = BIO_s_socket();
- BIO_set_app_data(wb, nullptr);
+ if (trackEor_ != track) {
+ trackEor_ = track;
appEorByteNo_ = 0;
minEorRawByteNo_ = 0;
- } else {
- CHECK(wb->method == BIO_s_socket());
}
}
AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
"sslAccept() called with socket in invalid state");
+ handshakeEndTime_ = std::chrono::steady_clock::now();
if (callback) {
callback->handshakeErr(this, ex);
}
verifyPeer_ = verifyPeer;
// Make sure we're in the uninitialized state
- if (!server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
+ if (!server_ || (sslState_ != STATE_UNINIT &&
+ sslState_ != STATE_UNENCRYPTED) ||
+ handshakeCallback_ != nullptr) {
return invalidState(callback);
}
+ // Cache local and remote socket addresses to keep them available
+ // after socket file descriptor is closed.
+ if (cacheAddrOnFailure_ && -1 != getFd()) {
+ cacheLocalPeerAddr();
+ }
+
+ handshakeStartTime_ = std::chrono::steady_clock::now();
+ // Make end time at least >= start time.
+ handshakeEndTime_ = handshakeStartTime_;
+
sslState_ = STATE_ACCEPTING;
handshakeCallback_ = callback;
return false;
}
- return (ss->tlsext_hostname ? true : false);
+ if(!ss->tlsext_hostname) {
+ return false;
+ }
+ return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
}
void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
void AsyncSSLSocket::timeoutExpired() noexcept {
if (state_ == StateEnum::ESTABLISHED &&
(sslState_ == STATE_CACHE_LOOKUP ||
- sslState_ == STATE_RSA_ASYNC_PENDING)) {
+ sslState_ == STATE_ASYNC_PENDING)) {
sslState_ = STATE_ERROR;
// We are expecting a callback in restartSSLAccept. The cache lookup
// and rsa-call necessarily have pointers to this ssl socket, so delay
}
}
-int AsyncSSLSocket::sslExDataIndex_ = -1;
-std::mutex AsyncSSLSocket::mutex_;
-
int AsyncSSLSocket::getSSLExDataIndex() {
- if (sslExDataIndex_ < 0) {
- std::lock_guard<std::mutex> g(mutex_);
- if (sslExDataIndex_ < 0) {
- sslExDataIndex_ = SSL_get_ex_new_index(0,
- (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
- }
- }
- return sslExDataIndex_;
+ static auto index = SSL_get_ex_new_index(
+ 0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
+ return index;
}
AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
getSSLExDataIndex()));
}
-void AsyncSSLSocket::failHandshake(const char* fn,
- const AsyncSocketException& ex) {
+void AsyncSSLSocket::failHandshake(const char* /* fn */,
+ const AsyncSocketException& ex) {
startFail();
-
if (handshakeTimeout_.isScheduled()) {
handshakeTimeout_.cancelTimeout();
}
+ invokeHandshakeErr(ex);
+ finishFail();
+}
+
+void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
+ handshakeEndTime_ = std::chrono::steady_clock::now();
if (handshakeCallback_ != nullptr) {
HandshakeCB* callback = handshakeCallback_;
handshakeCallback_ = nullptr;
callback->handshakeErr(this, ex);
}
-
- finishFail();
}
void AsyncSSLSocket::invokeHandshakeCB() {
+ handshakeEndTime_ = std::chrono::steady_clock::now();
if (handshakeTimeout_.isScheduled()) {
handshakeTimeout_.cancelTimeout();
}
}
}
+void AsyncSSLSocket::cacheLocalPeerAddr() {
+ SocketAddress address;
+ try {
+ getLocalAddress(&address);
+ getPeerAddress(&address);
+ } catch (const std::system_error& e) {
+ // The handle can be still valid while the connection is already closed.
+ if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
+ throw;
+ }
+ }
+}
+
void AsyncSSLSocket::connect(ConnectCallback* callback,
const folly::SocketAddress& address,
int timeout,
}
}
+bool AsyncSSLSocket::setupSSLBio() {
+ auto wb = BIO_new(&sslWriteBioMethod);
+
+ if (!wb) {
+ return false;
+ }
+
+ BIO_set_app_data(wb, this);
+ BIO_set_fd(wb, fd_, BIO_NOCLOSE);
+ SSL_set_bio(ssl_, wb, wb);
+ return true;
+}
+
void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
DestructorGuard dg(this);
assert(eventBase_->isInEventBaseThread());
+ // Cache local and remote socket addresses to keep them available
+ // after socket file descriptor is closed.
+ if (cacheAddrOnFailure_ && -1 != getFd()) {
+ cacheLocalPeerAddr();
+ }
+
verifyPeer_ = verifyPeer;
// Make sure we're in the uninitialized state
- if (server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
+ if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
+ STATE_UNENCRYPTED) ||
+ handshakeCallback_ != nullptr) {
return invalidState(callback);
}
+ handshakeStartTime_ = std::chrono::steady_clock::now();
+ // Make end time at least >= start time.
+ handshakeEndTime_ = handshakeStartTime_;
+
sslState_ = STATE_CONNECTING;
handshakeCallback_ = callback;
return failHandshake(__func__, ex);
}
+ if (!setupSSLBio()) {
+ sslState_ = STATE_ERROR;
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio");
+ return failHandshake(__func__, ex);
+ }
+
applyVerificationOptions(ssl_);
- SSL_set_fd(ssl_, fd_);
if (sslSession_ != nullptr) {
SSL_set_session(ssl_, sslSession_);
SSL_SESSION_free(sslSession_);
return sslSession_;
}
+const SSL* AsyncSSLSocket::getSSL() const {
+ return ssl_;
+}
+
void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
sslSession_ = session;
if (!takeOwnership && session != nullptr) {
}
}
-void AsyncSSLSocket::getSelectedNextProtocol(const unsigned char** protoName,
- unsigned* protoLen) const {
- if (!getSelectedNextProtocolNoThrow(protoName, protoLen)) {
+void AsyncSSLSocket::getSelectedNextProtocol(
+ const unsigned char** protoName,
+ unsigned* protoLen,
+ SSLContext::NextProtocolType* protoType) const {
+ if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) {
throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
"NPN not supported");
}
}
bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
- const unsigned char** protoName,
- unsigned* protoLen) const {
+ const unsigned char** protoName,
+ unsigned* protoLen,
+ SSLContext::NextProtocolType* protoType) const {
*protoName = nullptr;
*protoLen = 0;
+#if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
+ SSL_get0_alpn_selected(ssl_, protoName, protoLen);
+ if (*protoLen > 0) {
+ if (protoType) {
+ *protoType = SSLContext::NextProtocolType::ALPN;
+ }
+ return true;
+ }
+#endif
#ifdef OPENSSL_NPN_NEGOTIATED
SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
+ if (protoType) {
+ *protoType = SSLContext::NextProtocolType::NPN;
+ }
return true;
#else
+ (void)protoType;
return false;
#endif
}
return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
}
+/* static */
+const char* AsyncSSLSocket::getSSLServerNameFromSSL(SSL* ssl) {
+ if (ssl == nullptr) {
+ return nullptr;
+ }
+#ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
+ return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
+#else
+ return nullptr;
+#endif
+}
+
const char *AsyncSSLSocket::getSSLServerName() const {
#ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
- return (ssl_ != nullptr) ? SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name)
- : nullptr;
+ return getSSLServerNameFromSSL(ssl_);
#else
throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
- "SNI not supported");
+ "SNI not supported");
#endif
}
const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
- try {
- return getSSLServerName();
- } catch (AsyncSocketException& ex) {
- return nullptr;
- }
+ return getSSLServerNameFromSSL(ssl_);
}
int AsyncSSLSocket::getSSLVersion() const {
return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
}
+const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
+ X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
+ if (cert) {
+ int nid = OBJ_obj2nid(cert->sig_alg->algorithm);
+ return OBJ_nid2ln(nid);
+ }
+ return nullptr;
+}
+
int AsyncSSLSocket::getSSLCertSize() const {
int certSize = 0;
X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
return certSize;
}
-bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
- int error = *errorOut = SSL_get_error(ssl_, ret);
+bool AsyncSSLSocket::willBlock(int ret,
+ int* sslErrorOut,
+ unsigned long* errErrorOut) noexcept {
+ *errErrorOut = 0;
+ int error = *sslErrorOut = SSL_get_error(ssl_, ret);
if (error == SSL_ERROR_WANT_READ) {
// Register for read event if not already.
updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
// The timeout (if set) keeps running here
return true;
#endif
+ } else if (0
#ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
- } else if (error == SSL_ERROR_WANT_RSA_ASYNC_PENDING) {
+ || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
+#endif
+#ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
+ || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
+#endif
+ ) {
// Our custom openssl function has kicked off an async request to do
- // modular exponentiation. When that call returns, a callback will
+ // rsa/ecdsa private key operation. When that call returns, a callback will
// be invoked that will re-call handleAccept.
- sslState_ = STATE_RSA_ASYNC_PENDING;
+ sslState_ = STATE_ASYNC_PENDING;
// Unregister for all events while blocked here
updateEventRegistration(
// The timeout (if set) keeps running here
return true;
-#endif
} else {
- // SSL_ERROR_ZERO_RETURN is processed here so we can get some detail
- // in the log
- long lastError = ERR_get_error();
+ unsigned long lastError = *errErrorOut = ERR_get_error();
VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
<< "state=" << state_ << ", "
<< "sslState=" << sslState_ << ", "
<< "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
<< "func: " << ERR_func_error_string(lastError) << ", "
<< "reason: " << ERR_reason_error_string(lastError);
- if (error != SSL_ERROR_SYSCALL) {
- if (error == SSL_ERROR_SSL) {
- *errorOut = lastError;
- }
- if ((unsigned long)lastError < 0x8000) {
- errno = ENOSYS;
- } else {
- errno = lastError;
- }
- }
- ERR_clear_error();
return false;
}
}
DestructorGuard dg(this);
assert(
sslState_ == STATE_CACHE_LOOKUP ||
- sslState_ == STATE_RSA_ASYNC_PENDING ||
+ sslState_ == STATE_ASYNC_PENDING ||
sslState_ == STATE_ERROR ||
sslState_ == STATE_CLOSED
);
<< ", fd=" << fd_ << "): " << e.what();
return failHandshake(__func__, ex);
}
- SSL_set_fd(ssl_, fd_);
+
+ if (!setupSSLBio()) {
+ sslState_ = STATE_ERROR;
+ AsyncSocketException ex(
+ AsyncSocketException::INTERNAL_ERROR, "error creating write bio");
+ return failHandshake(__func__, ex);
+ }
+
SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
applyVerificationOptions(ssl_);
}
if (server_ && parseClientHello_) {
- SSL_set_msg_callback_arg(ssl_, this);
SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
+ SSL_set_msg_callback_arg(ssl_, this);
}
- errno = 0;
int ret = SSL_accept(ssl_);
if (ret <= 0) {
- int error;
- if (willBlock(ret, &error)) {
+ int sslError;
+ unsigned long errError;
+ int errnoCopy = errno;
+ if (willBlock(ret, &sslError, &errError)) {
return;
} else {
sslState_ = STATE_ERROR;
- SSLException ex(error, errno);
+ SSLException ex(sslError, errError, ret, errnoCopy);
return failHandshake(__func__, ex);
}
}
sslState_ == STATE_CONNECTING);
assert(ssl_);
- errno = 0;
int ret = SSL_connect(ssl_);
if (ret <= 0) {
- int error;
- if (willBlock(ret, &error)) {
+ int sslError;
+ unsigned long errError;
+ int errnoCopy = errno;
+ if (willBlock(ret, &sslError, &errError)) {
return;
} else {
sslState_ = STATE_ERROR;
- SSLException ex(error, errno);
+ SSLException ex(sslError, errError, ret, errnoCopy);
return failHandshake(__func__, ex);
}
}
// STATE_CONNECTING.
sslState_ = STATE_ESTABLISHED;
- VLOG(3) << "AsyncSSLSocket %p: fd %d successfully connected; "
+ VLOG(3) << "AsyncSSLSocket " << this << ": "
+ << "fd " << fd_ << " successfully connected; "
<< "state=" << int(state_) << ", sslState=" << sslState_
<< ", events=" << eventFlags_;
AsyncSocket::handleInitialReadWrite();
}
+void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
+#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
+ // turn on the buffer movable in openssl
+ if (bufferMovableEnabled_ && ssl_ != nullptr && !isBufferMovable_ &&
+ callback != nullptr && callback->isBufferMovable()) {
+ SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
+ isBufferMovable_ = true;
+ }
+#endif
+
+ AsyncSocket::setReadCB(callback);
+}
+
+void AsyncSSLSocket::setBufferMovableEnabled(bool enabled) {
+ bufferMovableEnabled_ = enabled;
+}
+
+void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
+ CHECK(readCallback_);
+ if (isBufferMovable_) {
+ *buf = nullptr;
+ *buflen = 0;
+ } else {
+ // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP
+ readCallback_->getReadBuffer(buf, buflen);
+ }
+}
+
void
AsyncSSLSocket::handleRead() noexcept {
VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
AsyncSocket::handleRead();
}
-ssize_t
-AsyncSSLSocket::performRead(void* buf, size_t buflen) {
- errno = 0;
- ssize_t bytes = SSL_read(ssl_, buf, buflen);
+AsyncSocket::ReadResult
+AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
+ VLOG(4) << "AsyncSSLSocket::performRead() this=" << this << ", buf=" << *buf
+ << ", buflen=" << *buflen;
+
+ if (sslState_ == STATE_UNENCRYPTED) {
+ return AsyncSocket::performRead(buf, buflen, offset);
+ }
+
+ ssize_t bytes = 0;
+ if (!isBufferMovable_) {
+ bytes = SSL_read(ssl_, *buf, *buflen);
+ }
+#ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
+ else {
+ bytes = SSL_read_buf(ssl_, buf, (int *) offset, (int *) buflen);
+ }
+#endif
+
if (server_ && renegotiateAttempted_) {
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
- << ", sslstate=" << sslState_ << ", events=" << eventFlags_ << "): "
- << "client intitiated SSL renegotiation not permitted";
- // We pack our own SSLerr here with a dummy function
- errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
- SSL_CLIENT_RENEGOTIATION_ATTEMPT);
- ERR_clear_error();
- return READ_ERROR;
+ << ", sslstate=" << sslState_ << ", events=" << eventFlags_
+ << "): client intitiated SSL renegotiation not permitted";
+ return ReadResult(
+ READ_ERROR,
+ folly::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
}
if (bytes <= 0) {
int error = SSL_get_error(ssl_, bytes);
if (error == SSL_ERROR_WANT_READ) {
// The caller will register for read event if not already.
- return READ_BLOCKING;
+ if (errno == EWOULDBLOCK || errno == EAGAIN) {
+ return ReadResult(READ_BLOCKING);
+ } else {
+ return ReadResult(READ_ERROR);
+ }
} else if (error == SSL_ERROR_WANT_WRITE) {
// TODO: Even though we are attempting to read data, SSL_read() may
// need to write data if renegotiation is being performed. We currently
// don't support this and just fail the read.
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
- << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
- << "unsupported SSL renegotiation during read",
- errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
- SSL_INVALID_RENEGOTIATION);
- ERR_clear_error();
- return READ_ERROR;
+ << ", sslState=" << sslState_ << ", events=" << eventFlags_
+ << "): unsupported SSL renegotiation during read";
+ return ReadResult(
+ READ_ERROR,
+ folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
} else {
- // TODO: Fix this code so that it can return a proper error message
- // to the callback, rather than relying on AsyncSocket code which
- // can't handle SSL errors.
- long lastError = ERR_get_error();
-
+ if (zero_return(error, bytes)) {
+ return ReadResult(bytes);
+ }
+ long errError = ERR_get_error();
VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
<< "state=" << state_ << ", "
<< "sslState=" << sslState_ << ", "
<< "bytes: " << bytes << ", "
<< "error: " << error << ", "
<< "errno: " << errno << ", "
- << "func: " << ERR_func_error_string(lastError) << ", "
- << "reason: " << ERR_reason_error_string(lastError);
- ERR_clear_error();
- if (zero_return(error, bytes)) {
- return bytes;
- }
- if (error != SSL_ERROR_SYSCALL) {
- if ((unsigned long)lastError < 0x8000) {
- errno = ENOSYS;
- } else {
- errno = lastError;
- }
- }
- return READ_ERROR;
+ << "func: " << ERR_func_error_string(errError) << ", "
+ << "reason: " << ERR_reason_error_string(errError);
+ return ReadResult(
+ READ_ERROR,
+ folly::make_unique<SSLException>(error, errError, bytes, errno));
}
} else {
appBytesReceived_ += bytes;
- return bytes;
+ return ReadResult(bytes);
}
}
AsyncSocket::handleWrite();
}
-ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
- uint32_t count,
- WriteFlags flags,
- uint32_t* countWritten,
- uint32_t* partialWritten) {
+AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
+ if (error == SSL_ERROR_WANT_READ) {
+ // Even though we are attempting to write data, SSL_write() may
+ // need to read data if renegotiation is being performed. We currently
+ // don't support this and just fail the write.
+ LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
+ << ", sslState=" << sslState_ << ", events=" << eventFlags_
+ << "): "
+ << "unsupported SSL renegotiation during write";
+ return WriteResult(
+ WRITE_ERROR,
+ folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
+ } else {
+ if (zero_return(error, rc)) {
+ return WriteResult(0);
+ }
+ auto errError = ERR_get_error();
+ VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
+ << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
+ << "SSL error: " << error << ", errno: " << errno
+ << ", func: " << ERR_func_error_string(errError)
+ << ", reason: " << ERR_reason_error_string(errError);
+ return WriteResult(
+ WRITE_ERROR,
+ folly::make_unique<SSLException>(error, errError, rc, errno));
+ }
+}
+
+AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
+ const iovec* vec,
+ uint32_t count,
+ WriteFlags flags,
+ uint32_t* countWritten,
+ uint32_t* partialWritten) {
+ if (sslState_ == STATE_UNENCRYPTED) {
+ return AsyncSocket::performWrite(
+ vec, count, flags, countWritten, partialWritten);
+ }
if (sslState_ != STATE_ESTABLISHED) {
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
- << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
+ << ", sslState=" << sslState_
+ << ", events=" << eventFlags_ << "): "
<< "TODO: AsyncSSLSocket currently does not support calling "
<< "write() before the handshake has fully completed";
- errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
- SSL_EARLY_WRITE);
- return -1;
+ return WriteResult(
+ WRITE_ERROR, folly::make_unique<SSLException>(SSLError::EARLY_WRITE));
}
bool cork = isSet(flags, WriteFlags::CORK);
buf = ((const char*)v->iov_base) + offset;
ssize_t bytes;
- errno = 0;
uint32_t buffersStolen = 0;
if ((len < minWriteSize_) && ((i + 1) < count)) {
// Combine this buffer with part or all of the next buffers in
if (error == SSL_ERROR_WANT_WRITE) {
// The caller will register for write event if not already.
*partialWritten = offset;
- return totalWritten;
- } else if (error == SSL_ERROR_WANT_READ) {
- // TODO: Even though we are attempting to write data, SSL_write() may
- // need to read data if renegotiation is being performed. We currently
- // don't support this and just fail the write.
- LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
- << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
- << "unsupported SSL renegotiation during write",
- errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
- SSL_INVALID_RENEGOTIATION);
- ERR_clear_error();
- return -1;
- } else {
- // TODO: Fix this code so that it can return a proper error message
- // to the callback, rather than relying on AsyncSocket code which
- // can't handle SSL errors.
- long lastError = ERR_get_error();
- VLOG(3) <<
- "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
- << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
- << "SSL error: " << error << ", errno: " << errno
- << ", func: " << ERR_func_error_string(lastError)
- << ", reason: " << ERR_reason_error_string(lastError);
- if (error != SSL_ERROR_SYSCALL) {
- if ((unsigned long)lastError < 0x8000) {
- errno = ENOSYS;
- } else {
- errno = lastError;
- }
- }
- ERR_clear_error();
- if (!zero_return(error, bytes)) {
- return -1;
- } // else fall through to below to correctly record totalWritten
+ return WriteResult(totalWritten);
}
+ auto writeResult = interpretSSLError(bytes, error);
+ if (writeResult.writeReturn < 0) {
+ return writeResult;
+ } // else fall through to below to correctly record totalWritten
}
totalWritten += bytes;
v = &(vec[++i]);
}
*partialWritten = bytes;
- return totalWritten;
+ return WriteResult(totalWritten);
}
}
- return totalWritten;
+ return WriteResult(totalWritten);
}
int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
bool eor) {
- if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) {
+ if (eor && trackEor_) {
if (appEorByteNo_) {
// cannot track for more than one app byte EOR
CHECK(appEorByteNo_ == appBytesWritten_ + n);
return n;
}
-void
-AsyncSSLSocket::sslInfoCallback(const SSL *ssl, int where, int ret) {
+void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
sslSocket->renegotiateAttempted_ = true;
}
+ if (where & SSL_CB_READ_ALERT) {
+ const char* type = SSL_alert_type_string(ret);
+ if (type) {
+ const char* desc = SSL_alert_desc_string(ret);
+ sslSocket->alertsReceived_.emplace_back(
+ *type, StringPiece(desc, std::strlen(desc)));
+ }
+ }
}
-int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) {
+int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
int ret;
struct msghdr msg;
struct iovec iov;
int flags = 0;
- AsyncSSLSocket *tsslSock;
+ AsyncSSLSocket* tsslSock;
- iov.iov_base = const_cast<char *>(in);
+ iov.iov_base = const_cast<char*>(in);
iov.iov_len = inl;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
- tsslSock =
- reinterpret_cast<AsyncSSLSocket*>(BIO_get_app_data(b));
- if (tsslSock &&
- tsslSock->minEorRawByteNo_ &&
+ auto appData = BIO_get_app_data(b);
+ CHECK(appData);
+
+ tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
+ CHECK(tsslSock);
+
+ if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ &&
tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
flags = MSG_EOR;
}
- errno = 0;
- ret = sendmsg(b->num, &msg, flags);
+ ret = sendmsg(BIO_get_fd(b, nullptr), &msg, flags);
BIO_clear_retry_flags(b);
if (ret <= 0) {
if (BIO_sock_should_retry(ret))
BIO_set_retry_write(b);
}
- return(ret);
+ return ret;
}
int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
void AsyncSSLSocket::enableClientHelloParsing() {
parseClientHello_ = true;
- clientHelloInfo_.reset(new ClientHelloInfo());
+ clientHelloInfo_.reset(new ssl::ClientHelloInfo());
}
void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl) {
clientHelloInfo_->clientHelloBuf_.clear();
}
-void
-AsyncSSLSocket::clientHelloParsingCallback(int written, int version,
- int contentType, const void *buf, size_t len, SSL *ssl, void *arg)
-{
+void AsyncSSLSocket::clientHelloParsingCallback(int written,
+ int /* version */,
+ int contentType,
+ const void* buf,
+ size_t len,
+ SSL* ssl,
+ void* arg) {
AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
if (written != 0) {
sock->resetClientHelloParsing(ssl);
return;
}
if (contentType != SSL3_RT_HANDSHAKE) {
- sock->resetClientHelloParsing(ssl);
return;
}
if (len == 0) {
if (cursor.totalLength() > 0) {
uint16_t extensionsLength = cursor.readBE<uint16_t>();
while (extensionsLength) {
+ ssl::TLSExtension extensionType =
+ static_cast<ssl::TLSExtension>(cursor.readBE<uint16_t>());
sock->clientHelloInfo_->
- clientHelloExtensions_.push_back(cursor.readBE<uint16_t>());
+ clientHelloExtensions_.push_back(extensionType);
extensionsLength -= 2;
uint16_t extensionDataLength = cursor.readBE<uint16_t>();
extensionsLength -= 2;
- cursor.skip(extensionDataLength);
- extensionsLength -= extensionDataLength;
+
+ if (extensionType == ssl::TLSExtension::SIGNATURE_ALGORITHMS) {
+ cursor.skip(2);
+ extensionDataLength -= 2;
+ while (extensionDataLength) {
+ ssl::HashAlgorithm hashAlg =
+ static_cast<ssl::HashAlgorithm>(cursor.readBE<uint8_t>());
+ ssl::SignatureAlgorithm sigAlg =
+ static_cast<ssl::SignatureAlgorithm>(cursor.readBE<uint8_t>());
+ extensionDataLength -= 2;
+ sock->clientHelloInfo_->
+ clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg);
+ }
+ } else {
+ cursor.skip(extensionDataLength);
+ extensionsLength -= extensionDataLength;
+ }
}
}
} catch (std::out_of_range& e) {