/*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2016 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
#include <signal.h>
#include <pthread.h>
-#include <folly/io/async/AsyncServerSocket.h>
+#include <folly/ExceptionWrapper.h>
+#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/AsyncTransport.h>
#include <folly/io/async/EventBase.h>
-#include <folly/io/async/AsyncTimeout.h>
-#include <folly/SocketAddress.h>
+#include <folly/io/async/ssl/SSLErrors.h>
+#include <folly/portability/Sockets.h>
+#include <folly/portability/Unistd.h>
#include <gtest/gtest.h>
#include <iostream>
#include <list>
-#include <unistd.h>
#include <fcntl.h>
-#include <poll.h>
#include <sys/types.h>
-#include <sys/socket.h>
-#include <netinet/tcp.h>
namespace folly {
, exception(AsyncSocketException::UNKNOWN, "none") {}
~WriteCallbackBase() {
- EXPECT_EQ(state, STATE_SUCCEEDED);
+ EXPECT_EQ(STATE_SUCCEEDED, state);
}
void setSocket(
class ReadCallbackBase :
public AsyncTransportWrapper::ReadCallback {
-public:
- explicit ReadCallbackBase(WriteCallbackBase *wcb)
- : wcb_(wcb)
- , state(STATE_WAITING) {}
+ public:
+ explicit ReadCallbackBase(WriteCallbackBase* wcb)
+ : wcb_(wcb), state(STATE_WAITING) {}
~ReadCallbackBase() {
- EXPECT_EQ(state, STATE_SUCCEEDED);
+ EXPECT_EQ(STATE_SUCCEEDED, state);
}
void setSocket(
*lenReturn = 0;
}
- void readDataAvailable(size_t len) noexcept override {
+ void readDataAvailable(size_t /* len */) noexcept override {
// This should never to called.
FAIL();
}
}
};
+class ReadEOFCallback : public ReadCallbackBase {
+ public:
+ explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
+
+ // Return nullptr buffer to trigger readError()
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ *bufReturn = nullptr;
+ *lenReturn = 0;
+ }
+
+ void readDataAvailable(size_t /* len */) noexcept override {
+ // This should never to called.
+ FAIL();
+ }
+
+ void readEOF() noexcept override {
+ ReadCallbackBase::readEOF();
+ setState(STATE_SUCCEEDED);
+ }
+};
+
class WriteErrorCallback : public ReadCallback {
public:
explicit WriteErrorCallback(WriteCallbackBase *wcb)
// Functions inherited from AsyncSSLSocketHandshakeCallback
void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
+ std::lock_guard<std::mutex> g(mutex_);
+ cv_.notify_all();
EXPECT_EQ(sock, socket_.get());
std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
rcb_->setSocket(socket_);
sock->setReadCB(rcb_);
state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
}
- void handshakeErr(
- AsyncSSLSocket *sock,
- const AsyncSocketException& ex) noexcept override {
+ void handshakeErr(AsyncSSLSocket* /* sock */,
+ const AsyncSocketException& ex) noexcept override {
+ std::lock_guard<std::mutex> g(mutex_);
+ cv_.notify_all();
std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
if (expect_ == EXPECT_ERROR) {
// rcb will never be invoked
rcb_->setState(STATE_SUCCEEDED);
}
+ errorString_ = ex.what();
+ }
+
+ void waitForHandshake() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ cv_.wait(lock, [this] { return state != STATE_WAITING; });
}
~HandshakeCallback() {
- EXPECT_EQ(state, STATE_SUCCEEDED);
+ EXPECT_EQ(STATE_SUCCEEDED, state);
}
void closeSocket() {
state = STATE_SUCCEEDED;
}
+ std::shared_ptr<AsyncSSLSocket> getSocket() {
+ return socket_;
+ }
+
StateEnum state;
std::shared_ptr<AsyncSSLSocket> socket_;
ReadCallbackBase *rcb_;
ExpectType expect_;
+ std::mutex mutex_;
+ std::condition_variable cv_;
+ std::string errorString_;
};
class SSLServerAcceptCallbackBase:
state(STATE_WAITING), hcb_(hcb) {}
~SSLServerAcceptCallbackBase() {
- EXPECT_EQ(state, STATE_SUCCEEDED);
+ EXPECT_EQ(STATE_SUCCEEDED, state);
}
void acceptError(const std::exception& ex) noexcept override {
state = STATE_FAILED;
}
- void connectionAccepted(int fd, const folly::SocketAddress& clientAddr)
- noexcept override{
+ void connectionAccepted(
+ int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
printf("Connection accepted\n");
std::shared_ptr<AsyncSSLSocket> sslSock;
try {
}
};
+class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
+ public:
+ ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
+ // We don't care if we get invoked or not.
+ // The client may time out and give up before connAccepted() is even
+ // called.
+ state = STATE_SUCCEEDED;
+ }
+
+ // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
+ void connAccepted(
+ const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
+ std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
+
+ // Just wait a while before closing the socket, so the client
+ // will time out waiting for the handshake to complete.
+ s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
+ }
+};
class TestSSLServer {
protected:
EventBase evb_;
std::shared_ptr<folly::SSLContext> ctx_;
SSLServerAcceptCallbackBase *acb_;
- folly::AsyncServerSocket *socket_;
+ std::shared_ptr<folly::AsyncServerSocket> socket_;
folly::SocketAddress address_;
pthread_t thread_;
public:
// Create a TestSSLServer.
// This immediately starts listening on the given port.
- explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
+ explicit TestSSLServer(
+ SSLServerAcceptCallbackBase* acb,
+ bool enableTFO = false);
// Kill the thread.
~TestSSLServer() {
static uint32_t asyncLookups_;
static uint32_t lookupDelay_;
- static SSL_SESSION *getSessionCallback(SSL *ssl,
- unsigned char *sess_id,
- int id_len,
- int *copyflag) {
+ static SSL_SESSION* getSessionCallback(SSL* ssl,
+ unsigned char* /* sess_id */,
+ int /* id_len */,
+ int* copyflag) {
*copyflag = 0;
asyncCallbacks_++;
#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
const unsigned char* nextProto;
unsigned nextProtoLength;
+ SSLContext::NextProtocolType protocolType;
+
private:
void handshakeSuc(AsyncSSLSocket*) noexcept override {
- socket_->getSelectedNextProtocol(&nextProto,
- &nextProtoLength);
+ socket_->getSelectedNextProtocol(
+ &nextProto, &nextProtoLength, &protocolType);
}
void handshakeErr(
AsyncSSLSocket*,
const unsigned char* nextProto;
unsigned nextProtoLength;
+ SSLContext::NextProtocolType protocolType;
+
private:
void handshakeSuc(AsyncSSLSocket*) noexcept override {
- socket_->getSelectedNextProtocol(&nextProto,
- &nextProtoLength);
+ socket_->getSelectedNextProtocol(
+ &nextProto, &nextProtoLength, &protocolType);
}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "server handshake error: " << ex.what();
}
- void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
*lenReturn = 0;
}
- void readDataAvailable(size_t len) noexcept override {
- }
+ void readDataAvailable(size_t /* len */) noexcept override {}
void readEOF() noexcept override {
socket_->close();
}
AsyncSSLSocket::UniquePtr socket_;
};
+class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
+ public AsyncTransportWrapper::ReadCallback {
+ public:
+ explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
+ : socket_(std::move(socket)) {
+ socket_->sslAccept(this);
+ }
+
+ ~RenegotiatingServer() {
+ socket_->setReadCB(nullptr);
+ }
+
+ void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
+ LOG(INFO) << "Renegotiating server handshake success";
+ socket_->setReadCB(this);
+ }
+ void handshakeErr(
+ AsyncSSLSocket*,
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
+ }
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ *lenReturn = sizeof(buf);
+ *bufReturn = buf;
+ }
+ void readDataAvailable(size_t /* len */) noexcept override {}
+ void readEOF() noexcept override {}
+ void readErr(const AsyncSocketException& ex) noexcept override {
+ LOG(INFO) << "server got read error " << ex.what();
+ auto exPtr = dynamic_cast<const SSLException*>(&ex);
+ ASSERT_NE(nullptr, exPtr);
+ std::string exStr(ex.what());
+ SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
+ ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
+ renegotiationError_ = true;
+ }
+
+ AsyncSSLSocket::UniquePtr socket_;
+ unsigned char buf[128];
+ bool renegotiationError_{false};
+};
+
#ifndef OPENSSL_NO_TLSEXT
class SNIClient :
private AsyncSSLSocket::HandshakeCB,
bool serverNameMatch;
private:
- void handshakeSuc(AsyncSSLSocket* ssl) noexcept override {}
+ void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "server handshake error: " << ex.what();
}
- void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
*lenReturn = 0;
}
- void readDataAvailable(size_t len) noexcept override {
- }
+ void readDataAvailable(size_t /* len */) noexcept override {}
void readEOF() noexcept override {
socket_->close();
}
std::cerr << "client write success" << std::endl;
}
- void writeErr(
- size_t bytesWritten,
- const AsyncSocketException& ex)
- noexcept override {
+ void writeErr(size_t /* bytesWritten */,
+ const AsyncSocketException& ex) noexcept override {
std::cerr << "client writeError: " << ex.what() << std::endl;
if (!sslSocket_) {
writeAfterConnectErrors_++;
verifyResult_(verifyResult) {
}
+ AsyncSSLSocket::UniquePtr moveSocket() && {
+ return std::move(socket_);
+ }
+
bool handshakeVerify_;
bool handshakeSuccess_;
bool handshakeError_;
+ std::chrono::nanoseconds handshakeTime;
protected:
AsyncSSLSocket::UniquePtr socket_;
bool verifyResult_;
// HandshakeCallback
- bool handshakeVer(
- AsyncSSLSocket* sock,
- bool preverifyOk,
- X509_STORE_CTX* ctx) noexcept override {
+ bool handshakeVer(AsyncSSLSocket* /* sock */,
+ bool preverifyOk,
+ X509_STORE_CTX* /* ctx */) noexcept override {
handshakeVerify_ = true;
EXPECT_EQ(preverifyResult_, preverifyOk);
}
void handshakeSuc(AsyncSSLSocket*) noexcept override {
+ LOG(INFO) << "Handshake success";
handshakeSuccess_ = true;
+ handshakeTime = socket_->getHandshakeTime();
}
void handshakeErr(
- AsyncSSLSocket*,
- const AsyncSocketException& ex) noexcept override {
+ AsyncSSLSocket*,
+ const AsyncSocketException& ex) noexcept override {
+ LOG(INFO) << "Handshake error " << ex.what();
handshakeError_ = true;
+ handshakeTime = socket_->getHandshakeTime();
}
// WriteCallback