2 * Copyright 2017 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/experimental/TestUtil.h>
24 #include <folly/io/async/AsyncSSLSocket.h>
25 #include <folly/io/async/AsyncServerSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/AsyncTimeout.h>
28 #include <folly/io/async/AsyncTransport.h>
29 #include <folly/io/async/EventBase.h>
30 #include <folly/io/async/ssl/SSLErrors.h>
31 #include <folly/io/async/test/TestSSLServer.h>
32 #include <folly/portability/GTest.h>
33 #include <folly/portability/Sockets.h>
34 #include <folly/portability/Unistd.h>
37 #include <sys/types.h>
38 #include <condition_variable>
44 // The destructors of all callback classes assert that the state is
45 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
46 // are responsible for setting the succeeded state properly before the
47 // destructors are called.
49 class SendMsgParamsCallbackBase :
50 public folly::AsyncSocket::SendMsgParamsCallback {
52 SendMsgParamsCallbackBase() {}
55 const std::shared_ptr<AsyncSSLSocket> &socket) {
57 oldCallback_ = socket_->getSendMsgParamsCB();
58 socket_->setSendMsgParamCB(this);
61 int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
63 return oldCallback_->getFlags(flags);
66 void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
67 oldCallback_->getAncillaryData(flags, data);
70 uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
71 return oldCallback_->getAncillaryDataSize(flags);
74 std::shared_ptr<AsyncSSLSocket> socket_;
75 folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
78 class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
80 SendMsgFlagsCallback() {}
82 void resetFlags(int flags) {
86 int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
91 return oldCallback_->getFlags(flags);
98 class SendMsgDataCallback : public SendMsgFlagsCallback {
100 SendMsgDataCallback() {}
102 void resetData(std::vector<char>&& data) {
103 ancillaryData_.swap(data);
106 void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
107 if (ancillaryData_.size()) {
108 std::cerr << "getAncillaryData: copying data" << std::endl;
109 memcpy(data, ancillaryData_.data(), ancillaryData_.size());
111 oldCallback_->getAncillaryData(flags, data);
115 uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
116 if (ancillaryData_.size()) {
117 std::cerr << "getAncillaryDataSize: returning size" << std::endl;
118 return ancillaryData_.size();
120 return oldCallback_->getAncillaryDataSize(flags);
124 std::vector<char> ancillaryData_;
127 class WriteCallbackBase :
128 public AsyncTransportWrapper::WriteCallback {
130 explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
131 : state(STATE_WAITING)
133 , exception(AsyncSocketException::UNKNOWN, "none")
136 ~WriteCallbackBase() {
137 EXPECT_EQ(STATE_SUCCEEDED, state);
140 virtual void setSocket(
141 const std::shared_ptr<AsyncSSLSocket> &socket) {
144 mcb_->setSocket(socket);
148 virtual void writeSuccess() noexcept override {
149 std::cerr << "writeSuccess" << std::endl;
150 state = STATE_SUCCEEDED;
154 size_t nBytesWritten,
155 const AsyncSocketException& ex) noexcept override {
156 std::cerr << "writeError: bytesWritten " << nBytesWritten
157 << ", exception " << ex.what() << std::endl;
159 state = STATE_FAILED;
160 this->bytesWritten = nBytesWritten;
165 std::shared_ptr<AsyncSSLSocket> socket_;
168 AsyncSocketException exception;
169 SendMsgParamsCallbackBase* mcb_;
172 class ExpectWriteErrorCallback :
173 public WriteCallbackBase {
175 explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
176 : WriteCallbackBase(mcb) {}
178 ~ExpectWriteErrorCallback() {
179 EXPECT_EQ(STATE_FAILED, state);
180 EXPECT_EQ(exception.type_,
181 AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
182 EXPECT_EQ(exception.errno_, 22);
183 // Suppress the assert in ~WriteCallbackBase()
184 state = STATE_SUCCEEDED;
189 /* copied from include/uapi/linux/net_tstamp.h */
190 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
191 enum SOF_TIMESTAMPING {
192 SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
193 SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
194 SOF_TIMESTAMPING_OPT_ID = (1 << 7),
195 SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
196 SOF_TIMESTAMPING_TX_ACK = (1 << 9),
197 SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
200 class WriteCheckTimestampCallback :
201 public WriteCallbackBase {
203 explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
204 : WriteCallbackBase(mcb) {}
206 ~WriteCheckTimestampCallback() {
207 EXPECT_EQ(STATE_SUCCEEDED, state);
208 EXPECT_TRUE(gotTimestamp_);
209 EXPECT_TRUE(gotByteSeq_);
213 const std::shared_ptr<AsyncSSLSocket> &socket) override {
214 WriteCallbackBase::setSocket(socket);
216 EXPECT_NE(socket_->getFd(), 0);
217 int flags = SOF_TIMESTAMPING_OPT_ID
218 | SOF_TIMESTAMPING_OPT_TSONLY
219 | SOF_TIMESTAMPING_SOFTWARE;
220 AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
221 int ret = tstampingOpt.apply(socket_->getFd(), flags);
225 void checkForTimestampNotifications() noexcept {
226 int fd = socket_->getFd();
227 std::vector<char> ctrl(1024, 0);
232 memset(&msg, 0, sizeof(msg));
233 entry.iov_base = &data;
234 entry.iov_len = sizeof(data);
235 msg.msg_iov = &entry;
237 msg.msg_control = ctrl.data();
238 msg.msg_controllen = ctrl.size();
242 ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
244 if (errno != EAGAIN) {
245 auto errnoCopy = errno;
246 std::cerr << "::recvmsg exited with code " << ret
247 << ", errno: " << errnoCopy << std::endl;
248 AsyncSocketException ex(
249 AsyncSocketException::INTERNAL_ERROR,
257 for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
258 cmsg != nullptr && cmsg->cmsg_len != 0;
259 cmsg = CMSG_NXTHDR(&msg, cmsg)) {
260 if (cmsg->cmsg_level == SOL_SOCKET &&
261 cmsg->cmsg_type == SCM_TIMESTAMPING) {
262 gotTimestamp_ = true;
266 if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
267 (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
275 bool gotTimestamp_{false};
276 bool gotByteSeq_{false};
278 #endif // MSG_ERRQUEUE
280 class ReadCallbackBase :
281 public AsyncTransportWrapper::ReadCallback {
283 explicit ReadCallbackBase(WriteCallbackBase* wcb)
284 : wcb_(wcb), state(STATE_WAITING) {}
286 ~ReadCallbackBase() {
287 EXPECT_EQ(STATE_SUCCEEDED, state);
291 const std::shared_ptr<AsyncSSLSocket> &socket) {
295 void setState(StateEnum s) {
303 const AsyncSocketException& ex) noexcept override {
304 std::cerr << "readError " << ex.what() << std::endl;
305 state = STATE_FAILED;
309 void readEOF() noexcept override {
310 std::cerr << "readEOF" << std::endl;
315 std::shared_ptr<AsyncSSLSocket> socket_;
316 WriteCallbackBase *wcb_;
320 class ReadCallback : public ReadCallbackBase {
322 explicit ReadCallback(WriteCallbackBase *wcb)
323 : ReadCallbackBase(wcb)
327 for (std::vector<Buffer>::iterator it = buffers.begin();
332 currentBuffer.free();
335 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
336 if (!currentBuffer.buffer) {
337 currentBuffer.allocate(4096);
339 *bufReturn = currentBuffer.buffer;
340 *lenReturn = currentBuffer.length;
343 void readDataAvailable(size_t len) noexcept override {
344 std::cerr << "readDataAvailable, len " << len << std::endl;
346 currentBuffer.length = len;
348 wcb_->setSocket(socket_);
350 // Write back the same data.
351 socket_->write(wcb_, currentBuffer.buffer, len);
353 buffers.push_back(currentBuffer);
354 currentBuffer.reset();
355 state = STATE_SUCCEEDED;
360 Buffer() : buffer(nullptr), length(0) {}
361 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
367 void allocate(size_t len) {
368 assert(buffer == nullptr);
369 this->buffer = static_cast<char*>(malloc(len));
381 std::vector<Buffer> buffers;
382 Buffer currentBuffer;
385 class ReadErrorCallback : public ReadCallbackBase {
387 explicit ReadErrorCallback(WriteCallbackBase *wcb)
388 : ReadCallbackBase(wcb) {}
390 // Return nullptr buffer to trigger readError()
391 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
392 *bufReturn = nullptr;
396 void readDataAvailable(size_t /* len */) noexcept override {
397 // This should never to called.
402 const AsyncSocketException& ex) noexcept override {
403 ReadCallbackBase::readErr(ex);
404 std::cerr << "ReadErrorCallback::readError" << std::endl;
405 setState(STATE_SUCCEEDED);
409 class ReadEOFCallback : public ReadCallbackBase {
411 explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
413 // Return nullptr buffer to trigger readError()
414 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
415 *bufReturn = nullptr;
419 void readDataAvailable(size_t /* len */) noexcept override {
420 // This should never to called.
424 void readEOF() noexcept override {
425 ReadCallbackBase::readEOF();
426 setState(STATE_SUCCEEDED);
430 class WriteErrorCallback : public ReadCallback {
432 explicit WriteErrorCallback(WriteCallbackBase *wcb)
433 : ReadCallback(wcb) {}
435 void readDataAvailable(size_t len) noexcept override {
436 std::cerr << "readDataAvailable, len " << len << std::endl;
438 currentBuffer.length = len;
440 // close the socket before writing to trigger writeError().
441 ::close(socket_->getFd());
443 wcb_->setSocket(socket_);
445 // Write back the same data.
446 folly::test::msvcSuppressAbortOnInvalidParams([&] {
447 socket_->write(wcb_, currentBuffer.buffer, len);
450 if (wcb_->state == STATE_FAILED) {
451 setState(STATE_SUCCEEDED);
453 state = STATE_FAILED;
456 buffers.push_back(currentBuffer);
457 currentBuffer.reset();
460 void readErr(const AsyncSocketException& ex) noexcept override {
461 std::cerr << "readError " << ex.what() << std::endl;
462 // do nothing since this is expected
466 class EmptyReadCallback : public ReadCallback {
468 explicit EmptyReadCallback()
469 : ReadCallback(nullptr) {}
471 void readErr(const AsyncSocketException& ex) noexcept override {
472 std::cerr << "readError " << ex.what() << std::endl;
473 state = STATE_FAILED;
479 void readEOF() noexcept override {
480 std::cerr << "readEOF" << std::endl;
484 state = STATE_SUCCEEDED;
487 std::shared_ptr<AsyncSocket> tcpSocket_;
490 class HandshakeCallback :
491 public AsyncSSLSocket::HandshakeCB {
498 explicit HandshakeCallback(ReadCallbackBase *rcb,
499 ExpectType expect = EXPECT_SUCCESS):
500 state(STATE_WAITING),
505 const std::shared_ptr<AsyncSSLSocket> &socket) {
509 void setState(StateEnum s) {
514 // Functions inherited from AsyncSSLSocketHandshakeCallback
515 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
516 std::lock_guard<std::mutex> g(mutex_);
518 EXPECT_EQ(sock, socket_.get());
519 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
520 rcb_->setSocket(socket_);
521 sock->setReadCB(rcb_);
522 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
524 void handshakeErr(AsyncSSLSocket* /* sock */,
525 const AsyncSocketException& ex) noexcept override {
526 std::lock_guard<std::mutex> g(mutex_);
528 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
529 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
530 if (expect_ == EXPECT_ERROR) {
531 // rcb will never be invoked
532 rcb_->setState(STATE_SUCCEEDED);
534 errorString_ = ex.what();
537 void waitForHandshake() {
538 std::unique_lock<std::mutex> lock(mutex_);
539 cv_.wait(lock, [this] { return state != STATE_WAITING; });
542 ~HandshakeCallback() {
543 EXPECT_EQ(STATE_SUCCEEDED, state);
548 state = STATE_SUCCEEDED;
551 std::shared_ptr<AsyncSSLSocket> getSocket() {
556 std::shared_ptr<AsyncSSLSocket> socket_;
557 ReadCallbackBase *rcb_;
560 std::condition_variable cv_;
561 std::string errorString_;
564 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
568 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
569 uint32_t timeout = 0):
570 SSLServerAcceptCallbackBase(hcb),
573 virtual ~SSLServerAcceptCallback() {
575 // if we set a timeout, we expect failure
576 EXPECT_EQ(hcb_->state, STATE_FAILED);
577 hcb_->setState(STATE_SUCCEEDED);
581 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
583 const std::shared_ptr<folly::AsyncSSLSocket> &s)
585 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
586 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
588 hcb_->setSocket(sock);
589 sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
590 EXPECT_EQ(sock->getSSLState(),
591 AsyncSSLSocket::STATE_ACCEPTING);
593 state = STATE_SUCCEEDED;
597 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
599 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
600 SSLServerAcceptCallback(hcb) {}
602 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
604 const std::shared_ptr<folly::AsyncSSLSocket> &s)
607 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
609 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
611 int fd = sock->getFd();
615 // The accepted connection should already have TCP_NODELAY set
617 socklen_t valueLength = sizeof(value);
618 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
624 // Unset the TCP_NODELAY option.
626 socklen_t valueLength = sizeof(value);
627 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
630 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
634 SSLServerAcceptCallback::connAccepted(sock);
638 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
640 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
641 uint32_t timeout = 0):
642 SSLServerAcceptCallback(hcb, timeout) {}
644 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
646 const std::shared_ptr<folly::AsyncSSLSocket> &s)
648 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
650 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
652 hcb_->setSocket(sock);
653 sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
654 ASSERT_TRUE((sock->getSSLState() ==
655 AsyncSSLSocket::STATE_ACCEPTING) ||
656 (sock->getSSLState() ==
657 AsyncSSLSocket::STATE_CACHE_LOOKUP));
659 state = STATE_SUCCEEDED;
664 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
666 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
667 SSLServerAcceptCallbackBase(hcb) {}
669 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
671 const std::shared_ptr<folly::AsyncSSLSocket> &s)
673 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
675 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
677 // The first call to sslAccept() should succeed.
678 hcb_->setSocket(sock);
679 sock->sslAccept(hcb_);
680 EXPECT_EQ(sock->getSSLState(),
681 AsyncSSLSocket::STATE_ACCEPTING);
683 // The second call to sslAccept() should fail.
684 HandshakeCallback callback2(hcb_->rcb_);
685 callback2.setSocket(sock);
686 sock->sslAccept(&callback2);
687 EXPECT_EQ(sock->getSSLState(),
688 AsyncSSLSocket::STATE_ERROR);
690 // Both callbacks should be in the error state.
691 EXPECT_EQ(hcb_->state, STATE_FAILED);
692 EXPECT_EQ(callback2.state, STATE_FAILED);
694 state = STATE_SUCCEEDED;
695 hcb_->setState(STATE_SUCCEEDED);
696 callback2.setState(STATE_SUCCEEDED);
700 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
702 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
703 SSLServerAcceptCallbackBase(hcb) {}
705 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
707 const std::shared_ptr<folly::AsyncSSLSocket> &s)
709 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
711 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
713 hcb_->setSocket(sock);
714 sock->getEventBase()->tryRunAfterDelay([=] {
715 std::cerr << "Delayed SSL accept, client will have close by now"
717 // SSL accept will fail
720 AsyncSSLSocket::STATE_UNINIT);
721 hcb_->socket_->sslAccept(hcb_);
722 // This registers for an event
725 AsyncSSLSocket::STATE_ACCEPTING);
727 state = STATE_SUCCEEDED;
732 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
734 ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
735 // We don't care if we get invoked or not.
736 // The client may time out and give up before connAccepted() is even
738 state = STATE_SUCCEEDED;
741 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
743 const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
744 std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
746 // Just wait a while before closing the socket, so the client
747 // will time out waiting for the handshake to complete.
748 s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
752 class TestSSLAsyncCacheServer : public TestSSLServer {
754 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
755 int lookupDelay = 100) :
757 SSL_CTX *sslCtx = ctx_->getSSLCtx();
758 SSL_CTX_sess_set_get_cb(sslCtx,
759 TestSSLAsyncCacheServer::getSessionCallback);
760 SSL_CTX_set_session_cache_mode(
761 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
764 lookupDelay_ = lookupDelay;
767 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
768 uint32_t getAsyncLookups() const { return asyncLookups_; }
771 static uint32_t asyncCallbacks_;
772 static uint32_t asyncLookups_;
773 static uint32_t lookupDelay_;
775 static SSL_SESSION* getSessionCallback(SSL* ssl,
776 unsigned char* /* sess_id */,
782 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
783 if (!SSL_want_sess_cache_lookup(ssl)) {
784 // libssl.so mismatch
785 std::cerr << "no async support" << std::endl;
789 AsyncSSLSocket *sslSocket =
790 AsyncSSLSocket::getFromSSL(ssl);
791 assert(sslSocket != nullptr);
792 // Going to simulate an async cache by just running delaying the miss 100ms
793 if (asyncCallbacks_ % 2 == 0) {
794 // This socket is already blocked on lookup, return miss
795 std::cerr << "returning miss" << std::endl;
797 // fresh meat - block it
798 std::cerr << "async lookup" << std::endl;
799 sslSocket->getEventBase()->tryRunAfterDelay(
800 std::bind(&AsyncSSLSocket::restartSSLAccept,
801 sslSocket), lookupDelay_);
802 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
810 void getfds(int fds[2]);
813 std::shared_ptr<folly::SSLContext> clientCtx,
814 std::shared_ptr<folly::SSLContext> serverCtx);
817 EventBase* eventBase,
818 AsyncSSLSocket::UniquePtr* clientSock,
819 AsyncSSLSocket::UniquePtr* serverSock);
821 class BlockingWriteClient :
822 private AsyncSSLSocket::HandshakeCB,
823 private AsyncTransportWrapper::WriteCallback {
825 explicit BlockingWriteClient(
826 AsyncSSLSocket::UniquePtr socket)
827 : socket_(std::move(socket)),
831 buf_.reset(new uint8_t[bufLen_]);
832 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
837 iov_.reset(new struct iovec[iovCount_]);
838 for (uint32_t n = 0; n < iovCount_; ++n) {
839 iov_[n].iov_base = buf_.get() + n;
841 iov_[n].iov_len = n % bufLen_;
843 iov_[n].iov_len = bufLen_ - (n % bufLen_);
847 socket_->sslConn(this, std::chrono::milliseconds(100));
850 struct iovec* getIovec() const {
853 uint32_t getIovecCount() const {
858 void handshakeSuc(AsyncSSLSocket*) noexcept override {
859 socket_->writev(this, iov_.get(), iovCount_);
863 const AsyncSocketException& ex) noexcept override {
864 ADD_FAILURE() << "client handshake error: " << ex.what();
866 void writeSuccess() noexcept override {
871 const AsyncSocketException& ex) noexcept override {
872 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
876 AsyncSSLSocket::UniquePtr socket_;
879 std::unique_ptr<uint8_t[]> buf_;
880 std::unique_ptr<struct iovec[]> iov_;
883 class BlockingWriteServer :
884 private AsyncSSLSocket::HandshakeCB,
885 private AsyncTransportWrapper::ReadCallback {
887 explicit BlockingWriteServer(
888 AsyncSSLSocket::UniquePtr socket)
889 : socket_(std::move(socket)),
890 bufSize_(2500 * 2000),
892 buf_.reset(new uint8_t[bufSize_]);
893 socket_->sslAccept(this, std::chrono::milliseconds(100));
896 void checkBuffer(struct iovec* iov, uint32_t count) const {
898 for (uint32_t n = 0; n < count; ++n) {
899 size_t bytesLeft = bytesRead_ - idx;
900 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
901 std::min(iov[n].iov_len, bytesLeft));
903 FAIL() << "buffer mismatch at iovec " << n << "/" << count
907 if (iov[n].iov_len > bytesLeft) {
908 FAIL() << "server did not read enough data: "
909 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
910 << " in iovec " << n << "/" << count;
913 idx += iov[n].iov_len;
915 if (idx != bytesRead_) {
916 ADD_FAILURE() << "server read extra data: " << bytesRead_
917 << " bytes read; expected " << idx;
922 void handshakeSuc(AsyncSSLSocket*) noexcept override {
923 // Wait 10ms before reading, so the client's writes will initially block.
924 socket_->getEventBase()->tryRunAfterDelay(
925 [this] { socket_->setReadCB(this); }, 10);
929 const AsyncSocketException& ex) noexcept override {
930 ADD_FAILURE() << "server handshake error: " << ex.what();
932 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
933 *bufReturn = buf_.get() + bytesRead_;
934 *lenReturn = bufSize_ - bytesRead_;
936 void readDataAvailable(size_t len) noexcept override {
938 socket_->setReadCB(nullptr);
939 socket_->getEventBase()->tryRunAfterDelay(
940 [this] { socket_->setReadCB(this); }, 2);
942 void readEOF() noexcept override {
946 const AsyncSocketException& ex) noexcept override {
947 ADD_FAILURE() << "server read error: " << ex.what();
950 AsyncSSLSocket::UniquePtr socket_;
953 std::unique_ptr<uint8_t[]> buf_;
957 private AsyncSSLSocket::HandshakeCB,
958 private AsyncTransportWrapper::WriteCallback {
961 AsyncSSLSocket::UniquePtr socket)
962 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
963 socket_->sslConn(this);
966 const unsigned char* nextProto;
967 unsigned nextProtoLength;
968 SSLContext::NextProtocolType protocolType;
971 void handshakeSuc(AsyncSSLSocket*) noexcept override {
972 socket_->getSelectedNextProtocol(
973 &nextProto, &nextProtoLength, &protocolType);
977 const AsyncSocketException& ex) noexcept override {
978 ADD_FAILURE() << "client handshake error: " << ex.what();
980 void writeSuccess() noexcept override {
985 const AsyncSocketException& ex) noexcept override {
986 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
990 AsyncSSLSocket::UniquePtr socket_;
994 private AsyncSSLSocket::HandshakeCB,
995 private AsyncTransportWrapper::ReadCallback {
997 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
998 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
999 socket_->sslAccept(this);
1002 const unsigned char* nextProto;
1003 unsigned nextProtoLength;
1004 SSLContext::NextProtocolType protocolType;
1007 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1008 socket_->getSelectedNextProtocol(
1009 &nextProto, &nextProtoLength, &protocolType);
1013 const AsyncSocketException& ex) noexcept override {
1014 ADD_FAILURE() << "server handshake error: " << ex.what();
1016 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1019 void readDataAvailable(size_t /* len */) noexcept override {}
1020 void readEOF() noexcept override {
1024 const AsyncSocketException& ex) noexcept override {
1025 ADD_FAILURE() << "server read error: " << ex.what();
1028 AsyncSSLSocket::UniquePtr socket_;
1031 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
1032 public AsyncTransportWrapper::ReadCallback {
1034 explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
1035 : socket_(std::move(socket)) {
1036 socket_->sslAccept(this);
1039 ~RenegotiatingServer() {
1040 socket_->setReadCB(nullptr);
1043 void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
1044 LOG(INFO) << "Renegotiating server handshake success";
1045 socket_->setReadCB(this);
1049 const AsyncSocketException& ex) noexcept override {
1050 ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
1052 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1053 *lenReturn = sizeof(buf);
1056 void readDataAvailable(size_t /* len */) noexcept override {}
1057 void readEOF() noexcept override {}
1058 void readErr(const AsyncSocketException& ex) noexcept override {
1059 LOG(INFO) << "server got read error " << ex.what();
1060 auto exPtr = dynamic_cast<const SSLException*>(&ex);
1061 ASSERT_NE(nullptr, exPtr);
1062 std::string exStr(ex.what());
1063 SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
1064 ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
1065 renegotiationError_ = true;
1068 AsyncSSLSocket::UniquePtr socket_;
1069 unsigned char buf[128];
1070 bool renegotiationError_{false};
1073 #ifndef OPENSSL_NO_TLSEXT
1075 private AsyncSSLSocket::HandshakeCB,
1076 private AsyncTransportWrapper::WriteCallback {
1079 AsyncSSLSocket::UniquePtr socket)
1080 : serverNameMatch(false), socket_(std::move(socket)) {
1081 socket_->sslConn(this);
1084 bool serverNameMatch;
1087 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1088 serverNameMatch = socket_->isServerNameMatch();
1092 const AsyncSocketException& ex) noexcept override {
1093 ADD_FAILURE() << "client handshake error: " << ex.what();
1095 void writeSuccess() noexcept override {
1099 size_t bytesWritten,
1100 const AsyncSocketException& ex) noexcept override {
1101 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1105 AsyncSSLSocket::UniquePtr socket_;
1109 private AsyncSSLSocket::HandshakeCB,
1110 private AsyncTransportWrapper::ReadCallback {
1113 AsyncSSLSocket::UniquePtr socket,
1114 const std::shared_ptr<folly::SSLContext>& ctx,
1115 const std::shared_ptr<folly::SSLContext>& sniCtx,
1116 const std::string& expectedServerName)
1117 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1118 expectedServerName_(expectedServerName) {
1119 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1120 std::placeholders::_1));
1121 socket_->sslAccept(this);
1124 bool serverNameMatch;
1127 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1130 const AsyncSocketException& ex) noexcept override {
1131 ADD_FAILURE() << "server handshake error: " << ex.what();
1133 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1136 void readDataAvailable(size_t /* len */) noexcept override {}
1137 void readEOF() noexcept override {
1141 const AsyncSocketException& ex) noexcept override {
1142 ADD_FAILURE() << "server read error: " << ex.what();
1145 folly::SSLContext::ServerNameCallbackResult
1146 serverNameCallback(SSL *ssl) {
1147 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1150 !strcasecmp(expectedServerName_.c_str(), sn)) {
1151 AsyncSSLSocket *sslSocket =
1152 AsyncSSLSocket::getFromSSL(ssl);
1153 sslSocket->switchServerSSLContext(sniCtx_);
1154 serverNameMatch = true;
1155 return folly::SSLContext::SERVER_NAME_FOUND;
1157 serverNameMatch = false;
1158 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1162 AsyncSSLSocket::UniquePtr socket_;
1163 std::shared_ptr<folly::SSLContext> sniCtx_;
1164 std::string expectedServerName_;
1168 class SSLClient : public AsyncSocket::ConnectCallback,
1169 public AsyncTransportWrapper::WriteCallback,
1170 public AsyncTransportWrapper::ReadCallback
1173 EventBase *eventBase_;
1174 std::shared_ptr<AsyncSSLSocket> sslSocket_;
1175 SSL_SESSION *session_;
1176 std::shared_ptr<folly::SSLContext> ctx_;
1178 folly::SocketAddress address_;
1182 uint32_t bytesRead_;
1186 uint32_t writeAfterConnectErrors_;
1188 // These settings test that we eventually drain the
1189 // socket, even if the maxReadsPerEvent_ is hit during
1190 // a event loop iteration.
1191 static constexpr size_t kMaxReadsPerEvent = 2;
1192 // 2 event loop iterations
1193 static constexpr size_t kMaxReadBufferSz =
1194 sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1197 SSLClient(EventBase *eventBase,
1198 const folly::SocketAddress& address,
1200 uint32_t timeout = 0)
1201 : eventBase_(eventBase),
1203 requests_(requests),
1210 writeAfterConnectErrors_(0) {
1211 ctx_.reset(new folly::SSLContext());
1212 ctx_->setOptions(SSL_OP_NO_TICKET);
1213 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1214 memset(buf_, 'a', sizeof(buf_));
1219 SSL_SESSION_free(session_);
1222 EXPECT_EQ(bytesRead_, sizeof(buf_));
1226 uint32_t getHit() const { return hit_; }
1228 uint32_t getMiss() const { return miss_; }
1230 uint32_t getErrors() const { return errors_; }
1232 uint32_t getWriteAfterConnectErrors() const {
1233 return writeAfterConnectErrors_;
1236 void connect(bool writeNow = false) {
1237 sslSocket_ = AsyncSSLSocket::newSocket(
1239 if (session_ != nullptr) {
1240 sslSocket_->setSSLSession(session_);
1243 sslSocket_->connect(this, address_, timeout_);
1244 if (sslSocket_ && writeNow) {
1245 // write some junk, used in an error test
1246 sslSocket_->write(this, buf_, sizeof(buf_));
1250 void connectSuccess() noexcept override {
1251 std::cerr << "client SSL socket connected" << std::endl;
1252 if (sslSocket_->getSSLSessionReused()) {
1256 if (session_ != nullptr) {
1257 SSL_SESSION_free(session_);
1259 session_ = sslSocket_->getSSLSession();
1263 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1264 sslSocket_->write(this, buf_, sizeof(buf_));
1265 sslSocket_->setReadCB(this);
1266 memset(readbuf_, 'b', sizeof(readbuf_));
1271 const AsyncSocketException& ex) noexcept override {
1272 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1277 void writeSuccess() noexcept override {
1278 std::cerr << "client write success" << std::endl;
1281 void writeErr(size_t /* bytesWritten */,
1282 const AsyncSocketException& ex) noexcept override {
1283 std::cerr << "client writeError: " << ex.what() << std::endl;
1285 writeAfterConnectErrors_++;
1289 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1290 *bufReturn = readbuf_ + bytesRead_;
1291 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1294 void readEOF() noexcept override {
1295 std::cerr << "client readEOF" << std::endl;
1299 const AsyncSocketException& ex) noexcept override {
1300 std::cerr << "client readError: " << ex.what() << std::endl;
1303 void readDataAvailable(size_t len) noexcept override {
1304 std::cerr << "client read data: " << len << std::endl;
1306 if (bytesRead_ == sizeof(buf_)) {
1307 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1308 sslSocket_->closeNow();
1310 if (requests_ != 0) {
1318 class SSLHandshakeBase :
1319 public AsyncSSLSocket::HandshakeCB,
1320 private AsyncTransportWrapper::WriteCallback {
1322 explicit SSLHandshakeBase(
1323 AsyncSSLSocket::UniquePtr socket,
1324 bool preverifyResult,
1325 bool verifyResult) :
1326 handshakeVerify_(false),
1327 handshakeSuccess_(false),
1328 handshakeError_(false),
1329 socket_(std::move(socket)),
1330 preverifyResult_(preverifyResult),
1331 verifyResult_(verifyResult) {
1334 AsyncSSLSocket::UniquePtr moveSocket() && {
1335 return std::move(socket_);
1338 bool handshakeVerify_;
1339 bool handshakeSuccess_;
1340 bool handshakeError_;
1341 std::chrono::nanoseconds handshakeTime;
1344 AsyncSSLSocket::UniquePtr socket_;
1345 bool preverifyResult_;
1348 // HandshakeCallback
1349 bool handshakeVer(AsyncSSLSocket* /* sock */,
1351 X509_STORE_CTX* /* ctx */) noexcept override {
1352 handshakeVerify_ = true;
1354 EXPECT_EQ(preverifyResult_, preverifyOk);
1355 return verifyResult_;
1358 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1359 LOG(INFO) << "Handshake success";
1360 handshakeSuccess_ = true;
1361 handshakeTime = socket_->getHandshakeTime();
1366 const AsyncSocketException& ex) noexcept override {
1367 LOG(INFO) << "Handshake error " << ex.what();
1368 handshakeError_ = true;
1369 handshakeTime = socket_->getHandshakeTime();
1373 void writeSuccess() noexcept override {
1378 size_t bytesWritten,
1379 const AsyncSocketException& ex) noexcept override {
1380 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1385 class SSLHandshakeClient : public SSLHandshakeBase {
1388 AsyncSSLSocket::UniquePtr socket,
1389 bool preverifyResult,
1390 bool verifyResult) :
1391 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1392 socket_->sslConn(this, std::chrono::milliseconds::zero());
1396 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1398 SSLHandshakeClientNoVerify(
1399 AsyncSSLSocket::UniquePtr socket,
1400 bool preverifyResult,
1401 bool verifyResult) :
1402 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1405 std::chrono::milliseconds::zero(),
1406 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1410 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1412 SSLHandshakeClientDoVerify(
1413 AsyncSSLSocket::UniquePtr socket,
1414 bool preverifyResult,
1415 bool verifyResult) :
1416 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1419 std::chrono::milliseconds::zero(),
1420 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1424 class SSLHandshakeServer : public SSLHandshakeBase {
1427 AsyncSSLSocket::UniquePtr socket,
1428 bool preverifyResult,
1430 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1431 socket_->sslAccept(this, std::chrono::milliseconds::zero());
1435 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1437 SSLHandshakeServerParseClientHello(
1438 AsyncSSLSocket::UniquePtr socket,
1439 bool preverifyResult,
1441 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1442 socket_->enableClientHelloParsing();
1443 socket_->sslAccept(this, std::chrono::milliseconds::zero());
1446 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1449 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1450 handshakeSuccess_ = true;
1451 sock->getSSLSharedCiphers(sharedCiphers_);
1452 sock->getSSLServerCiphers(serverCiphers_);
1453 sock->getSSLClientCiphers(clientCiphers_);
1454 chosenCipher_ = sock->getNegotiatedCipherName();
1459 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1461 SSLHandshakeServerNoVerify(
1462 AsyncSSLSocket::UniquePtr socket,
1463 bool preverifyResult,
1465 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1468 std::chrono::milliseconds::zero(),
1469 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1473 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1475 SSLHandshakeServerDoVerify(
1476 AsyncSSLSocket::UniquePtr socket,
1477 bool preverifyResult,
1479 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1482 std::chrono::milliseconds::zero(),
1483 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1487 class EventBaseAborter : public AsyncTimeout {
1489 EventBaseAborter(EventBase* eventBase,
1492 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1493 , eventBase_(eventBase) {
1494 scheduleTimeout(timeoutMS);
1497 void timeoutExpired() noexcept override {
1498 FAIL() << "test timed out";
1499 eventBase_->terminateLoopSoon();
1503 EventBase* eventBase_;