2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include <folly/io/async/AsyncServerSocket.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <folly/io/async/AsyncTransport.h>
25 #include <folly/io/async/EventBase.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/SocketAddress.h>
29 #include <gtest/gtest.h>
35 #include <sys/types.h>
36 #include <sys/socket.h>
37 #include <netinet/tcp.h>
47 // The destructors of all callback classes assert that the state is
48 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
49 // are responsible for setting the succeeded state properly before the
50 // destructors are called.
52 class WriteCallbackBase :
53 public AsyncTransportWrapper::WriteCallback {
56 : state(STATE_WAITING)
58 , exception(AsyncSocketException::UNKNOWN, "none") {}
60 ~WriteCallbackBase() {
61 EXPECT_EQ(state, STATE_SUCCEEDED);
65 const std::shared_ptr<AsyncSSLSocket> &socket) {
69 void writeSuccess() noexcept override {
70 std::cerr << "writeSuccess" << std::endl;
71 state = STATE_SUCCEEDED;
76 const AsyncSocketException& ex) noexcept override {
77 std::cerr << "writeError: bytesWritten " << bytesWritten
78 << ", exception " << ex.what() << std::endl;
81 this->bytesWritten = bytesWritten;
84 socket_->detachEventBase();
87 std::shared_ptr<AsyncSSLSocket> socket_;
90 AsyncSocketException exception;
93 class ReadCallbackBase :
94 public AsyncTransportWrapper::ReadCallback {
96 explicit ReadCallbackBase(WriteCallbackBase *wcb)
98 , state(STATE_WAITING) {}
100 ~ReadCallbackBase() {
101 EXPECT_EQ(state, STATE_SUCCEEDED);
105 const std::shared_ptr<AsyncSSLSocket> &socket) {
109 void setState(StateEnum s) {
117 const AsyncSocketException& ex) noexcept override {
118 std::cerr << "readError " << ex.what() << std::endl;
119 state = STATE_FAILED;
121 socket_->detachEventBase();
124 void readEOF() noexcept override {
125 std::cerr << "readEOF" << std::endl;
128 socket_->detachEventBase();
131 std::shared_ptr<AsyncSSLSocket> socket_;
132 WriteCallbackBase *wcb_;
136 class ReadCallback : public ReadCallbackBase {
138 explicit ReadCallback(WriteCallbackBase *wcb)
139 : ReadCallbackBase(wcb)
143 for (std::vector<Buffer>::iterator it = buffers.begin();
148 currentBuffer.free();
151 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
152 if (!currentBuffer.buffer) {
153 currentBuffer.allocate(4096);
155 *bufReturn = currentBuffer.buffer;
156 *lenReturn = currentBuffer.length;
159 void readDataAvailable(size_t len) noexcept override {
160 std::cerr << "readDataAvailable, len " << len << std::endl;
162 currentBuffer.length = len;
164 wcb_->setSocket(socket_);
166 // Write back the same data.
167 socket_->write(wcb_, currentBuffer.buffer, len);
169 buffers.push_back(currentBuffer);
170 currentBuffer.reset();
171 state = STATE_SUCCEEDED;
176 Buffer() : buffer(nullptr), length(0) {}
177 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
183 void allocate(size_t length) {
184 assert(buffer == nullptr);
185 this->buffer = static_cast<char*>(malloc(length));
186 this->length = length;
197 std::vector<Buffer> buffers;
198 Buffer currentBuffer;
201 class ReadErrorCallback : public ReadCallbackBase {
203 explicit ReadErrorCallback(WriteCallbackBase *wcb)
204 : ReadCallbackBase(wcb) {}
206 // Return nullptr buffer to trigger readError()
207 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
208 *bufReturn = nullptr;
212 void readDataAvailable(size_t /* len */) noexcept override {
213 // This should never to called.
218 const AsyncSocketException& ex) noexcept override {
219 ReadCallbackBase::readErr(ex);
220 std::cerr << "ReadErrorCallback::readError" << std::endl;
221 setState(STATE_SUCCEEDED);
225 class WriteErrorCallback : public ReadCallback {
227 explicit WriteErrorCallback(WriteCallbackBase *wcb)
228 : ReadCallback(wcb) {}
230 void readDataAvailable(size_t len) noexcept override {
231 std::cerr << "readDataAvailable, len " << len << std::endl;
233 currentBuffer.length = len;
235 // close the socket before writing to trigger writeError().
236 ::close(socket_->getFd());
238 wcb_->setSocket(socket_);
240 // Write back the same data.
241 socket_->write(wcb_, currentBuffer.buffer, len);
243 if (wcb_->state == STATE_FAILED) {
244 setState(STATE_SUCCEEDED);
246 state = STATE_FAILED;
249 buffers.push_back(currentBuffer);
250 currentBuffer.reset();
253 void readErr(const AsyncSocketException& ex) noexcept override {
254 std::cerr << "readError " << ex.what() << std::endl;
255 // do nothing since this is expected
259 class EmptyReadCallback : public ReadCallback {
261 explicit EmptyReadCallback()
262 : ReadCallback(nullptr) {}
264 void readErr(const AsyncSocketException& ex) noexcept override {
265 std::cerr << "readError " << ex.what() << std::endl;
266 state = STATE_FAILED;
268 tcpSocket_->detachEventBase();
271 void readEOF() noexcept override {
272 std::cerr << "readEOF" << std::endl;
275 tcpSocket_->detachEventBase();
276 state = STATE_SUCCEEDED;
279 std::shared_ptr<AsyncSocket> tcpSocket_;
282 class HandshakeCallback :
283 public AsyncSSLSocket::HandshakeCB {
290 explicit HandshakeCallback(ReadCallbackBase *rcb,
291 ExpectType expect = EXPECT_SUCCESS):
292 state(STATE_WAITING),
297 const std::shared_ptr<AsyncSSLSocket> &socket) {
301 void setState(StateEnum s) {
306 // Functions inherited from AsyncSSLSocketHandshakeCallback
307 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
308 std::lock_guard<std::mutex> g(mutex_);
310 EXPECT_EQ(sock, socket_.get());
311 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
312 rcb_->setSocket(socket_);
313 sock->setReadCB(rcb_);
314 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
316 void handshakeErr(AsyncSSLSocket* /* sock */,
317 const AsyncSocketException& ex) noexcept override {
318 std::lock_guard<std::mutex> g(mutex_);
320 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
321 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
322 if (expect_ == EXPECT_ERROR) {
323 // rcb will never be invoked
324 rcb_->setState(STATE_SUCCEEDED);
326 errorString_ = ex.what();
329 void waitForHandshake() {
330 std::unique_lock<std::mutex> lock(mutex_);
331 cv_.wait(lock, [this] { return state != STATE_WAITING; });
334 ~HandshakeCallback() {
335 EXPECT_EQ(state, STATE_SUCCEEDED);
340 state = STATE_SUCCEEDED;
344 std::shared_ptr<AsyncSSLSocket> socket_;
345 ReadCallbackBase *rcb_;
348 std::condition_variable cv_;
349 std::string errorString_;
352 class SSLServerAcceptCallbackBase:
353 public folly::AsyncServerSocket::AcceptCallback {
355 explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
356 state(STATE_WAITING), hcb_(hcb) {}
358 ~SSLServerAcceptCallbackBase() {
359 EXPECT_EQ(state, STATE_SUCCEEDED);
362 void acceptError(const std::exception& ex) noexcept override {
363 std::cerr << "SSLServerAcceptCallbackBase::acceptError "
364 << ex.what() << std::endl;
365 state = STATE_FAILED;
368 void connectionAccepted(
369 int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
370 printf("Connection accepted\n");
371 std::shared_ptr<AsyncSSLSocket> sslSock;
373 // Create a AsyncSSLSocket object with the fd. The socket should be
374 // added to the event base and in the state of accepting SSL connection.
375 sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
376 } catch (const std::exception &e) {
377 LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
378 "object with socket " << e.what() << fd;
384 connAccepted(sslSock);
387 virtual void connAccepted(
388 const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
391 HandshakeCallback *hcb_;
392 std::shared_ptr<folly::SSLContext> ctx_;
393 folly::EventBase* base_;
396 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
400 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
401 uint32_t timeout = 0):
402 SSLServerAcceptCallbackBase(hcb),
405 virtual ~SSLServerAcceptCallback() {
407 // if we set a timeout, we expect failure
408 EXPECT_EQ(hcb_->state, STATE_FAILED);
409 hcb_->setState(STATE_SUCCEEDED);
413 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
415 const std::shared_ptr<folly::AsyncSSLSocket> &s)
417 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
418 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
420 hcb_->setSocket(sock);
421 sock->sslAccept(hcb_, timeout_);
422 EXPECT_EQ(sock->getSSLState(),
423 AsyncSSLSocket::STATE_ACCEPTING);
425 state = STATE_SUCCEEDED;
429 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
431 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
432 SSLServerAcceptCallback(hcb) {}
434 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
436 const std::shared_ptr<folly::AsyncSSLSocket> &s)
439 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
441 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
443 int fd = sock->getFd();
447 // The accepted connection should already have TCP_NODELAY set
449 socklen_t valueLength = sizeof(value);
450 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
456 // Unset the TCP_NODELAY option.
458 socklen_t valueLength = sizeof(value);
459 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
462 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
466 SSLServerAcceptCallback::connAccepted(sock);
470 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
472 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
473 uint32_t timeout = 0):
474 SSLServerAcceptCallback(hcb, timeout) {}
476 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
478 const std::shared_ptr<folly::AsyncSSLSocket> &s)
480 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
482 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
484 hcb_->setSocket(sock);
485 sock->sslAccept(hcb_, timeout_);
486 ASSERT_TRUE((sock->getSSLState() ==
487 AsyncSSLSocket::STATE_ACCEPTING) ||
488 (sock->getSSLState() ==
489 AsyncSSLSocket::STATE_CACHE_LOOKUP));
491 state = STATE_SUCCEEDED;
496 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
498 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
499 SSLServerAcceptCallbackBase(hcb) {}
501 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
503 const std::shared_ptr<folly::AsyncSSLSocket> &s)
505 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
507 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
509 // The first call to sslAccept() should succeed.
510 hcb_->setSocket(sock);
511 sock->sslAccept(hcb_);
512 EXPECT_EQ(sock->getSSLState(),
513 AsyncSSLSocket::STATE_ACCEPTING);
515 // The second call to sslAccept() should fail.
516 HandshakeCallback callback2(hcb_->rcb_);
517 callback2.setSocket(sock);
518 sock->sslAccept(&callback2);
519 EXPECT_EQ(sock->getSSLState(),
520 AsyncSSLSocket::STATE_ERROR);
522 // Both callbacks should be in the error state.
523 EXPECT_EQ(hcb_->state, STATE_FAILED);
524 EXPECT_EQ(callback2.state, STATE_FAILED);
526 sock->detachEventBase();
528 state = STATE_SUCCEEDED;
529 hcb_->setState(STATE_SUCCEEDED);
530 callback2.setState(STATE_SUCCEEDED);
534 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
536 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
537 SSLServerAcceptCallbackBase(hcb) {}
539 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
541 const std::shared_ptr<folly::AsyncSSLSocket> &s)
543 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
545 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
547 hcb_->setSocket(sock);
548 sock->getEventBase()->tryRunAfterDelay([=] {
549 std::cerr << "Delayed SSL accept, client will have close by now"
551 // SSL accept will fail
554 AsyncSSLSocket::STATE_UNINIT);
555 hcb_->socket_->sslAccept(hcb_);
556 // This registers for an event
559 AsyncSSLSocket::STATE_ACCEPTING);
561 state = STATE_SUCCEEDED;
567 class TestSSLServer {
570 std::shared_ptr<folly::SSLContext> ctx_;
571 SSLServerAcceptCallbackBase *acb_;
572 std::shared_ptr<folly::AsyncServerSocket> socket_;
573 folly::SocketAddress address_;
576 static void *Main(void *ctx) {
577 TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
579 std::cerr << "Server thread exited event loop" << std::endl;
584 // Create a TestSSLServer.
585 // This immediately starts listening on the given port.
586 explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
590 evb_.runInEventBaseThread([&](){
591 socket_->stopAccepting();
593 std::cerr << "Waiting for server thread to exit" << std::endl;
594 pthread_join(thread_, nullptr);
597 EventBase &getEventBase() { return evb_; }
599 const folly::SocketAddress& getAddress() const {
604 class TestSSLAsyncCacheServer : public TestSSLServer {
606 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
607 int lookupDelay = 100) :
609 SSL_CTX *sslCtx = ctx_->getSSLCtx();
610 SSL_CTX_sess_set_get_cb(sslCtx,
611 TestSSLAsyncCacheServer::getSessionCallback);
612 SSL_CTX_set_session_cache_mode(
613 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
616 lookupDelay_ = lookupDelay;
619 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
620 uint32_t getAsyncLookups() const { return asyncLookups_; }
623 static uint32_t asyncCallbacks_;
624 static uint32_t asyncLookups_;
625 static uint32_t lookupDelay_;
627 static SSL_SESSION* getSessionCallback(SSL* ssl,
628 unsigned char* /* sess_id */,
633 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
634 if (!SSL_want_sess_cache_lookup(ssl)) {
635 // libssl.so mismatch
636 std::cerr << "no async support" << std::endl;
640 AsyncSSLSocket *sslSocket =
641 AsyncSSLSocket::getFromSSL(ssl);
642 assert(sslSocket != nullptr);
643 // Going to simulate an async cache by just running delaying the miss 100ms
644 if (asyncCallbacks_ % 2 == 0) {
645 // This socket is already blocked on lookup, return miss
646 std::cerr << "returning miss" << std::endl;
648 // fresh meat - block it
649 std::cerr << "async lookup" << std::endl;
650 sslSocket->getEventBase()->tryRunAfterDelay(
651 std::bind(&AsyncSSLSocket::restartSSLAccept,
652 sslSocket), lookupDelay_);
653 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
661 void getfds(int fds[2]);
664 std::shared_ptr<folly::SSLContext> clientCtx,
665 std::shared_ptr<folly::SSLContext> serverCtx);
668 EventBase* eventBase,
669 AsyncSSLSocket::UniquePtr* clientSock,
670 AsyncSSLSocket::UniquePtr* serverSock);
672 class BlockingWriteClient :
673 private AsyncSSLSocket::HandshakeCB,
674 private AsyncTransportWrapper::WriteCallback {
676 explicit BlockingWriteClient(
677 AsyncSSLSocket::UniquePtr socket)
678 : socket_(std::move(socket)),
682 buf_.reset(new uint8_t[bufLen_]);
683 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
688 iov_.reset(new struct iovec[iovCount_]);
689 for (uint32_t n = 0; n < iovCount_; ++n) {
690 iov_[n].iov_base = buf_.get() + n;
692 iov_[n].iov_len = n % bufLen_;
694 iov_[n].iov_len = bufLen_ - (n % bufLen_);
698 socket_->sslConn(this, 100);
701 struct iovec* getIovec() const {
704 uint32_t getIovecCount() const {
709 void handshakeSuc(AsyncSSLSocket*) noexcept override {
710 socket_->writev(this, iov_.get(), iovCount_);
714 const AsyncSocketException& ex) noexcept override {
715 ADD_FAILURE() << "client handshake error: " << ex.what();
717 void writeSuccess() noexcept override {
722 const AsyncSocketException& ex) noexcept override {
723 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
727 AsyncSSLSocket::UniquePtr socket_;
730 std::unique_ptr<uint8_t[]> buf_;
731 std::unique_ptr<struct iovec[]> iov_;
734 class BlockingWriteServer :
735 private AsyncSSLSocket::HandshakeCB,
736 private AsyncTransportWrapper::ReadCallback {
738 explicit BlockingWriteServer(
739 AsyncSSLSocket::UniquePtr socket)
740 : socket_(std::move(socket)),
741 bufSize_(2500 * 2000),
743 buf_.reset(new uint8_t[bufSize_]);
744 socket_->sslAccept(this, 100);
747 void checkBuffer(struct iovec* iov, uint32_t count) const {
749 for (uint32_t n = 0; n < count; ++n) {
750 size_t bytesLeft = bytesRead_ - idx;
751 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
752 std::min(iov[n].iov_len, bytesLeft));
754 FAIL() << "buffer mismatch at iovec " << n << "/" << count
758 if (iov[n].iov_len > bytesLeft) {
759 FAIL() << "server did not read enough data: "
760 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
761 << " in iovec " << n << "/" << count;
764 idx += iov[n].iov_len;
766 if (idx != bytesRead_) {
767 ADD_FAILURE() << "server read extra data: " << bytesRead_
768 << " bytes read; expected " << idx;
773 void handshakeSuc(AsyncSSLSocket*) noexcept override {
774 // Wait 10ms before reading, so the client's writes will initially block.
775 socket_->getEventBase()->tryRunAfterDelay(
776 [this] { socket_->setReadCB(this); }, 10);
780 const AsyncSocketException& ex) noexcept override {
781 ADD_FAILURE() << "server handshake error: " << ex.what();
783 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
784 *bufReturn = buf_.get() + bytesRead_;
785 *lenReturn = bufSize_ - bytesRead_;
787 void readDataAvailable(size_t len) noexcept override {
789 socket_->setReadCB(nullptr);
790 socket_->getEventBase()->tryRunAfterDelay(
791 [this] { socket_->setReadCB(this); }, 2);
793 void readEOF() noexcept override {
797 const AsyncSocketException& ex) noexcept override {
798 ADD_FAILURE() << "server read error: " << ex.what();
801 AsyncSSLSocket::UniquePtr socket_;
804 std::unique_ptr<uint8_t[]> buf_;
808 private AsyncSSLSocket::HandshakeCB,
809 private AsyncTransportWrapper::WriteCallback {
812 AsyncSSLSocket::UniquePtr socket)
813 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
814 socket_->sslConn(this);
817 const unsigned char* nextProto;
818 unsigned nextProtoLength;
819 SSLContext::NextProtocolType protocolType;
822 void handshakeSuc(AsyncSSLSocket*) noexcept override {
823 socket_->getSelectedNextProtocol(
824 &nextProto, &nextProtoLength, &protocolType);
828 const AsyncSocketException& ex) noexcept override {
829 ADD_FAILURE() << "client handshake error: " << ex.what();
831 void writeSuccess() noexcept override {
836 const AsyncSocketException& ex) noexcept override {
837 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
841 AsyncSSLSocket::UniquePtr socket_;
845 private AsyncSSLSocket::HandshakeCB,
846 private AsyncTransportWrapper::ReadCallback {
848 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
849 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
850 socket_->sslAccept(this);
853 const unsigned char* nextProto;
854 unsigned nextProtoLength;
855 SSLContext::NextProtocolType protocolType;
858 void handshakeSuc(AsyncSSLSocket*) noexcept override {
859 socket_->getSelectedNextProtocol(
860 &nextProto, &nextProtoLength, &protocolType);
864 const AsyncSocketException& ex) noexcept override {
865 ADD_FAILURE() << "server handshake error: " << ex.what();
867 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
870 void readDataAvailable(size_t /* len */) noexcept override {}
871 void readEOF() noexcept override {
875 const AsyncSocketException& ex) noexcept override {
876 ADD_FAILURE() << "server read error: " << ex.what();
879 AsyncSSLSocket::UniquePtr socket_;
882 #ifndef OPENSSL_NO_TLSEXT
884 private AsyncSSLSocket::HandshakeCB,
885 private AsyncTransportWrapper::WriteCallback {
888 AsyncSSLSocket::UniquePtr socket)
889 : serverNameMatch(false), socket_(std::move(socket)) {
890 socket_->sslConn(this);
893 bool serverNameMatch;
896 void handshakeSuc(AsyncSSLSocket*) noexcept override {
897 serverNameMatch = socket_->isServerNameMatch();
901 const AsyncSocketException& ex) noexcept override {
902 ADD_FAILURE() << "client handshake error: " << ex.what();
904 void writeSuccess() noexcept override {
909 const AsyncSocketException& ex) noexcept override {
910 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
914 AsyncSSLSocket::UniquePtr socket_;
918 private AsyncSSLSocket::HandshakeCB,
919 private AsyncTransportWrapper::ReadCallback {
922 AsyncSSLSocket::UniquePtr socket,
923 const std::shared_ptr<folly::SSLContext>& ctx,
924 const std::shared_ptr<folly::SSLContext>& sniCtx,
925 const std::string& expectedServerName)
926 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
927 expectedServerName_(expectedServerName) {
928 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
929 std::placeholders::_1));
930 socket_->sslAccept(this);
933 bool serverNameMatch;
936 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
939 const AsyncSocketException& ex) noexcept override {
940 ADD_FAILURE() << "server handshake error: " << ex.what();
942 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
945 void readDataAvailable(size_t /* len */) noexcept override {}
946 void readEOF() noexcept override {
950 const AsyncSocketException& ex) noexcept override {
951 ADD_FAILURE() << "server read error: " << ex.what();
954 folly::SSLContext::ServerNameCallbackResult
955 serverNameCallback(SSL *ssl) {
956 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
959 !strcasecmp(expectedServerName_.c_str(), sn)) {
960 AsyncSSLSocket *sslSocket =
961 AsyncSSLSocket::getFromSSL(ssl);
962 sslSocket->switchServerSSLContext(sniCtx_);
963 serverNameMatch = true;
964 return folly::SSLContext::SERVER_NAME_FOUND;
966 serverNameMatch = false;
967 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
971 AsyncSSLSocket::UniquePtr socket_;
972 std::shared_ptr<folly::SSLContext> sniCtx_;
973 std::string expectedServerName_;
977 class SSLClient : public AsyncSocket::ConnectCallback,
978 public AsyncTransportWrapper::WriteCallback,
979 public AsyncTransportWrapper::ReadCallback
982 EventBase *eventBase_;
983 std::shared_ptr<AsyncSSLSocket> sslSocket_;
984 SSL_SESSION *session_;
985 std::shared_ptr<folly::SSLContext> ctx_;
987 folly::SocketAddress address_;
995 uint32_t writeAfterConnectErrors_;
997 // These settings test that we eventually drain the
998 // socket, even if the maxReadsPerEvent_ is hit during
999 // a event loop iteration.
1000 static constexpr size_t kMaxReadsPerEvent = 2;
1001 static constexpr size_t kMaxReadBufferSz =
1002 sizeof(readbuf_) / kMaxReadsPerEvent / 2; // 2 event loop iterations
1005 SSLClient(EventBase *eventBase,
1006 const folly::SocketAddress& address,
1008 uint32_t timeout = 0)
1009 : eventBase_(eventBase),
1011 requests_(requests),
1018 writeAfterConnectErrors_(0) {
1019 ctx_.reset(new folly::SSLContext());
1020 ctx_->setOptions(SSL_OP_NO_TICKET);
1021 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1022 memset(buf_, 'a', sizeof(buf_));
1027 SSL_SESSION_free(session_);
1030 EXPECT_EQ(bytesRead_, sizeof(buf_));
1034 uint32_t getHit() const { return hit_; }
1036 uint32_t getMiss() const { return miss_; }
1038 uint32_t getErrors() const { return errors_; }
1040 uint32_t getWriteAfterConnectErrors() const {
1041 return writeAfterConnectErrors_;
1044 void connect(bool writeNow = false) {
1045 sslSocket_ = AsyncSSLSocket::newSocket(
1047 if (session_ != nullptr) {
1048 sslSocket_->setSSLSession(session_);
1051 sslSocket_->connect(this, address_, timeout_);
1052 if (sslSocket_ && writeNow) {
1053 // write some junk, used in an error test
1054 sslSocket_->write(this, buf_, sizeof(buf_));
1058 void connectSuccess() noexcept override {
1059 std::cerr << "client SSL socket connected" << std::endl;
1060 if (sslSocket_->getSSLSessionReused()) {
1064 if (session_ != nullptr) {
1065 SSL_SESSION_free(session_);
1067 session_ = sslSocket_->getSSLSession();
1071 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1072 sslSocket_->write(this, buf_, sizeof(buf_));
1073 sslSocket_->setReadCB(this);
1074 memset(readbuf_, 'b', sizeof(readbuf_));
1079 const AsyncSocketException& ex) noexcept override {
1080 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1085 void writeSuccess() noexcept override {
1086 std::cerr << "client write success" << std::endl;
1089 void writeErr(size_t /* bytesWritten */,
1090 const AsyncSocketException& ex) noexcept override {
1091 std::cerr << "client writeError: " << ex.what() << std::endl;
1093 writeAfterConnectErrors_++;
1097 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1098 *bufReturn = readbuf_ + bytesRead_;
1099 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1102 void readEOF() noexcept override {
1103 std::cerr << "client readEOF" << std::endl;
1107 const AsyncSocketException& ex) noexcept override {
1108 std::cerr << "client readError: " << ex.what() << std::endl;
1111 void readDataAvailable(size_t len) noexcept override {
1112 std::cerr << "client read data: " << len << std::endl;
1114 if (bytesRead_ == sizeof(buf_)) {
1115 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1116 sslSocket_->closeNow();
1118 if (requests_ != 0) {
1126 class SSLHandshakeBase :
1127 public AsyncSSLSocket::HandshakeCB,
1128 private AsyncTransportWrapper::WriteCallback {
1130 explicit SSLHandshakeBase(
1131 AsyncSSLSocket::UniquePtr socket,
1132 bool preverifyResult,
1133 bool verifyResult) :
1134 handshakeVerify_(false),
1135 handshakeSuccess_(false),
1136 handshakeError_(false),
1137 socket_(std::move(socket)),
1138 preverifyResult_(preverifyResult),
1139 verifyResult_(verifyResult) {
1142 bool handshakeVerify_;
1143 bool handshakeSuccess_;
1144 bool handshakeError_;
1145 std::chrono::nanoseconds handshakeTime;
1148 AsyncSSLSocket::UniquePtr socket_;
1149 bool preverifyResult_;
1152 // HandshakeCallback
1153 bool handshakeVer(AsyncSSLSocket* /* sock */,
1155 X509_STORE_CTX* /* ctx */) noexcept override {
1156 handshakeVerify_ = true;
1158 EXPECT_EQ(preverifyResult_, preverifyOk);
1159 return verifyResult_;
1162 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1163 handshakeSuccess_ = true;
1164 handshakeTime = socket_->getHandshakeTime();
1167 void handshakeErr(AsyncSSLSocket*,
1168 const AsyncSocketException& /* ex */) noexcept override {
1169 handshakeError_ = true;
1170 handshakeTime = socket_->getHandshakeTime();
1174 void writeSuccess() noexcept override {
1179 size_t bytesWritten,
1180 const AsyncSocketException& ex) noexcept override {
1181 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1186 class SSLHandshakeClient : public SSLHandshakeBase {
1189 AsyncSSLSocket::UniquePtr socket,
1190 bool preverifyResult,
1191 bool verifyResult) :
1192 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1193 socket_->sslConn(this, 0);
1197 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1199 SSLHandshakeClientNoVerify(
1200 AsyncSSLSocket::UniquePtr socket,
1201 bool preverifyResult,
1202 bool verifyResult) :
1203 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1204 socket_->sslConn(this, 0,
1205 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1209 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1211 SSLHandshakeClientDoVerify(
1212 AsyncSSLSocket::UniquePtr socket,
1213 bool preverifyResult,
1214 bool verifyResult) :
1215 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1216 socket_->sslConn(this, 0,
1217 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1221 class SSLHandshakeServer : public SSLHandshakeBase {
1224 AsyncSSLSocket::UniquePtr socket,
1225 bool preverifyResult,
1227 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1228 socket_->sslAccept(this, 0);
1232 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1234 SSLHandshakeServerParseClientHello(
1235 AsyncSSLSocket::UniquePtr socket,
1236 bool preverifyResult,
1238 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1239 socket_->enableClientHelloParsing();
1240 socket_->sslAccept(this, 0);
1243 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1246 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1247 handshakeSuccess_ = true;
1248 sock->getSSLSharedCiphers(sharedCiphers_);
1249 sock->getSSLServerCiphers(serverCiphers_);
1250 sock->getSSLClientCiphers(clientCiphers_);
1251 chosenCipher_ = sock->getNegotiatedCipherName();
1256 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1258 SSLHandshakeServerNoVerify(
1259 AsyncSSLSocket::UniquePtr socket,
1260 bool preverifyResult,
1262 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1263 socket_->sslAccept(this, 0,
1264 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1268 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1270 SSLHandshakeServerDoVerify(
1271 AsyncSSLSocket::UniquePtr socket,
1272 bool preverifyResult,
1274 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1275 socket_->sslAccept(this, 0,
1276 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1280 class EventBaseAborter : public AsyncTimeout {
1282 EventBaseAborter(EventBase* eventBase,
1285 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1286 , eventBase_(eventBase) {
1287 scheduleTimeout(timeoutMS);
1290 void timeoutExpired() noexcept override {
1291 FAIL() << "test timed out";
1292 eventBase_->terminateLoopSoon();
1296 EventBase* eventBase_;