#include <signal.h>
#include <pthread.h>
-#include <folly/io/async/AsyncServerSocket.h>
+#include <folly/ExceptionWrapper.h>
+#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/AsyncTransport.h>
#include <folly/io/async/EventBase.h>
-#include <folly/io/async/AsyncTimeout.h>
-#include <folly/SocketAddress.h>
+#include <folly/io/async/ssl/SSLErrors.h>
#include <gtest/gtest.h>
#include <iostream>
, exception(AsyncSocketException::UNKNOWN, "none") {}
~WriteCallbackBase() {
- EXPECT_EQ(state, STATE_SUCCEEDED);
+ EXPECT_EQ(STATE_SUCCEEDED, state);
}
void setSocket(
class ReadCallbackBase :
public AsyncTransportWrapper::ReadCallback {
-public:
- explicit ReadCallbackBase(WriteCallbackBase *wcb)
- : wcb_(wcb)
- , state(STATE_WAITING) {}
+ public:
+ explicit ReadCallbackBase(WriteCallbackBase* wcb)
+ : wcb_(wcb), state(STATE_WAITING) {}
~ReadCallbackBase() {
EXPECT_EQ(state, STATE_SUCCEEDED);
}
};
+class ReadEOFCallback : public ReadCallbackBase {
+ public:
+ explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
+
+ // Return nullptr buffer to trigger readError()
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ *bufReturn = nullptr;
+ *lenReturn = 0;
+ }
+
+ void readDataAvailable(size_t /* len */) noexcept override {
+ // This should never to called.
+ FAIL();
+ }
+
+ void readEOF() noexcept override {
+ ReadCallbackBase::readEOF();
+ setState(STATE_SUCCEEDED);
+ }
+};
+
class WriteErrorCallback : public ReadCallback {
public:
explicit WriteErrorCallback(WriteCallbackBase *wcb)
state = STATE_SUCCEEDED;
}
+ std::shared_ptr<AsyncSSLSocket> getSocket() {
+ return socket_;
+ }
+
StateEnum state;
std::shared_ptr<AsyncSSLSocket> socket_;
ReadCallbackBase *rcb_;
AsyncSSLSocket::UniquePtr socket_;
};
+class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
+ public AsyncTransportWrapper::ReadCallback {
+ public:
+ explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
+ : socket_(std::move(socket)) {
+ socket_->sslAccept(this);
+ }
+
+ ~RenegotiatingServer() {
+ socket_->setReadCB(nullptr);
+ }
+
+ void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
+ LOG(INFO) << "Renegotiating server handshake success";
+ socket_->setReadCB(this);
+ }
+ void handshakeErr(
+ AsyncSSLSocket*,
+ const AsyncSocketException& ex) noexcept override {
+ ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
+ }
+ void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+ *lenReturn = sizeof(buf);
+ *bufReturn = buf;
+ }
+ void readDataAvailable(size_t /* len */) noexcept override {}
+ void readEOF() noexcept override {}
+ void readErr(const AsyncSocketException& ex) noexcept override {
+ LOG(INFO) << "server got read error " << ex.what();
+ auto exPtr = dynamic_cast<const SSLException*>(&ex);
+ ASSERT_NE(nullptr, exPtr);
+ std::string exStr(ex.what());
+ SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
+ ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
+ renegotiationError_ = true;
+ }
+
+ AsyncSSLSocket::UniquePtr socket_;
+ unsigned char buf[128];
+ bool renegotiationError_{false};
+};
+
#ifndef OPENSSL_NO_TLSEXT
class SNIClient :
private AsyncSSLSocket::HandshakeCB,
verifyResult_(verifyResult) {
}
+ AsyncSSLSocket::UniquePtr moveSocket() && {
+ return std::move(socket_);
+ }
+
bool handshakeVerify_;
bool handshakeSuccess_;
bool handshakeError_;
}
void handshakeSuc(AsyncSSLSocket*) noexcept override {
+ LOG(INFO) << "Handshake success";
handshakeSuccess_ = true;
handshakeTime = socket_->getHandshakeTime();
}
- void handshakeErr(AsyncSSLSocket*,
- const AsyncSocketException& /* ex */) noexcept override {
+ void handshakeErr(
+ AsyncSSLSocket*,
+ const AsyncSocketException& ex) noexcept override {
+ LOG(INFO) << "Handshake error " << ex.what();
handshakeError_ = true;
handshakeTime = socket_->getHandshakeTime();
}