2 * Copyright 2017 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/experimental/TestUtil.h>
24 #include <folly/io/async/AsyncSSLSocket.h>
25 #include <folly/io/async/AsyncServerSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/AsyncTimeout.h>
28 #include <folly/io/async/AsyncTransport.h>
29 #include <folly/io/async/EventBase.h>
30 #include <folly/io/async/ssl/SSLErrors.h>
31 #include <folly/io/async/test/TestSSLServer.h>
32 #include <folly/portability/GTest.h>
33 #include <folly/portability/Sockets.h>
34 #include <folly/portability/Unistd.h>
37 #include <sys/types.h>
38 #include <condition_variable>
44 // The destructors of all callback classes assert that the state is
45 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
46 // are responsible for setting the succeeded state properly before the
47 // destructors are called.
49 class WriteCallbackBase :
50 public AsyncTransportWrapper::WriteCallback {
53 : state(STATE_WAITING)
55 , exception(AsyncSocketException::UNKNOWN, "none") {}
57 ~WriteCallbackBase() {
58 EXPECT_EQ(STATE_SUCCEEDED, state);
62 const std::shared_ptr<AsyncSSLSocket> &socket) {
66 void writeSuccess() noexcept override {
67 std::cerr << "writeSuccess" << std::endl;
68 state = STATE_SUCCEEDED;
73 const AsyncSocketException& ex) noexcept override {
74 std::cerr << "writeError: bytesWritten " << nBytesWritten
75 << ", exception " << ex.what() << std::endl;
78 this->bytesWritten = nBytesWritten;
83 std::shared_ptr<AsyncSSLSocket> socket_;
86 AsyncSocketException exception;
89 class ReadCallbackBase :
90 public AsyncTransportWrapper::ReadCallback {
92 explicit ReadCallbackBase(WriteCallbackBase* wcb)
93 : wcb_(wcb), state(STATE_WAITING) {}
96 EXPECT_EQ(STATE_SUCCEEDED, state);
100 const std::shared_ptr<AsyncSSLSocket> &socket) {
104 void setState(StateEnum s) {
112 const AsyncSocketException& ex) noexcept override {
113 std::cerr << "readError " << ex.what() << std::endl;
114 state = STATE_FAILED;
118 void readEOF() noexcept override {
119 std::cerr << "readEOF" << std::endl;
124 std::shared_ptr<AsyncSSLSocket> socket_;
125 WriteCallbackBase *wcb_;
129 class ReadCallback : public ReadCallbackBase {
131 explicit ReadCallback(WriteCallbackBase *wcb)
132 : ReadCallbackBase(wcb)
136 for (std::vector<Buffer>::iterator it = buffers.begin();
141 currentBuffer.free();
144 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
145 if (!currentBuffer.buffer) {
146 currentBuffer.allocate(4096);
148 *bufReturn = currentBuffer.buffer;
149 *lenReturn = currentBuffer.length;
152 void readDataAvailable(size_t len) noexcept override {
153 std::cerr << "readDataAvailable, len " << len << std::endl;
155 currentBuffer.length = len;
157 wcb_->setSocket(socket_);
159 // Write back the same data.
160 socket_->write(wcb_, currentBuffer.buffer, len);
162 buffers.push_back(currentBuffer);
163 currentBuffer.reset();
164 state = STATE_SUCCEEDED;
169 Buffer() : buffer(nullptr), length(0) {}
170 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
176 void allocate(size_t len) {
177 assert(buffer == nullptr);
178 this->buffer = static_cast<char*>(malloc(len));
190 std::vector<Buffer> buffers;
191 Buffer currentBuffer;
194 class ReadErrorCallback : public ReadCallbackBase {
196 explicit ReadErrorCallback(WriteCallbackBase *wcb)
197 : ReadCallbackBase(wcb) {}
199 // Return nullptr buffer to trigger readError()
200 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
201 *bufReturn = nullptr;
205 void readDataAvailable(size_t /* len */) noexcept override {
206 // This should never to called.
211 const AsyncSocketException& ex) noexcept override {
212 ReadCallbackBase::readErr(ex);
213 std::cerr << "ReadErrorCallback::readError" << std::endl;
214 setState(STATE_SUCCEEDED);
218 class ReadEOFCallback : public ReadCallbackBase {
220 explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
222 // Return nullptr buffer to trigger readError()
223 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
224 *bufReturn = nullptr;
228 void readDataAvailable(size_t /* len */) noexcept override {
229 // This should never to called.
233 void readEOF() noexcept override {
234 ReadCallbackBase::readEOF();
235 setState(STATE_SUCCEEDED);
239 class WriteErrorCallback : public ReadCallback {
241 explicit WriteErrorCallback(WriteCallbackBase *wcb)
242 : ReadCallback(wcb) {}
244 void readDataAvailable(size_t len) noexcept override {
245 std::cerr << "readDataAvailable, len " << len << std::endl;
247 currentBuffer.length = len;
249 // close the socket before writing to trigger writeError().
250 ::close(socket_->getFd());
252 wcb_->setSocket(socket_);
254 // Write back the same data.
255 folly::test::msvcSuppressAbortOnInvalidParams([&] {
256 socket_->write(wcb_, currentBuffer.buffer, len);
259 if (wcb_->state == STATE_FAILED) {
260 setState(STATE_SUCCEEDED);
262 state = STATE_FAILED;
265 buffers.push_back(currentBuffer);
266 currentBuffer.reset();
269 void readErr(const AsyncSocketException& ex) noexcept override {
270 std::cerr << "readError " << ex.what() << std::endl;
271 // do nothing since this is expected
275 class EmptyReadCallback : public ReadCallback {
277 explicit EmptyReadCallback()
278 : ReadCallback(nullptr) {}
280 void readErr(const AsyncSocketException& ex) noexcept override {
281 std::cerr << "readError " << ex.what() << std::endl;
282 state = STATE_FAILED;
288 void readEOF() noexcept override {
289 std::cerr << "readEOF" << std::endl;
293 state = STATE_SUCCEEDED;
296 std::shared_ptr<AsyncSocket> tcpSocket_;
299 class HandshakeCallback :
300 public AsyncSSLSocket::HandshakeCB {
307 explicit HandshakeCallback(ReadCallbackBase *rcb,
308 ExpectType expect = EXPECT_SUCCESS):
309 state(STATE_WAITING),
314 const std::shared_ptr<AsyncSSLSocket> &socket) {
318 void setState(StateEnum s) {
323 // Functions inherited from AsyncSSLSocketHandshakeCallback
324 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
325 std::lock_guard<std::mutex> g(mutex_);
327 EXPECT_EQ(sock, socket_.get());
328 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
329 rcb_->setSocket(socket_);
330 sock->setReadCB(rcb_);
331 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
333 void handshakeErr(AsyncSSLSocket* /* sock */,
334 const AsyncSocketException& ex) noexcept override {
335 std::lock_guard<std::mutex> g(mutex_);
337 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
338 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
339 if (expect_ == EXPECT_ERROR) {
340 // rcb will never be invoked
341 rcb_->setState(STATE_SUCCEEDED);
343 errorString_ = ex.what();
346 void waitForHandshake() {
347 std::unique_lock<std::mutex> lock(mutex_);
348 cv_.wait(lock, [this] { return state != STATE_WAITING; });
351 ~HandshakeCallback() {
352 EXPECT_EQ(STATE_SUCCEEDED, state);
357 state = STATE_SUCCEEDED;
360 std::shared_ptr<AsyncSSLSocket> getSocket() {
365 std::shared_ptr<AsyncSSLSocket> socket_;
366 ReadCallbackBase *rcb_;
369 std::condition_variable cv_;
370 std::string errorString_;
373 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
377 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
378 uint32_t timeout = 0):
379 SSLServerAcceptCallbackBase(hcb),
382 virtual ~SSLServerAcceptCallback() {
384 // if we set a timeout, we expect failure
385 EXPECT_EQ(hcb_->state, STATE_FAILED);
386 hcb_->setState(STATE_SUCCEEDED);
390 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
392 const std::shared_ptr<folly::AsyncSSLSocket> &s)
394 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
395 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
397 hcb_->setSocket(sock);
398 sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
399 EXPECT_EQ(sock->getSSLState(),
400 AsyncSSLSocket::STATE_ACCEPTING);
402 state = STATE_SUCCEEDED;
406 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
408 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
409 SSLServerAcceptCallback(hcb) {}
411 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
413 const std::shared_ptr<folly::AsyncSSLSocket> &s)
416 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
418 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
420 int fd = sock->getFd();
424 // The accepted connection should already have TCP_NODELAY set
426 socklen_t valueLength = sizeof(value);
427 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
433 // Unset the TCP_NODELAY option.
435 socklen_t valueLength = sizeof(value);
436 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
439 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
443 SSLServerAcceptCallback::connAccepted(sock);
447 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
449 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
450 uint32_t timeout = 0):
451 SSLServerAcceptCallback(hcb, timeout) {}
453 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
455 const std::shared_ptr<folly::AsyncSSLSocket> &s)
457 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
459 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
461 hcb_->setSocket(sock);
462 sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
463 ASSERT_TRUE((sock->getSSLState() ==
464 AsyncSSLSocket::STATE_ACCEPTING) ||
465 (sock->getSSLState() ==
466 AsyncSSLSocket::STATE_CACHE_LOOKUP));
468 state = STATE_SUCCEEDED;
473 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
475 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
476 SSLServerAcceptCallbackBase(hcb) {}
478 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
480 const std::shared_ptr<folly::AsyncSSLSocket> &s)
482 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
484 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
486 // The first call to sslAccept() should succeed.
487 hcb_->setSocket(sock);
488 sock->sslAccept(hcb_);
489 EXPECT_EQ(sock->getSSLState(),
490 AsyncSSLSocket::STATE_ACCEPTING);
492 // The second call to sslAccept() should fail.
493 HandshakeCallback callback2(hcb_->rcb_);
494 callback2.setSocket(sock);
495 sock->sslAccept(&callback2);
496 EXPECT_EQ(sock->getSSLState(),
497 AsyncSSLSocket::STATE_ERROR);
499 // Both callbacks should be in the error state.
500 EXPECT_EQ(hcb_->state, STATE_FAILED);
501 EXPECT_EQ(callback2.state, STATE_FAILED);
503 state = STATE_SUCCEEDED;
504 hcb_->setState(STATE_SUCCEEDED);
505 callback2.setState(STATE_SUCCEEDED);
509 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
511 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
512 SSLServerAcceptCallbackBase(hcb) {}
514 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
516 const std::shared_ptr<folly::AsyncSSLSocket> &s)
518 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
520 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
522 hcb_->setSocket(sock);
523 sock->getEventBase()->tryRunAfterDelay([=] {
524 std::cerr << "Delayed SSL accept, client will have close by now"
526 // SSL accept will fail
529 AsyncSSLSocket::STATE_UNINIT);
530 hcb_->socket_->sslAccept(hcb_);
531 // This registers for an event
534 AsyncSSLSocket::STATE_ACCEPTING);
536 state = STATE_SUCCEEDED;
541 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
543 ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
544 // We don't care if we get invoked or not.
545 // The client may time out and give up before connAccepted() is even
547 state = STATE_SUCCEEDED;
550 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
552 const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
553 std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
555 // Just wait a while before closing the socket, so the client
556 // will time out waiting for the handshake to complete.
557 s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
561 class TestSSLAsyncCacheServer : public TestSSLServer {
563 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
564 int lookupDelay = 100) :
566 SSL_CTX *sslCtx = ctx_->getSSLCtx();
567 SSL_CTX_sess_set_get_cb(sslCtx,
568 TestSSLAsyncCacheServer::getSessionCallback);
569 SSL_CTX_set_session_cache_mode(
570 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
573 lookupDelay_ = lookupDelay;
576 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
577 uint32_t getAsyncLookups() const { return asyncLookups_; }
580 static uint32_t asyncCallbacks_;
581 static uint32_t asyncLookups_;
582 static uint32_t lookupDelay_;
584 static SSL_SESSION* getSessionCallback(SSL* ssl,
585 unsigned char* /* sess_id */,
591 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
592 if (!SSL_want_sess_cache_lookup(ssl)) {
593 // libssl.so mismatch
594 std::cerr << "no async support" << std::endl;
598 AsyncSSLSocket *sslSocket =
599 AsyncSSLSocket::getFromSSL(ssl);
600 assert(sslSocket != nullptr);
601 // Going to simulate an async cache by just running delaying the miss 100ms
602 if (asyncCallbacks_ % 2 == 0) {
603 // This socket is already blocked on lookup, return miss
604 std::cerr << "returning miss" << std::endl;
606 // fresh meat - block it
607 std::cerr << "async lookup" << std::endl;
608 sslSocket->getEventBase()->tryRunAfterDelay(
609 std::bind(&AsyncSSLSocket::restartSSLAccept,
610 sslSocket), lookupDelay_);
611 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
619 void getfds(int fds[2]);
622 std::shared_ptr<folly::SSLContext> clientCtx,
623 std::shared_ptr<folly::SSLContext> serverCtx);
626 EventBase* eventBase,
627 AsyncSSLSocket::UniquePtr* clientSock,
628 AsyncSSLSocket::UniquePtr* serverSock);
630 class BlockingWriteClient :
631 private AsyncSSLSocket::HandshakeCB,
632 private AsyncTransportWrapper::WriteCallback {
634 explicit BlockingWriteClient(
635 AsyncSSLSocket::UniquePtr socket)
636 : socket_(std::move(socket)),
640 buf_.reset(new uint8_t[bufLen_]);
641 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
646 iov_.reset(new struct iovec[iovCount_]);
647 for (uint32_t n = 0; n < iovCount_; ++n) {
648 iov_[n].iov_base = buf_.get() + n;
650 iov_[n].iov_len = n % bufLen_;
652 iov_[n].iov_len = bufLen_ - (n % bufLen_);
656 socket_->sslConn(this, std::chrono::milliseconds(100));
659 struct iovec* getIovec() const {
662 uint32_t getIovecCount() const {
667 void handshakeSuc(AsyncSSLSocket*) noexcept override {
668 socket_->writev(this, iov_.get(), iovCount_);
672 const AsyncSocketException& ex) noexcept override {
673 ADD_FAILURE() << "client handshake error: " << ex.what();
675 void writeSuccess() noexcept override {
680 const AsyncSocketException& ex) noexcept override {
681 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
685 AsyncSSLSocket::UniquePtr socket_;
688 std::unique_ptr<uint8_t[]> buf_;
689 std::unique_ptr<struct iovec[]> iov_;
692 class BlockingWriteServer :
693 private AsyncSSLSocket::HandshakeCB,
694 private AsyncTransportWrapper::ReadCallback {
696 explicit BlockingWriteServer(
697 AsyncSSLSocket::UniquePtr socket)
698 : socket_(std::move(socket)),
699 bufSize_(2500 * 2000),
701 buf_.reset(new uint8_t[bufSize_]);
702 socket_->sslAccept(this, std::chrono::milliseconds(100));
705 void checkBuffer(struct iovec* iov, uint32_t count) const {
707 for (uint32_t n = 0; n < count; ++n) {
708 size_t bytesLeft = bytesRead_ - idx;
709 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
710 std::min(iov[n].iov_len, bytesLeft));
712 FAIL() << "buffer mismatch at iovec " << n << "/" << count
716 if (iov[n].iov_len > bytesLeft) {
717 FAIL() << "server did not read enough data: "
718 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
719 << " in iovec " << n << "/" << count;
722 idx += iov[n].iov_len;
724 if (idx != bytesRead_) {
725 ADD_FAILURE() << "server read extra data: " << bytesRead_
726 << " bytes read; expected " << idx;
731 void handshakeSuc(AsyncSSLSocket*) noexcept override {
732 // Wait 10ms before reading, so the client's writes will initially block.
733 socket_->getEventBase()->tryRunAfterDelay(
734 [this] { socket_->setReadCB(this); }, 10);
738 const AsyncSocketException& ex) noexcept override {
739 ADD_FAILURE() << "server handshake error: " << ex.what();
741 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
742 *bufReturn = buf_.get() + bytesRead_;
743 *lenReturn = bufSize_ - bytesRead_;
745 void readDataAvailable(size_t len) noexcept override {
747 socket_->setReadCB(nullptr);
748 socket_->getEventBase()->tryRunAfterDelay(
749 [this] { socket_->setReadCB(this); }, 2);
751 void readEOF() noexcept override {
755 const AsyncSocketException& ex) noexcept override {
756 ADD_FAILURE() << "server read error: " << ex.what();
759 AsyncSSLSocket::UniquePtr socket_;
762 std::unique_ptr<uint8_t[]> buf_;
766 private AsyncSSLSocket::HandshakeCB,
767 private AsyncTransportWrapper::WriteCallback {
770 AsyncSSLSocket::UniquePtr socket)
771 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
772 socket_->sslConn(this);
775 const unsigned char* nextProto;
776 unsigned nextProtoLength;
777 SSLContext::NextProtocolType protocolType;
780 void handshakeSuc(AsyncSSLSocket*) noexcept override {
781 socket_->getSelectedNextProtocol(
782 &nextProto, &nextProtoLength, &protocolType);
786 const AsyncSocketException& ex) noexcept override {
787 ADD_FAILURE() << "client handshake error: " << ex.what();
789 void writeSuccess() noexcept override {
794 const AsyncSocketException& ex) noexcept override {
795 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
799 AsyncSSLSocket::UniquePtr socket_;
803 private AsyncSSLSocket::HandshakeCB,
804 private AsyncTransportWrapper::ReadCallback {
806 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
807 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
808 socket_->sslAccept(this);
811 const unsigned char* nextProto;
812 unsigned nextProtoLength;
813 SSLContext::NextProtocolType protocolType;
816 void handshakeSuc(AsyncSSLSocket*) noexcept override {
817 socket_->getSelectedNextProtocol(
818 &nextProto, &nextProtoLength, &protocolType);
822 const AsyncSocketException& ex) noexcept override {
823 ADD_FAILURE() << "server handshake error: " << ex.what();
825 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
828 void readDataAvailable(size_t /* len */) noexcept override {}
829 void readEOF() noexcept override {
833 const AsyncSocketException& ex) noexcept override {
834 ADD_FAILURE() << "server read error: " << ex.what();
837 AsyncSSLSocket::UniquePtr socket_;
840 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
841 public AsyncTransportWrapper::ReadCallback {
843 explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
844 : socket_(std::move(socket)) {
845 socket_->sslAccept(this);
848 ~RenegotiatingServer() {
849 socket_->setReadCB(nullptr);
852 void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
853 LOG(INFO) << "Renegotiating server handshake success";
854 socket_->setReadCB(this);
858 const AsyncSocketException& ex) noexcept override {
859 ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
861 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
862 *lenReturn = sizeof(buf);
865 void readDataAvailable(size_t /* len */) noexcept override {}
866 void readEOF() noexcept override {}
867 void readErr(const AsyncSocketException& ex) noexcept override {
868 LOG(INFO) << "server got read error " << ex.what();
869 auto exPtr = dynamic_cast<const SSLException*>(&ex);
870 ASSERT_NE(nullptr, exPtr);
871 std::string exStr(ex.what());
872 SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
873 ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
874 renegotiationError_ = true;
877 AsyncSSLSocket::UniquePtr socket_;
878 unsigned char buf[128];
879 bool renegotiationError_{false};
882 #ifndef OPENSSL_NO_TLSEXT
884 private AsyncSSLSocket::HandshakeCB,
885 private AsyncTransportWrapper::WriteCallback {
888 AsyncSSLSocket::UniquePtr socket)
889 : serverNameMatch(false), socket_(std::move(socket)) {
890 socket_->sslConn(this);
893 bool serverNameMatch;
896 void handshakeSuc(AsyncSSLSocket*) noexcept override {
897 serverNameMatch = socket_->isServerNameMatch();
901 const AsyncSocketException& ex) noexcept override {
902 ADD_FAILURE() << "client handshake error: " << ex.what();
904 void writeSuccess() noexcept override {
909 const AsyncSocketException& ex) noexcept override {
910 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
914 AsyncSSLSocket::UniquePtr socket_;
918 private AsyncSSLSocket::HandshakeCB,
919 private AsyncTransportWrapper::ReadCallback {
922 AsyncSSLSocket::UniquePtr socket,
923 const std::shared_ptr<folly::SSLContext>& ctx,
924 const std::shared_ptr<folly::SSLContext>& sniCtx,
925 const std::string& expectedServerName)
926 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
927 expectedServerName_(expectedServerName) {
928 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
929 std::placeholders::_1));
930 socket_->sslAccept(this);
933 bool serverNameMatch;
936 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
939 const AsyncSocketException& ex) noexcept override {
940 ADD_FAILURE() << "server handshake error: " << ex.what();
942 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
945 void readDataAvailable(size_t /* len */) noexcept override {}
946 void readEOF() noexcept override {
950 const AsyncSocketException& ex) noexcept override {
951 ADD_FAILURE() << "server read error: " << ex.what();
954 folly::SSLContext::ServerNameCallbackResult
955 serverNameCallback(SSL *ssl) {
956 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
959 !strcasecmp(expectedServerName_.c_str(), sn)) {
960 AsyncSSLSocket *sslSocket =
961 AsyncSSLSocket::getFromSSL(ssl);
962 sslSocket->switchServerSSLContext(sniCtx_);
963 serverNameMatch = true;
964 return folly::SSLContext::SERVER_NAME_FOUND;
966 serverNameMatch = false;
967 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
971 AsyncSSLSocket::UniquePtr socket_;
972 std::shared_ptr<folly::SSLContext> sniCtx_;
973 std::string expectedServerName_;
977 class SSLClient : public AsyncSocket::ConnectCallback,
978 public AsyncTransportWrapper::WriteCallback,
979 public AsyncTransportWrapper::ReadCallback
982 EventBase *eventBase_;
983 std::shared_ptr<AsyncSSLSocket> sslSocket_;
984 SSL_SESSION *session_;
985 std::shared_ptr<folly::SSLContext> ctx_;
987 folly::SocketAddress address_;
995 uint32_t writeAfterConnectErrors_;
997 // These settings test that we eventually drain the
998 // socket, even if the maxReadsPerEvent_ is hit during
999 // a event loop iteration.
1000 static constexpr size_t kMaxReadsPerEvent = 2;
1001 // 2 event loop iterations
1002 static constexpr size_t kMaxReadBufferSz =
1003 sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1006 SSLClient(EventBase *eventBase,
1007 const folly::SocketAddress& address,
1009 uint32_t timeout = 0)
1010 : eventBase_(eventBase),
1012 requests_(requests),
1019 writeAfterConnectErrors_(0) {
1020 ctx_.reset(new folly::SSLContext());
1021 ctx_->setOptions(SSL_OP_NO_TICKET);
1022 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1023 memset(buf_, 'a', sizeof(buf_));
1028 SSL_SESSION_free(session_);
1031 EXPECT_EQ(bytesRead_, sizeof(buf_));
1035 uint32_t getHit() const { return hit_; }
1037 uint32_t getMiss() const { return miss_; }
1039 uint32_t getErrors() const { return errors_; }
1041 uint32_t getWriteAfterConnectErrors() const {
1042 return writeAfterConnectErrors_;
1045 void connect(bool writeNow = false) {
1046 sslSocket_ = AsyncSSLSocket::newSocket(
1048 if (session_ != nullptr) {
1049 sslSocket_->setSSLSession(session_);
1052 sslSocket_->connect(this, address_, timeout_);
1053 if (sslSocket_ && writeNow) {
1054 // write some junk, used in an error test
1055 sslSocket_->write(this, buf_, sizeof(buf_));
1059 void connectSuccess() noexcept override {
1060 std::cerr << "client SSL socket connected" << std::endl;
1061 if (sslSocket_->getSSLSessionReused()) {
1065 if (session_ != nullptr) {
1066 SSL_SESSION_free(session_);
1068 session_ = sslSocket_->getSSLSession();
1072 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1073 sslSocket_->write(this, buf_, sizeof(buf_));
1074 sslSocket_->setReadCB(this);
1075 memset(readbuf_, 'b', sizeof(readbuf_));
1080 const AsyncSocketException& ex) noexcept override {
1081 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1086 void writeSuccess() noexcept override {
1087 std::cerr << "client write success" << std::endl;
1090 void writeErr(size_t /* bytesWritten */,
1091 const AsyncSocketException& ex) noexcept override {
1092 std::cerr << "client writeError: " << ex.what() << std::endl;
1094 writeAfterConnectErrors_++;
1098 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1099 *bufReturn = readbuf_ + bytesRead_;
1100 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1103 void readEOF() noexcept override {
1104 std::cerr << "client readEOF" << std::endl;
1108 const AsyncSocketException& ex) noexcept override {
1109 std::cerr << "client readError: " << ex.what() << std::endl;
1112 void readDataAvailable(size_t len) noexcept override {
1113 std::cerr << "client read data: " << len << std::endl;
1115 if (bytesRead_ == sizeof(buf_)) {
1116 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1117 sslSocket_->closeNow();
1119 if (requests_ != 0) {
1127 class SSLHandshakeBase :
1128 public AsyncSSLSocket::HandshakeCB,
1129 private AsyncTransportWrapper::WriteCallback {
1131 explicit SSLHandshakeBase(
1132 AsyncSSLSocket::UniquePtr socket,
1133 bool preverifyResult,
1134 bool verifyResult) :
1135 handshakeVerify_(false),
1136 handshakeSuccess_(false),
1137 handshakeError_(false),
1138 socket_(std::move(socket)),
1139 preverifyResult_(preverifyResult),
1140 verifyResult_(verifyResult) {
1143 AsyncSSLSocket::UniquePtr moveSocket() && {
1144 return std::move(socket_);
1147 bool handshakeVerify_;
1148 bool handshakeSuccess_;
1149 bool handshakeError_;
1150 std::chrono::nanoseconds handshakeTime;
1153 AsyncSSLSocket::UniquePtr socket_;
1154 bool preverifyResult_;
1157 // HandshakeCallback
1158 bool handshakeVer(AsyncSSLSocket* /* sock */,
1160 X509_STORE_CTX* /* ctx */) noexcept override {
1161 handshakeVerify_ = true;
1163 EXPECT_EQ(preverifyResult_, preverifyOk);
1164 return verifyResult_;
1167 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1168 LOG(INFO) << "Handshake success";
1169 handshakeSuccess_ = true;
1170 handshakeTime = socket_->getHandshakeTime();
1175 const AsyncSocketException& ex) noexcept override {
1176 LOG(INFO) << "Handshake error " << ex.what();
1177 handshakeError_ = true;
1178 handshakeTime = socket_->getHandshakeTime();
1182 void writeSuccess() noexcept override {
1187 size_t bytesWritten,
1188 const AsyncSocketException& ex) noexcept override {
1189 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1194 class SSLHandshakeClient : public SSLHandshakeBase {
1197 AsyncSSLSocket::UniquePtr socket,
1198 bool preverifyResult,
1199 bool verifyResult) :
1200 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1201 socket_->sslConn(this, std::chrono::milliseconds::zero());
1205 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1207 SSLHandshakeClientNoVerify(
1208 AsyncSSLSocket::UniquePtr socket,
1209 bool preverifyResult,
1210 bool verifyResult) :
1211 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1214 std::chrono::milliseconds::zero(),
1215 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1219 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1221 SSLHandshakeClientDoVerify(
1222 AsyncSSLSocket::UniquePtr socket,
1223 bool preverifyResult,
1224 bool verifyResult) :
1225 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1228 std::chrono::milliseconds::zero(),
1229 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1233 class SSLHandshakeServer : public SSLHandshakeBase {
1236 AsyncSSLSocket::UniquePtr socket,
1237 bool preverifyResult,
1239 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1240 socket_->sslAccept(this, std::chrono::milliseconds::zero());
1244 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1246 SSLHandshakeServerParseClientHello(
1247 AsyncSSLSocket::UniquePtr socket,
1248 bool preverifyResult,
1250 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1251 socket_->enableClientHelloParsing();
1252 socket_->sslAccept(this, std::chrono::milliseconds::zero());
1255 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1258 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1259 handshakeSuccess_ = true;
1260 sock->getSSLSharedCiphers(sharedCiphers_);
1261 sock->getSSLServerCiphers(serverCiphers_);
1262 sock->getSSLClientCiphers(clientCiphers_);
1263 chosenCipher_ = sock->getNegotiatedCipherName();
1268 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1270 SSLHandshakeServerNoVerify(
1271 AsyncSSLSocket::UniquePtr socket,
1272 bool preverifyResult,
1274 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1277 std::chrono::milliseconds::zero(),
1278 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1282 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1284 SSLHandshakeServerDoVerify(
1285 AsyncSSLSocket::UniquePtr socket,
1286 bool preverifyResult,
1288 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1291 std::chrono::milliseconds::zero(),
1292 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1296 class EventBaseAborter : public AsyncTimeout {
1298 EventBaseAborter(EventBase* eventBase,
1301 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1302 , eventBase_(eventBase) {
1303 scheduleTimeout(timeoutMS);
1306 void timeoutExpired() noexcept override {
1307 FAIL() << "test timed out";
1308 eventBase_->terminateLoopSoon();
1312 EventBase* eventBase_;