2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/experimental/TestUtil.h>
24 #include <folly/io/async/AsyncSSLSocket.h>
25 #include <folly/io/async/AsyncServerSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/AsyncTimeout.h>
28 #include <folly/io/async/AsyncTransport.h>
29 #include <folly/io/async/EventBase.h>
30 #include <folly/io/async/ssl/SSLErrors.h>
31 #include <folly/portability/GTest.h>
32 #include <folly/portability/Sockets.h>
33 #include <folly/portability/Unistd.h>
36 #include <sys/types.h>
37 #include <condition_variable>
49 // The destructors of all callback classes assert that the state is
50 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
51 // are responsible for setting the succeeded state properly before the
52 // destructors are called.
54 class WriteCallbackBase :
55 public AsyncTransportWrapper::WriteCallback {
58 : state(STATE_WAITING)
60 , exception(AsyncSocketException::UNKNOWN, "none") {}
62 ~WriteCallbackBase() {
63 EXPECT_EQ(STATE_SUCCEEDED, state);
67 const std::shared_ptr<AsyncSSLSocket> &socket) {
71 void writeSuccess() noexcept override {
72 std::cerr << "writeSuccess" << std::endl;
73 state = STATE_SUCCEEDED;
78 const AsyncSocketException& ex) noexcept override {
79 std::cerr << "writeError: bytesWritten " << nBytesWritten
80 << ", exception " << ex.what() << std::endl;
83 this->bytesWritten = nBytesWritten;
86 socket_->detachEventBase();
89 std::shared_ptr<AsyncSSLSocket> socket_;
92 AsyncSocketException exception;
95 class ReadCallbackBase :
96 public AsyncTransportWrapper::ReadCallback {
98 explicit ReadCallbackBase(WriteCallbackBase* wcb)
99 : wcb_(wcb), state(STATE_WAITING) {}
101 ~ReadCallbackBase() {
102 EXPECT_EQ(STATE_SUCCEEDED, state);
106 const std::shared_ptr<AsyncSSLSocket> &socket) {
110 void setState(StateEnum s) {
118 const AsyncSocketException& ex) noexcept override {
119 std::cerr << "readError " << ex.what() << std::endl;
120 state = STATE_FAILED;
122 socket_->detachEventBase();
125 void readEOF() noexcept override {
126 std::cerr << "readEOF" << std::endl;
129 socket_->detachEventBase();
132 std::shared_ptr<AsyncSSLSocket> socket_;
133 WriteCallbackBase *wcb_;
137 class ReadCallback : public ReadCallbackBase {
139 explicit ReadCallback(WriteCallbackBase *wcb)
140 : ReadCallbackBase(wcb)
144 for (std::vector<Buffer>::iterator it = buffers.begin();
149 currentBuffer.free();
152 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
153 if (!currentBuffer.buffer) {
154 currentBuffer.allocate(4096);
156 *bufReturn = currentBuffer.buffer;
157 *lenReturn = currentBuffer.length;
160 void readDataAvailable(size_t len) noexcept override {
161 std::cerr << "readDataAvailable, len " << len << std::endl;
163 currentBuffer.length = len;
165 wcb_->setSocket(socket_);
167 // Write back the same data.
168 socket_->write(wcb_, currentBuffer.buffer, len);
170 buffers.push_back(currentBuffer);
171 currentBuffer.reset();
172 state = STATE_SUCCEEDED;
177 Buffer() : buffer(nullptr), length(0) {}
178 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
184 void allocate(size_t len) {
185 assert(buffer == nullptr);
186 this->buffer = static_cast<char*>(malloc(len));
198 std::vector<Buffer> buffers;
199 Buffer currentBuffer;
202 class ReadErrorCallback : public ReadCallbackBase {
204 explicit ReadErrorCallback(WriteCallbackBase *wcb)
205 : ReadCallbackBase(wcb) {}
207 // Return nullptr buffer to trigger readError()
208 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
209 *bufReturn = nullptr;
213 void readDataAvailable(size_t /* len */) noexcept override {
214 // This should never to called.
219 const AsyncSocketException& ex) noexcept override {
220 ReadCallbackBase::readErr(ex);
221 std::cerr << "ReadErrorCallback::readError" << std::endl;
222 setState(STATE_SUCCEEDED);
226 class ReadEOFCallback : public ReadCallbackBase {
228 explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
230 // Return nullptr buffer to trigger readError()
231 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
232 *bufReturn = nullptr;
236 void readDataAvailable(size_t /* len */) noexcept override {
237 // This should never to called.
241 void readEOF() noexcept override {
242 ReadCallbackBase::readEOF();
243 setState(STATE_SUCCEEDED);
247 class WriteErrorCallback : public ReadCallback {
249 explicit WriteErrorCallback(WriteCallbackBase *wcb)
250 : ReadCallback(wcb) {}
252 void readDataAvailable(size_t len) noexcept override {
253 std::cerr << "readDataAvailable, len " << len << std::endl;
255 currentBuffer.length = len;
257 // close the socket before writing to trigger writeError().
258 ::close(socket_->getFd());
260 wcb_->setSocket(socket_);
262 // Write back the same data.
263 folly::test::msvcSuppressAbortOnInvalidParams([&] {
264 socket_->write(wcb_, currentBuffer.buffer, len);
267 if (wcb_->state == STATE_FAILED) {
268 setState(STATE_SUCCEEDED);
270 state = STATE_FAILED;
273 buffers.push_back(currentBuffer);
274 currentBuffer.reset();
277 void readErr(const AsyncSocketException& ex) noexcept override {
278 std::cerr << "readError " << ex.what() << std::endl;
279 // do nothing since this is expected
283 class EmptyReadCallback : public ReadCallback {
285 explicit EmptyReadCallback()
286 : ReadCallback(nullptr) {}
288 void readErr(const AsyncSocketException& ex) noexcept override {
289 std::cerr << "readError " << ex.what() << std::endl;
290 state = STATE_FAILED;
292 tcpSocket_->detachEventBase();
295 void readEOF() noexcept override {
296 std::cerr << "readEOF" << std::endl;
299 tcpSocket_->detachEventBase();
300 state = STATE_SUCCEEDED;
303 std::shared_ptr<AsyncSocket> tcpSocket_;
306 class HandshakeCallback :
307 public AsyncSSLSocket::HandshakeCB {
314 explicit HandshakeCallback(ReadCallbackBase *rcb,
315 ExpectType expect = EXPECT_SUCCESS):
316 state(STATE_WAITING),
321 const std::shared_ptr<AsyncSSLSocket> &socket) {
325 void setState(StateEnum s) {
330 // Functions inherited from AsyncSSLSocketHandshakeCallback
331 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
332 std::lock_guard<std::mutex> g(mutex_);
334 EXPECT_EQ(sock, socket_.get());
335 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
336 rcb_->setSocket(socket_);
337 sock->setReadCB(rcb_);
338 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
340 void handshakeErr(AsyncSSLSocket* /* sock */,
341 const AsyncSocketException& ex) noexcept override {
342 std::lock_guard<std::mutex> g(mutex_);
344 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
345 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
346 if (expect_ == EXPECT_ERROR) {
347 // rcb will never be invoked
348 rcb_->setState(STATE_SUCCEEDED);
350 errorString_ = ex.what();
353 void waitForHandshake() {
354 std::unique_lock<std::mutex> lock(mutex_);
355 cv_.wait(lock, [this] { return state != STATE_WAITING; });
358 ~HandshakeCallback() {
359 EXPECT_EQ(STATE_SUCCEEDED, state);
364 state = STATE_SUCCEEDED;
367 std::shared_ptr<AsyncSSLSocket> getSocket() {
372 std::shared_ptr<AsyncSSLSocket> socket_;
373 ReadCallbackBase *rcb_;
376 std::condition_variable cv_;
377 std::string errorString_;
380 class SSLServerAcceptCallbackBase:
381 public folly::AsyncServerSocket::AcceptCallback {
383 explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
384 state(STATE_WAITING), hcb_(hcb) {}
386 ~SSLServerAcceptCallbackBase() {
387 EXPECT_EQ(STATE_SUCCEEDED, state);
390 void acceptError(const std::exception& ex) noexcept override {
391 std::cerr << "SSLServerAcceptCallbackBase::acceptError "
392 << ex.what() << std::endl;
393 state = STATE_FAILED;
396 void connectionAccepted(
397 int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
398 printf("Connection accepted\n");
399 std::shared_ptr<AsyncSSLSocket> sslSock;
401 // Create a AsyncSSLSocket object with the fd. The socket should be
402 // added to the event base and in the state of accepting SSL connection.
403 sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
404 } catch (const std::exception &e) {
405 LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
406 "object with socket " << e.what() << fd;
412 connAccepted(sslSock);
415 virtual void connAccepted(
416 const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
419 HandshakeCallback *hcb_;
420 std::shared_ptr<folly::SSLContext> ctx_;
421 folly::EventBase* base_;
424 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
428 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
429 uint32_t timeout = 0):
430 SSLServerAcceptCallbackBase(hcb),
433 virtual ~SSLServerAcceptCallback() {
435 // if we set a timeout, we expect failure
436 EXPECT_EQ(hcb_->state, STATE_FAILED);
437 hcb_->setState(STATE_SUCCEEDED);
441 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
443 const std::shared_ptr<folly::AsyncSSLSocket> &s)
445 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
446 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
448 hcb_->setSocket(sock);
449 sock->sslAccept(hcb_, timeout_);
450 EXPECT_EQ(sock->getSSLState(),
451 AsyncSSLSocket::STATE_ACCEPTING);
453 state = STATE_SUCCEEDED;
457 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
459 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
460 SSLServerAcceptCallback(hcb) {}
462 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
464 const std::shared_ptr<folly::AsyncSSLSocket> &s)
467 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
469 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
471 int fd = sock->getFd();
475 // The accepted connection should already have TCP_NODELAY set
477 socklen_t valueLength = sizeof(value);
478 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
484 // Unset the TCP_NODELAY option.
486 socklen_t valueLength = sizeof(value);
487 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
490 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
494 SSLServerAcceptCallback::connAccepted(sock);
498 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
500 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
501 uint32_t timeout = 0):
502 SSLServerAcceptCallback(hcb, timeout) {}
504 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
506 const std::shared_ptr<folly::AsyncSSLSocket> &s)
508 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
510 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
512 hcb_->setSocket(sock);
513 sock->sslAccept(hcb_, timeout_);
514 ASSERT_TRUE((sock->getSSLState() ==
515 AsyncSSLSocket::STATE_ACCEPTING) ||
516 (sock->getSSLState() ==
517 AsyncSSLSocket::STATE_CACHE_LOOKUP));
519 state = STATE_SUCCEEDED;
524 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
526 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
527 SSLServerAcceptCallbackBase(hcb) {}
529 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
531 const std::shared_ptr<folly::AsyncSSLSocket> &s)
533 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
535 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
537 // The first call to sslAccept() should succeed.
538 hcb_->setSocket(sock);
539 sock->sslAccept(hcb_);
540 EXPECT_EQ(sock->getSSLState(),
541 AsyncSSLSocket::STATE_ACCEPTING);
543 // The second call to sslAccept() should fail.
544 HandshakeCallback callback2(hcb_->rcb_);
545 callback2.setSocket(sock);
546 sock->sslAccept(&callback2);
547 EXPECT_EQ(sock->getSSLState(),
548 AsyncSSLSocket::STATE_ERROR);
550 // Both callbacks should be in the error state.
551 EXPECT_EQ(hcb_->state, STATE_FAILED);
552 EXPECT_EQ(callback2.state, STATE_FAILED);
554 sock->detachEventBase();
556 state = STATE_SUCCEEDED;
557 hcb_->setState(STATE_SUCCEEDED);
558 callback2.setState(STATE_SUCCEEDED);
562 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
564 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
565 SSLServerAcceptCallbackBase(hcb) {}
567 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
569 const std::shared_ptr<folly::AsyncSSLSocket> &s)
571 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
573 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
575 hcb_->setSocket(sock);
576 sock->getEventBase()->tryRunAfterDelay([=] {
577 std::cerr << "Delayed SSL accept, client will have close by now"
579 // SSL accept will fail
582 AsyncSSLSocket::STATE_UNINIT);
583 hcb_->socket_->sslAccept(hcb_);
584 // This registers for an event
587 AsyncSSLSocket::STATE_ACCEPTING);
589 state = STATE_SUCCEEDED;
594 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
596 ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
597 // We don't care if we get invoked or not.
598 // The client may time out and give up before connAccepted() is even
600 state = STATE_SUCCEEDED;
603 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
605 const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
606 std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
608 // Just wait a while before closing the socket, so the client
609 // will time out waiting for the handshake to complete.
610 s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
614 class TestSSLServer {
617 std::shared_ptr<folly::SSLContext> ctx_;
618 SSLServerAcceptCallbackBase *acb_;
619 std::shared_ptr<folly::AsyncServerSocket> socket_;
620 folly::SocketAddress address_;
623 static void *Main(void *ctx) {
624 TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
626 std::cerr << "Server thread exited event loop" << std::endl;
631 // Create a TestSSLServer.
632 // This immediately starts listening on the given port.
633 explicit TestSSLServer(
634 SSLServerAcceptCallbackBase* acb,
635 bool enableTFO = false);
639 evb_.runInEventBaseThread([&](){
640 socket_->stopAccepting();
642 std::cerr << "Waiting for server thread to exit" << std::endl;
643 pthread_join(thread_, nullptr);
646 EventBase &getEventBase() { return evb_; }
648 const folly::SocketAddress& getAddress() const {
653 class TestSSLAsyncCacheServer : public TestSSLServer {
655 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
656 int lookupDelay = 100) :
658 SSL_CTX *sslCtx = ctx_->getSSLCtx();
659 SSL_CTX_sess_set_get_cb(sslCtx,
660 TestSSLAsyncCacheServer::getSessionCallback);
661 SSL_CTX_set_session_cache_mode(
662 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
665 lookupDelay_ = lookupDelay;
668 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
669 uint32_t getAsyncLookups() const { return asyncLookups_; }
672 static uint32_t asyncCallbacks_;
673 static uint32_t asyncLookups_;
674 static uint32_t lookupDelay_;
676 static SSL_SESSION* getSessionCallback(SSL* ssl,
677 unsigned char* /* sess_id */,
682 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
683 if (!SSL_want_sess_cache_lookup(ssl)) {
684 // libssl.so mismatch
685 std::cerr << "no async support" << std::endl;
689 AsyncSSLSocket *sslSocket =
690 AsyncSSLSocket::getFromSSL(ssl);
691 assert(sslSocket != nullptr);
692 // Going to simulate an async cache by just running delaying the miss 100ms
693 if (asyncCallbacks_ % 2 == 0) {
694 // This socket is already blocked on lookup, return miss
695 std::cerr << "returning miss" << std::endl;
697 // fresh meat - block it
698 std::cerr << "async lookup" << std::endl;
699 sslSocket->getEventBase()->tryRunAfterDelay(
700 std::bind(&AsyncSSLSocket::restartSSLAccept,
701 sslSocket), lookupDelay_);
702 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
710 void getfds(int fds[2]);
713 std::shared_ptr<folly::SSLContext> clientCtx,
714 std::shared_ptr<folly::SSLContext> serverCtx);
717 EventBase* eventBase,
718 AsyncSSLSocket::UniquePtr* clientSock,
719 AsyncSSLSocket::UniquePtr* serverSock);
721 class BlockingWriteClient :
722 private AsyncSSLSocket::HandshakeCB,
723 private AsyncTransportWrapper::WriteCallback {
725 explicit BlockingWriteClient(
726 AsyncSSLSocket::UniquePtr socket)
727 : socket_(std::move(socket)),
731 buf_.reset(new uint8_t[bufLen_]);
732 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
737 iov_.reset(new struct iovec[iovCount_]);
738 for (uint32_t n = 0; n < iovCount_; ++n) {
739 iov_[n].iov_base = buf_.get() + n;
741 iov_[n].iov_len = n % bufLen_;
743 iov_[n].iov_len = bufLen_ - (n % bufLen_);
747 socket_->sslConn(this, 100);
750 struct iovec* getIovec() const {
753 uint32_t getIovecCount() const {
758 void handshakeSuc(AsyncSSLSocket*) noexcept override {
759 socket_->writev(this, iov_.get(), iovCount_);
763 const AsyncSocketException& ex) noexcept override {
764 ADD_FAILURE() << "client handshake error: " << ex.what();
766 void writeSuccess() noexcept override {
771 const AsyncSocketException& ex) noexcept override {
772 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
776 AsyncSSLSocket::UniquePtr socket_;
779 std::unique_ptr<uint8_t[]> buf_;
780 std::unique_ptr<struct iovec[]> iov_;
783 class BlockingWriteServer :
784 private AsyncSSLSocket::HandshakeCB,
785 private AsyncTransportWrapper::ReadCallback {
787 explicit BlockingWriteServer(
788 AsyncSSLSocket::UniquePtr socket)
789 : socket_(std::move(socket)),
790 bufSize_(2500 * 2000),
792 buf_.reset(new uint8_t[bufSize_]);
793 socket_->sslAccept(this, 100);
796 void checkBuffer(struct iovec* iov, uint32_t count) const {
798 for (uint32_t n = 0; n < count; ++n) {
799 size_t bytesLeft = bytesRead_ - idx;
800 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
801 std::min(iov[n].iov_len, bytesLeft));
803 FAIL() << "buffer mismatch at iovec " << n << "/" << count
807 if (iov[n].iov_len > bytesLeft) {
808 FAIL() << "server did not read enough data: "
809 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
810 << " in iovec " << n << "/" << count;
813 idx += iov[n].iov_len;
815 if (idx != bytesRead_) {
816 ADD_FAILURE() << "server read extra data: " << bytesRead_
817 << " bytes read; expected " << idx;
822 void handshakeSuc(AsyncSSLSocket*) noexcept override {
823 // Wait 10ms before reading, so the client's writes will initially block.
824 socket_->getEventBase()->tryRunAfterDelay(
825 [this] { socket_->setReadCB(this); }, 10);
829 const AsyncSocketException& ex) noexcept override {
830 ADD_FAILURE() << "server handshake error: " << ex.what();
832 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
833 *bufReturn = buf_.get() + bytesRead_;
834 *lenReturn = bufSize_ - bytesRead_;
836 void readDataAvailable(size_t len) noexcept override {
838 socket_->setReadCB(nullptr);
839 socket_->getEventBase()->tryRunAfterDelay(
840 [this] { socket_->setReadCB(this); }, 2);
842 void readEOF() noexcept override {
846 const AsyncSocketException& ex) noexcept override {
847 ADD_FAILURE() << "server read error: " << ex.what();
850 AsyncSSLSocket::UniquePtr socket_;
853 std::unique_ptr<uint8_t[]> buf_;
857 private AsyncSSLSocket::HandshakeCB,
858 private AsyncTransportWrapper::WriteCallback {
861 AsyncSSLSocket::UniquePtr socket)
862 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
863 socket_->sslConn(this);
866 const unsigned char* nextProto;
867 unsigned nextProtoLength;
868 SSLContext::NextProtocolType protocolType;
871 void handshakeSuc(AsyncSSLSocket*) noexcept override {
872 socket_->getSelectedNextProtocol(
873 &nextProto, &nextProtoLength, &protocolType);
877 const AsyncSocketException& ex) noexcept override {
878 ADD_FAILURE() << "client handshake error: " << ex.what();
880 void writeSuccess() noexcept override {
885 const AsyncSocketException& ex) noexcept override {
886 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
890 AsyncSSLSocket::UniquePtr socket_;
894 private AsyncSSLSocket::HandshakeCB,
895 private AsyncTransportWrapper::ReadCallback {
897 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
898 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
899 socket_->sslAccept(this);
902 const unsigned char* nextProto;
903 unsigned nextProtoLength;
904 SSLContext::NextProtocolType protocolType;
907 void handshakeSuc(AsyncSSLSocket*) noexcept override {
908 socket_->getSelectedNextProtocol(
909 &nextProto, &nextProtoLength, &protocolType);
913 const AsyncSocketException& ex) noexcept override {
914 ADD_FAILURE() << "server handshake error: " << ex.what();
916 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
919 void readDataAvailable(size_t /* len */) noexcept override {}
920 void readEOF() noexcept override {
924 const AsyncSocketException& ex) noexcept override {
925 ADD_FAILURE() << "server read error: " << ex.what();
928 AsyncSSLSocket::UniquePtr socket_;
931 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
932 public AsyncTransportWrapper::ReadCallback {
934 explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
935 : socket_(std::move(socket)) {
936 socket_->sslAccept(this);
939 ~RenegotiatingServer() {
940 socket_->setReadCB(nullptr);
943 void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
944 LOG(INFO) << "Renegotiating server handshake success";
945 socket_->setReadCB(this);
949 const AsyncSocketException& ex) noexcept override {
950 ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
952 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
953 *lenReturn = sizeof(buf);
956 void readDataAvailable(size_t /* len */) noexcept override {}
957 void readEOF() noexcept override {}
958 void readErr(const AsyncSocketException& ex) noexcept override {
959 LOG(INFO) << "server got read error " << ex.what();
960 auto exPtr = dynamic_cast<const SSLException*>(&ex);
961 ASSERT_NE(nullptr, exPtr);
962 std::string exStr(ex.what());
963 SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
964 ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
965 renegotiationError_ = true;
968 AsyncSSLSocket::UniquePtr socket_;
969 unsigned char buf[128];
970 bool renegotiationError_{false};
973 #ifndef OPENSSL_NO_TLSEXT
975 private AsyncSSLSocket::HandshakeCB,
976 private AsyncTransportWrapper::WriteCallback {
979 AsyncSSLSocket::UniquePtr socket)
980 : serverNameMatch(false), socket_(std::move(socket)) {
981 socket_->sslConn(this);
984 bool serverNameMatch;
987 void handshakeSuc(AsyncSSLSocket*) noexcept override {
988 serverNameMatch = socket_->isServerNameMatch();
992 const AsyncSocketException& ex) noexcept override {
993 ADD_FAILURE() << "client handshake error: " << ex.what();
995 void writeSuccess() noexcept override {
1000 const AsyncSocketException& ex) noexcept override {
1001 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1005 AsyncSSLSocket::UniquePtr socket_;
1009 private AsyncSSLSocket::HandshakeCB,
1010 private AsyncTransportWrapper::ReadCallback {
1013 AsyncSSLSocket::UniquePtr socket,
1014 const std::shared_ptr<folly::SSLContext>& ctx,
1015 const std::shared_ptr<folly::SSLContext>& sniCtx,
1016 const std::string& expectedServerName)
1017 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1018 expectedServerName_(expectedServerName) {
1019 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1020 std::placeholders::_1));
1021 socket_->sslAccept(this);
1024 bool serverNameMatch;
1027 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1030 const AsyncSocketException& ex) noexcept override {
1031 ADD_FAILURE() << "server handshake error: " << ex.what();
1033 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1036 void readDataAvailable(size_t /* len */) noexcept override {}
1037 void readEOF() noexcept override {
1041 const AsyncSocketException& ex) noexcept override {
1042 ADD_FAILURE() << "server read error: " << ex.what();
1045 folly::SSLContext::ServerNameCallbackResult
1046 serverNameCallback(SSL *ssl) {
1047 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1050 !strcasecmp(expectedServerName_.c_str(), sn)) {
1051 AsyncSSLSocket *sslSocket =
1052 AsyncSSLSocket::getFromSSL(ssl);
1053 sslSocket->switchServerSSLContext(sniCtx_);
1054 serverNameMatch = true;
1055 return folly::SSLContext::SERVER_NAME_FOUND;
1057 serverNameMatch = false;
1058 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1062 AsyncSSLSocket::UniquePtr socket_;
1063 std::shared_ptr<folly::SSLContext> sniCtx_;
1064 std::string expectedServerName_;
1068 class SSLClient : public AsyncSocket::ConnectCallback,
1069 public AsyncTransportWrapper::WriteCallback,
1070 public AsyncTransportWrapper::ReadCallback
1073 EventBase *eventBase_;
1074 std::shared_ptr<AsyncSSLSocket> sslSocket_;
1075 SSL_SESSION *session_;
1076 std::shared_ptr<folly::SSLContext> ctx_;
1078 folly::SocketAddress address_;
1082 uint32_t bytesRead_;
1086 uint32_t writeAfterConnectErrors_;
1088 // These settings test that we eventually drain the
1089 // socket, even if the maxReadsPerEvent_ is hit during
1090 // a event loop iteration.
1091 static constexpr size_t kMaxReadsPerEvent = 2;
1092 // 2 event loop iterations
1093 static constexpr size_t kMaxReadBufferSz =
1094 sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1097 SSLClient(EventBase *eventBase,
1098 const folly::SocketAddress& address,
1100 uint32_t timeout = 0)
1101 : eventBase_(eventBase),
1103 requests_(requests),
1110 writeAfterConnectErrors_(0) {
1111 ctx_.reset(new folly::SSLContext());
1112 ctx_->setOptions(SSL_OP_NO_TICKET);
1113 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1114 memset(buf_, 'a', sizeof(buf_));
1119 SSL_SESSION_free(session_);
1122 EXPECT_EQ(bytesRead_, sizeof(buf_));
1126 uint32_t getHit() const { return hit_; }
1128 uint32_t getMiss() const { return miss_; }
1130 uint32_t getErrors() const { return errors_; }
1132 uint32_t getWriteAfterConnectErrors() const {
1133 return writeAfterConnectErrors_;
1136 void connect(bool writeNow = false) {
1137 sslSocket_ = AsyncSSLSocket::newSocket(
1139 if (session_ != nullptr) {
1140 sslSocket_->setSSLSession(session_);
1143 sslSocket_->connect(this, address_, timeout_);
1144 if (sslSocket_ && writeNow) {
1145 // write some junk, used in an error test
1146 sslSocket_->write(this, buf_, sizeof(buf_));
1150 void connectSuccess() noexcept override {
1151 std::cerr << "client SSL socket connected" << std::endl;
1152 if (sslSocket_->getSSLSessionReused()) {
1156 if (session_ != nullptr) {
1157 SSL_SESSION_free(session_);
1159 session_ = sslSocket_->getSSLSession();
1163 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1164 sslSocket_->write(this, buf_, sizeof(buf_));
1165 sslSocket_->setReadCB(this);
1166 memset(readbuf_, 'b', sizeof(readbuf_));
1171 const AsyncSocketException& ex) noexcept override {
1172 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1177 void writeSuccess() noexcept override {
1178 std::cerr << "client write success" << std::endl;
1181 void writeErr(size_t /* bytesWritten */,
1182 const AsyncSocketException& ex) noexcept override {
1183 std::cerr << "client writeError: " << ex.what() << std::endl;
1185 writeAfterConnectErrors_++;
1189 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1190 *bufReturn = readbuf_ + bytesRead_;
1191 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1194 void readEOF() noexcept override {
1195 std::cerr << "client readEOF" << std::endl;
1199 const AsyncSocketException& ex) noexcept override {
1200 std::cerr << "client readError: " << ex.what() << std::endl;
1203 void readDataAvailable(size_t len) noexcept override {
1204 std::cerr << "client read data: " << len << std::endl;
1206 if (bytesRead_ == sizeof(buf_)) {
1207 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1208 sslSocket_->closeNow();
1210 if (requests_ != 0) {
1218 class SSLHandshakeBase :
1219 public AsyncSSLSocket::HandshakeCB,
1220 private AsyncTransportWrapper::WriteCallback {
1222 explicit SSLHandshakeBase(
1223 AsyncSSLSocket::UniquePtr socket,
1224 bool preverifyResult,
1225 bool verifyResult) :
1226 handshakeVerify_(false),
1227 handshakeSuccess_(false),
1228 handshakeError_(false),
1229 socket_(std::move(socket)),
1230 preverifyResult_(preverifyResult),
1231 verifyResult_(verifyResult) {
1234 AsyncSSLSocket::UniquePtr moveSocket() && {
1235 return std::move(socket_);
1238 bool handshakeVerify_;
1239 bool handshakeSuccess_;
1240 bool handshakeError_;
1241 std::chrono::nanoseconds handshakeTime;
1244 AsyncSSLSocket::UniquePtr socket_;
1245 bool preverifyResult_;
1248 // HandshakeCallback
1249 bool handshakeVer(AsyncSSLSocket* /* sock */,
1251 X509_STORE_CTX* /* ctx */) noexcept override {
1252 handshakeVerify_ = true;
1254 EXPECT_EQ(preverifyResult_, preverifyOk);
1255 return verifyResult_;
1258 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1259 LOG(INFO) << "Handshake success";
1260 handshakeSuccess_ = true;
1261 handshakeTime = socket_->getHandshakeTime();
1266 const AsyncSocketException& ex) noexcept override {
1267 LOG(INFO) << "Handshake error " << ex.what();
1268 handshakeError_ = true;
1269 handshakeTime = socket_->getHandshakeTime();
1273 void writeSuccess() noexcept override {
1278 size_t bytesWritten,
1279 const AsyncSocketException& ex) noexcept override {
1280 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1285 class SSLHandshakeClient : public SSLHandshakeBase {
1288 AsyncSSLSocket::UniquePtr socket,
1289 bool preverifyResult,
1290 bool verifyResult) :
1291 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1292 socket_->sslConn(this, 0);
1296 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1298 SSLHandshakeClientNoVerify(
1299 AsyncSSLSocket::UniquePtr socket,
1300 bool preverifyResult,
1301 bool verifyResult) :
1302 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1303 socket_->sslConn(this, 0,
1304 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1308 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1310 SSLHandshakeClientDoVerify(
1311 AsyncSSLSocket::UniquePtr socket,
1312 bool preverifyResult,
1313 bool verifyResult) :
1314 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1315 socket_->sslConn(this, 0,
1316 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1320 class SSLHandshakeServer : public SSLHandshakeBase {
1323 AsyncSSLSocket::UniquePtr socket,
1324 bool preverifyResult,
1326 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1327 socket_->sslAccept(this, 0);
1331 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1333 SSLHandshakeServerParseClientHello(
1334 AsyncSSLSocket::UniquePtr socket,
1335 bool preverifyResult,
1337 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1338 socket_->enableClientHelloParsing();
1339 socket_->sslAccept(this, 0);
1342 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1345 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1346 handshakeSuccess_ = true;
1347 sock->getSSLSharedCiphers(sharedCiphers_);
1348 sock->getSSLServerCiphers(serverCiphers_);
1349 sock->getSSLClientCiphers(clientCiphers_);
1350 chosenCipher_ = sock->getNegotiatedCipherName();
1355 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1357 SSLHandshakeServerNoVerify(
1358 AsyncSSLSocket::UniquePtr socket,
1359 bool preverifyResult,
1361 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1362 socket_->sslAccept(this, 0,
1363 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1367 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1369 SSLHandshakeServerDoVerify(
1370 AsyncSSLSocket::UniquePtr socket,
1371 bool preverifyResult,
1373 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1374 socket_->sslAccept(this, 0,
1375 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1379 class EventBaseAborter : public AsyncTimeout {
1381 EventBaseAborter(EventBase* eventBase,
1384 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1385 , eventBase_(eventBase) {
1386 scheduleTimeout(timeoutMS);
1389 void timeoutExpired() noexcept override {
1390 FAIL() << "test timed out";
1391 eventBase_->terminateLoopSoon();
1395 EventBase* eventBase_;