2 * Copyright 2017 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/experimental/TestUtil.h>
24 #include <folly/io/async/AsyncSSLSocket.h>
25 #include <folly/io/async/AsyncServerSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/AsyncTimeout.h>
28 #include <folly/io/async/AsyncTransport.h>
29 #include <folly/io/async/EventBase.h>
30 #include <folly/io/async/ssl/SSLErrors.h>
31 #include <folly/io/async/test/TestSSLServer.h>
32 #include <folly/portability/GTest.h>
33 #include <folly/portability/Sockets.h>
34 #include <folly/portability/Unistd.h>
37 #include <sys/types.h>
38 #include <condition_variable>
44 // The destructors of all callback classes assert that the state is
45 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
46 // are responsible for setting the succeeded state properly before the
47 // destructors are called.
49 class SendMsgParamsCallbackBase :
50 public folly::AsyncSocket::SendMsgParamsCallback {
52 SendMsgParamsCallbackBase() {}
55 const std::shared_ptr<AsyncSSLSocket> &socket) {
57 oldCallback_ = socket_->getSendMsgParamsCB();
58 socket_->setSendMsgParamCB(this);
61 int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
63 return oldCallback_->getFlags(flags);
66 void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
67 oldCallback_->getAncillaryData(flags, data);
70 uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
71 return oldCallback_->getAncillaryDataSize(flags);
74 std::shared_ptr<AsyncSSLSocket> socket_;
75 folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
78 class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
80 SendMsgFlagsCallback() {}
82 void resetFlags(int flags) {
86 int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
91 return oldCallback_->getFlags(flags);
98 class SendMsgDataCallback : public SendMsgFlagsCallback {
100 SendMsgDataCallback() {}
102 void resetData(std::vector<char>&& data) {
103 ancillaryData_.swap(data);
106 void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
107 if (ancillaryData_.size()) {
108 std::cerr << "getAncillaryData: copying data" << std::endl;
109 memcpy(data, ancillaryData_.data(), ancillaryData_.size());
111 oldCallback_->getAncillaryData(flags, data);
115 uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
116 if (ancillaryData_.size()) {
117 std::cerr << "getAncillaryDataSize: returning size" << std::endl;
118 return ancillaryData_.size();
120 return oldCallback_->getAncillaryDataSize(flags);
124 std::vector<char> ancillaryData_;
127 class WriteCallbackBase :
128 public AsyncTransportWrapper::WriteCallback {
130 explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
131 : state(STATE_WAITING)
133 , exception(AsyncSocketException::UNKNOWN, "none")
136 ~WriteCallbackBase() override {
137 EXPECT_EQ(STATE_SUCCEEDED, state);
140 virtual void setSocket(
141 const std::shared_ptr<AsyncSSLSocket> &socket) {
144 mcb_->setSocket(socket);
148 void writeSuccess() noexcept override {
149 std::cerr << "writeSuccess" << std::endl;
150 state = STATE_SUCCEEDED;
154 size_t nBytesWritten,
155 const AsyncSocketException& ex) noexcept override {
156 std::cerr << "writeError: bytesWritten " << nBytesWritten
157 << ", exception " << ex.what() << std::endl;
159 state = STATE_FAILED;
160 this->bytesWritten = nBytesWritten;
165 std::shared_ptr<AsyncSSLSocket> socket_;
168 AsyncSocketException exception;
169 SendMsgParamsCallbackBase* mcb_;
172 class ExpectWriteErrorCallback :
173 public WriteCallbackBase {
175 explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
176 : WriteCallbackBase(mcb) {}
178 ~ExpectWriteErrorCallback() override {
179 EXPECT_EQ(STATE_FAILED, state);
180 EXPECT_EQ(exception.type_,
181 AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
182 EXPECT_EQ(exception.errno_, 22);
183 // Suppress the assert in ~WriteCallbackBase()
184 state = STATE_SUCCEEDED;
189 /* copied from include/uapi/linux/net_tstamp.h */
190 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
191 enum SOF_TIMESTAMPING {
192 SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
193 SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
194 SOF_TIMESTAMPING_OPT_ID = (1 << 7),
195 SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
196 SOF_TIMESTAMPING_TX_ACK = (1 << 9),
197 SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
200 class WriteCheckTimestampCallback :
201 public WriteCallbackBase {
203 explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
204 : WriteCallbackBase(mcb) {}
206 ~WriteCheckTimestampCallback() override {
207 EXPECT_EQ(STATE_SUCCEEDED, state);
208 EXPECT_TRUE(gotTimestamp_);
209 EXPECT_TRUE(gotByteSeq_);
213 const std::shared_ptr<AsyncSSLSocket> &socket) override {
214 WriteCallbackBase::setSocket(socket);
216 EXPECT_NE(socket_->getFd(), 0);
217 int flags = SOF_TIMESTAMPING_OPT_ID
218 | SOF_TIMESTAMPING_OPT_TSONLY
219 | SOF_TIMESTAMPING_SOFTWARE;
220 AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
221 int ret = tstampingOpt.apply(socket_->getFd(), flags);
225 void checkForTimestampNotifications() noexcept {
226 int fd = socket_->getFd();
227 std::vector<char> ctrl(1024, 0);
232 memset(&msg, 0, sizeof(msg));
233 entry.iov_base = &data;
234 entry.iov_len = sizeof(data);
235 msg.msg_iov = &entry;
237 msg.msg_control = ctrl.data();
238 msg.msg_controllen = ctrl.size();
242 ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
244 if (errno != EAGAIN) {
245 auto errnoCopy = errno;
246 std::cerr << "::recvmsg exited with code " << ret
247 << ", errno: " << errnoCopy << std::endl;
248 AsyncSocketException ex(
249 AsyncSocketException::INTERNAL_ERROR,
257 for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
258 cmsg != nullptr && cmsg->cmsg_len != 0;
259 cmsg = CMSG_NXTHDR(&msg, cmsg)) {
260 if (cmsg->cmsg_level == SOL_SOCKET &&
261 cmsg->cmsg_type == SCM_TIMESTAMPING) {
262 gotTimestamp_ = true;
266 if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
267 (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
275 bool gotTimestamp_{false};
276 bool gotByteSeq_{false};
278 #endif // MSG_ERRQUEUE
280 class ReadCallbackBase :
281 public AsyncTransportWrapper::ReadCallback {
283 explicit ReadCallbackBase(WriteCallbackBase* wcb)
284 : wcb_(wcb), state(STATE_WAITING) {}
286 ~ReadCallbackBase() override {
287 EXPECT_EQ(STATE_SUCCEEDED, state);
291 const std::shared_ptr<AsyncSSLSocket> &socket) {
295 void setState(StateEnum s) {
303 const AsyncSocketException& ex) noexcept override {
304 std::cerr << "readError " << ex.what() << std::endl;
305 state = STATE_FAILED;
309 void readEOF() noexcept override {
310 std::cerr << "readEOF" << std::endl;
315 std::shared_ptr<AsyncSSLSocket> socket_;
316 WriteCallbackBase *wcb_;
320 class ReadCallback : public ReadCallbackBase {
322 explicit ReadCallback(WriteCallbackBase *wcb)
323 : ReadCallbackBase(wcb)
326 ~ReadCallback() override {
327 for (std::vector<Buffer>::iterator it = buffers.begin();
332 currentBuffer.free();
335 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
336 if (!currentBuffer.buffer) {
337 currentBuffer.allocate(4096);
339 *bufReturn = currentBuffer.buffer;
340 *lenReturn = currentBuffer.length;
343 void readDataAvailable(size_t len) noexcept override {
344 std::cerr << "readDataAvailable, len " << len << std::endl;
346 currentBuffer.length = len;
348 wcb_->setSocket(socket_);
350 // Write back the same data.
351 socket_->write(wcb_, currentBuffer.buffer, len);
353 buffers.push_back(currentBuffer);
354 currentBuffer.reset();
355 state = STATE_SUCCEEDED;
360 Buffer() : buffer(nullptr), length(0) {}
361 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
367 void allocate(size_t len) {
368 assert(buffer == nullptr);
369 this->buffer = static_cast<char*>(malloc(len));
381 std::vector<Buffer> buffers;
382 Buffer currentBuffer;
385 class ReadErrorCallback : public ReadCallbackBase {
387 explicit ReadErrorCallback(WriteCallbackBase *wcb)
388 : ReadCallbackBase(wcb) {}
390 // Return nullptr buffer to trigger readError()
391 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
392 *bufReturn = nullptr;
396 void readDataAvailable(size_t /* len */) noexcept override {
397 // This should never to called.
402 const AsyncSocketException& ex) noexcept override {
403 ReadCallbackBase::readErr(ex);
404 std::cerr << "ReadErrorCallback::readError" << std::endl;
405 setState(STATE_SUCCEEDED);
409 class ReadEOFCallback : public ReadCallbackBase {
411 explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
413 // Return nullptr buffer to trigger readError()
414 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
415 *bufReturn = nullptr;
419 void readDataAvailable(size_t /* len */) noexcept override {
420 // This should never to called.
424 void readEOF() noexcept override {
425 ReadCallbackBase::readEOF();
426 setState(STATE_SUCCEEDED);
430 class WriteErrorCallback : public ReadCallback {
432 explicit WriteErrorCallback(WriteCallbackBase *wcb)
433 : ReadCallback(wcb) {}
435 void readDataAvailable(size_t len) noexcept override {
436 std::cerr << "readDataAvailable, len " << len << std::endl;
438 currentBuffer.length = len;
440 // close the socket before writing to trigger writeError().
441 ::close(socket_->getFd());
443 wcb_->setSocket(socket_);
445 // Write back the same data.
446 folly::test::msvcSuppressAbortOnInvalidParams([&] {
447 socket_->write(wcb_, currentBuffer.buffer, len);
450 if (wcb_->state == STATE_FAILED) {
451 setState(STATE_SUCCEEDED);
453 state = STATE_FAILED;
456 buffers.push_back(currentBuffer);
457 currentBuffer.reset();
460 void readErr(const AsyncSocketException& ex) noexcept override {
461 std::cerr << "readError " << ex.what() << std::endl;
462 // do nothing since this is expected
466 class EmptyReadCallback : public ReadCallback {
468 explicit EmptyReadCallback()
469 : ReadCallback(nullptr) {}
471 void readErr(const AsyncSocketException& ex) noexcept override {
472 std::cerr << "readError " << ex.what() << std::endl;
473 state = STATE_FAILED;
479 void readEOF() noexcept override {
480 std::cerr << "readEOF" << std::endl;
484 state = STATE_SUCCEEDED;
487 std::shared_ptr<AsyncSocket> tcpSocket_;
490 class HandshakeCallback :
491 public AsyncSSLSocket::HandshakeCB {
498 explicit HandshakeCallback(ReadCallbackBase *rcb,
499 ExpectType expect = EXPECT_SUCCESS):
500 state(STATE_WAITING),
505 const std::shared_ptr<AsyncSSLSocket> &socket) {
509 void setState(StateEnum s) {
514 // Functions inherited from AsyncSSLSocketHandshakeCallback
515 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
516 std::lock_guard<std::mutex> g(mutex_);
518 EXPECT_EQ(sock, socket_.get());
519 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
520 rcb_->setSocket(socket_);
521 sock->setReadCB(rcb_);
522 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
524 void handshakeErr(AsyncSSLSocket* /* sock */,
525 const AsyncSocketException& ex) noexcept override {
526 std::lock_guard<std::mutex> g(mutex_);
528 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
529 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
530 if (expect_ == EXPECT_ERROR) {
531 // rcb will never be invoked
532 rcb_->setState(STATE_SUCCEEDED);
534 errorString_ = ex.what();
537 void waitForHandshake() {
538 std::unique_lock<std::mutex> lock(mutex_);
539 cv_.wait(lock, [this] { return state != STATE_WAITING; });
542 ~HandshakeCallback() override {
543 EXPECT_EQ(STATE_SUCCEEDED, state);
548 state = STATE_SUCCEEDED;
551 std::shared_ptr<AsyncSSLSocket> getSocket() {
556 std::shared_ptr<AsyncSSLSocket> socket_;
557 ReadCallbackBase *rcb_;
560 std::condition_variable cv_;
561 std::string errorString_;
564 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
568 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
569 uint32_t timeout = 0):
570 SSLServerAcceptCallbackBase(hcb),
573 ~SSLServerAcceptCallback() override {
575 // if we set a timeout, we expect failure
576 EXPECT_EQ(hcb_->state, STATE_FAILED);
577 hcb_->setState(STATE_SUCCEEDED);
581 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
583 const std::shared_ptr<folly::AsyncSSLSocket> &s)
585 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
586 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
588 hcb_->setSocket(sock);
589 sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
590 EXPECT_EQ(sock->getSSLState(),
591 AsyncSSLSocket::STATE_ACCEPTING);
593 state = STATE_SUCCEEDED;
597 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
599 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
600 SSLServerAcceptCallback(hcb) {}
602 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
604 const std::shared_ptr<folly::AsyncSSLSocket> &s)
607 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
609 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
611 int fd = sock->getFd();
615 // The accepted connection should already have TCP_NODELAY set
617 socklen_t valueLength = sizeof(value);
618 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
624 // Unset the TCP_NODELAY option.
626 socklen_t valueLength = sizeof(value);
627 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
630 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
634 SSLServerAcceptCallback::connAccepted(sock);
638 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
640 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
641 uint32_t timeout = 0):
642 SSLServerAcceptCallback(hcb, timeout) {}
644 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
646 const std::shared_ptr<folly::AsyncSSLSocket> &s)
648 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
650 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
652 hcb_->setSocket(sock);
653 sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
654 ASSERT_TRUE((sock->getSSLState() ==
655 AsyncSSLSocket::STATE_ACCEPTING) ||
656 (sock->getSSLState() ==
657 AsyncSSLSocket::STATE_CACHE_LOOKUP));
659 state = STATE_SUCCEEDED;
664 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
666 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
667 SSLServerAcceptCallbackBase(hcb) {}
669 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
671 const std::shared_ptr<folly::AsyncSSLSocket> &s)
673 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
675 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
677 // The first call to sslAccept() should succeed.
678 hcb_->setSocket(sock);
679 sock->sslAccept(hcb_);
680 EXPECT_EQ(sock->getSSLState(),
681 AsyncSSLSocket::STATE_ACCEPTING);
683 // The second call to sslAccept() should fail.
684 HandshakeCallback callback2(hcb_->rcb_);
685 callback2.setSocket(sock);
686 sock->sslAccept(&callback2);
687 EXPECT_EQ(sock->getSSLState(),
688 AsyncSSLSocket::STATE_ERROR);
690 // Both callbacks should be in the error state.
691 EXPECT_EQ(hcb_->state, STATE_FAILED);
692 EXPECT_EQ(callback2.state, STATE_FAILED);
694 state = STATE_SUCCEEDED;
695 hcb_->setState(STATE_SUCCEEDED);
696 callback2.setState(STATE_SUCCEEDED);
700 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
702 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
703 SSLServerAcceptCallbackBase(hcb) {}
705 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
707 const std::shared_ptr<folly::AsyncSSLSocket> &s)
709 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
711 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
713 hcb_->setSocket(sock);
714 sock->getEventBase()->tryRunAfterDelay([=] {
715 std::cerr << "Delayed SSL accept, client will have close by now"
717 // SSL accept will fail
720 AsyncSSLSocket::STATE_UNINIT);
721 hcb_->socket_->sslAccept(hcb_);
722 // This registers for an event
725 AsyncSSLSocket::STATE_ACCEPTING);
727 state = STATE_SUCCEEDED;
732 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
734 ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
735 // We don't care if we get invoked or not.
736 // The client may time out and give up before connAccepted() is even
738 state = STATE_SUCCEEDED;
741 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
743 const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
744 std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
746 // Just wait a while before closing the socket, so the client
747 // will time out waiting for the handshake to complete.
748 s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
752 class TestSSLAsyncCacheServer : public TestSSLServer {
754 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
755 int lookupDelay = 100) :
757 SSL_CTX *sslCtx = ctx_->getSSLCtx();
758 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
759 SSL_CTX_sess_set_get_cb(sslCtx,
760 TestSSLAsyncCacheServer::getSessionCallback);
762 SSL_CTX_set_session_cache_mode(
763 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
766 lookupDelay_ = lookupDelay;
769 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
770 uint32_t getAsyncLookups() const { return asyncLookups_; }
773 static uint32_t asyncCallbacks_;
774 static uint32_t asyncLookups_;
775 static uint32_t lookupDelay_;
777 static SSL_SESSION* getSessionCallback(SSL* ssl,
778 unsigned char* /* sess_id */,
784 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
785 if (!SSL_want_sess_cache_lookup(ssl)) {
786 // libssl.so mismatch
787 std::cerr << "no async support" << std::endl;
791 AsyncSSLSocket *sslSocket =
792 AsyncSSLSocket::getFromSSL(ssl);
793 assert(sslSocket != nullptr);
794 // Going to simulate an async cache by just running delaying the miss 100ms
795 if (asyncCallbacks_ % 2 == 0) {
796 // This socket is already blocked on lookup, return miss
797 std::cerr << "returning miss" << std::endl;
799 // fresh meat - block it
800 std::cerr << "async lookup" << std::endl;
801 sslSocket->getEventBase()->tryRunAfterDelay(
802 std::bind(&AsyncSSLSocket::restartSSLAccept,
803 sslSocket), lookupDelay_);
804 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
812 void getfds(int fds[2]);
815 std::shared_ptr<folly::SSLContext> clientCtx,
816 std::shared_ptr<folly::SSLContext> serverCtx);
819 EventBase* eventBase,
820 AsyncSSLSocket::UniquePtr* clientSock,
821 AsyncSSLSocket::UniquePtr* serverSock);
823 class BlockingWriteClient :
824 private AsyncSSLSocket::HandshakeCB,
825 private AsyncTransportWrapper::WriteCallback {
827 explicit BlockingWriteClient(
828 AsyncSSLSocket::UniquePtr socket)
829 : socket_(std::move(socket)),
833 buf_.reset(new uint8_t[bufLen_]);
834 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
839 iov_.reset(new struct iovec[iovCount_]);
840 for (uint32_t n = 0; n < iovCount_; ++n) {
841 iov_[n].iov_base = buf_.get() + n;
843 iov_[n].iov_len = n % bufLen_;
845 iov_[n].iov_len = bufLen_ - (n % bufLen_);
849 socket_->sslConn(this, std::chrono::milliseconds(100));
852 struct iovec* getIovec() const {
855 uint32_t getIovecCount() const {
860 void handshakeSuc(AsyncSSLSocket*) noexcept override {
861 socket_->writev(this, iov_.get(), iovCount_);
865 const AsyncSocketException& ex) noexcept override {
866 ADD_FAILURE() << "client handshake error: " << ex.what();
868 void writeSuccess() noexcept override {
873 const AsyncSocketException& ex) noexcept override {
874 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
878 AsyncSSLSocket::UniquePtr socket_;
881 std::unique_ptr<uint8_t[]> buf_;
882 std::unique_ptr<struct iovec[]> iov_;
885 class BlockingWriteServer :
886 private AsyncSSLSocket::HandshakeCB,
887 private AsyncTransportWrapper::ReadCallback {
889 explicit BlockingWriteServer(
890 AsyncSSLSocket::UniquePtr socket)
891 : socket_(std::move(socket)),
892 bufSize_(2500 * 2000),
894 buf_.reset(new uint8_t[bufSize_]);
895 socket_->sslAccept(this, std::chrono::milliseconds(100));
898 void checkBuffer(struct iovec* iov, uint32_t count) const {
900 for (uint32_t n = 0; n < count; ++n) {
901 size_t bytesLeft = bytesRead_ - idx;
902 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
903 std::min(iov[n].iov_len, bytesLeft));
905 FAIL() << "buffer mismatch at iovec " << n << "/" << count
909 if (iov[n].iov_len > bytesLeft) {
910 FAIL() << "server did not read enough data: "
911 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
912 << " in iovec " << n << "/" << count;
915 idx += iov[n].iov_len;
917 if (idx != bytesRead_) {
918 ADD_FAILURE() << "server read extra data: " << bytesRead_
919 << " bytes read; expected " << idx;
924 void handshakeSuc(AsyncSSLSocket*) noexcept override {
925 // Wait 10ms before reading, so the client's writes will initially block.
926 socket_->getEventBase()->tryRunAfterDelay(
927 [this] { socket_->setReadCB(this); }, 10);
931 const AsyncSocketException& ex) noexcept override {
932 ADD_FAILURE() << "server handshake error: " << ex.what();
934 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
935 *bufReturn = buf_.get() + bytesRead_;
936 *lenReturn = bufSize_ - bytesRead_;
938 void readDataAvailable(size_t len) noexcept override {
940 socket_->setReadCB(nullptr);
941 socket_->getEventBase()->tryRunAfterDelay(
942 [this] { socket_->setReadCB(this); }, 2);
944 void readEOF() noexcept override {
948 const AsyncSocketException& ex) noexcept override {
949 ADD_FAILURE() << "server read error: " << ex.what();
952 AsyncSSLSocket::UniquePtr socket_;
955 std::unique_ptr<uint8_t[]> buf_;
959 private AsyncSSLSocket::HandshakeCB,
960 private AsyncTransportWrapper::WriteCallback {
963 AsyncSSLSocket::UniquePtr socket)
964 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
965 socket_->sslConn(this);
968 const unsigned char* nextProto;
969 unsigned nextProtoLength;
970 SSLContext::NextProtocolType protocolType;
971 folly::Optional<AsyncSocketException> except;
974 void handshakeSuc(AsyncSSLSocket*) noexcept override {
975 socket_->getSelectedNextProtocol(
976 &nextProto, &nextProtoLength, &protocolType);
980 const AsyncSocketException& ex) noexcept override {
983 void writeSuccess() noexcept override {
988 const AsyncSocketException& ex) noexcept override {
989 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
993 AsyncSSLSocket::UniquePtr socket_;
997 private AsyncSSLSocket::HandshakeCB,
998 private AsyncTransportWrapper::ReadCallback {
1000 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
1001 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
1002 socket_->sslAccept(this);
1005 const unsigned char* nextProto;
1006 unsigned nextProtoLength;
1007 SSLContext::NextProtocolType protocolType;
1008 folly::Optional<AsyncSocketException> except;
1011 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1012 socket_->getSelectedNextProtocol(
1013 &nextProto, &nextProtoLength, &protocolType);
1017 const AsyncSocketException& ex) noexcept override {
1020 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1023 void readDataAvailable(size_t /* len */) noexcept override {}
1024 void readEOF() noexcept override {
1028 const AsyncSocketException& ex) noexcept override {
1029 ADD_FAILURE() << "server read error: " << ex.what();
1032 AsyncSSLSocket::UniquePtr socket_;
1035 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
1036 public AsyncTransportWrapper::ReadCallback {
1038 explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
1039 : socket_(std::move(socket)) {
1040 socket_->sslAccept(this);
1043 ~RenegotiatingServer() override {
1044 socket_->setReadCB(nullptr);
1047 void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
1048 LOG(INFO) << "Renegotiating server handshake success";
1049 socket_->setReadCB(this);
1053 const AsyncSocketException& ex) noexcept override {
1054 ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
1056 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1057 *lenReturn = sizeof(buf);
1060 void readDataAvailable(size_t /* len */) noexcept override {}
1061 void readEOF() noexcept override {}
1062 void readErr(const AsyncSocketException& ex) noexcept override {
1063 LOG(INFO) << "server got read error " << ex.what();
1064 auto exPtr = dynamic_cast<const SSLException*>(&ex);
1065 ASSERT_NE(nullptr, exPtr);
1066 std::string exStr(ex.what());
1067 SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
1068 ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
1069 renegotiationError_ = true;
1072 AsyncSSLSocket::UniquePtr socket_;
1073 unsigned char buf[128];
1074 bool renegotiationError_{false};
1077 #ifndef OPENSSL_NO_TLSEXT
1079 private AsyncSSLSocket::HandshakeCB,
1080 private AsyncTransportWrapper::WriteCallback {
1083 AsyncSSLSocket::UniquePtr socket)
1084 : serverNameMatch(false), socket_(std::move(socket)) {
1085 socket_->sslConn(this);
1088 bool serverNameMatch;
1091 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1092 serverNameMatch = socket_->isServerNameMatch();
1096 const AsyncSocketException& ex) noexcept override {
1097 ADD_FAILURE() << "client handshake error: " << ex.what();
1099 void writeSuccess() noexcept override {
1103 size_t bytesWritten,
1104 const AsyncSocketException& ex) noexcept override {
1105 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1109 AsyncSSLSocket::UniquePtr socket_;
1113 private AsyncSSLSocket::HandshakeCB,
1114 private AsyncTransportWrapper::ReadCallback {
1117 AsyncSSLSocket::UniquePtr socket,
1118 const std::shared_ptr<folly::SSLContext>& ctx,
1119 const std::shared_ptr<folly::SSLContext>& sniCtx,
1120 const std::string& expectedServerName)
1121 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1122 expectedServerName_(expectedServerName) {
1123 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1124 std::placeholders::_1));
1125 socket_->sslAccept(this);
1128 bool serverNameMatch;
1131 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1134 const AsyncSocketException& ex) noexcept override {
1135 ADD_FAILURE() << "server handshake error: " << ex.what();
1137 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1140 void readDataAvailable(size_t /* len */) noexcept override {}
1141 void readEOF() noexcept override {
1145 const AsyncSocketException& ex) noexcept override {
1146 ADD_FAILURE() << "server read error: " << ex.what();
1149 folly::SSLContext::ServerNameCallbackResult
1150 serverNameCallback(SSL *ssl) {
1151 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1154 !strcasecmp(expectedServerName_.c_str(), sn)) {
1155 AsyncSSLSocket *sslSocket =
1156 AsyncSSLSocket::getFromSSL(ssl);
1157 sslSocket->switchServerSSLContext(sniCtx_);
1158 serverNameMatch = true;
1159 return folly::SSLContext::SERVER_NAME_FOUND;
1161 serverNameMatch = false;
1162 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1166 AsyncSSLSocket::UniquePtr socket_;
1167 std::shared_ptr<folly::SSLContext> sniCtx_;
1168 std::string expectedServerName_;
1172 class SSLClient : public AsyncSocket::ConnectCallback,
1173 public AsyncTransportWrapper::WriteCallback,
1174 public AsyncTransportWrapper::ReadCallback
1177 EventBase *eventBase_;
1178 std::shared_ptr<AsyncSSLSocket> sslSocket_;
1179 SSL_SESSION *session_;
1180 std::shared_ptr<folly::SSLContext> ctx_;
1182 folly::SocketAddress address_;
1186 uint32_t bytesRead_;
1190 uint32_t writeAfterConnectErrors_;
1192 // These settings test that we eventually drain the
1193 // socket, even if the maxReadsPerEvent_ is hit during
1194 // a event loop iteration.
1195 static constexpr size_t kMaxReadsPerEvent = 2;
1196 // 2 event loop iterations
1197 static constexpr size_t kMaxReadBufferSz =
1198 sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1201 SSLClient(EventBase *eventBase,
1202 const folly::SocketAddress& address,
1204 uint32_t timeout = 0)
1205 : eventBase_(eventBase),
1207 requests_(requests),
1214 writeAfterConnectErrors_(0) {
1215 ctx_.reset(new folly::SSLContext());
1216 ctx_->setOptions(SSL_OP_NO_TICKET);
1217 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1218 memset(buf_, 'a', sizeof(buf_));
1221 ~SSLClient() override {
1223 SSL_SESSION_free(session_);
1226 EXPECT_EQ(bytesRead_, sizeof(buf_));
1230 uint32_t getHit() const { return hit_; }
1232 uint32_t getMiss() const { return miss_; }
1234 uint32_t getErrors() const { return errors_; }
1236 uint32_t getWriteAfterConnectErrors() const {
1237 return writeAfterConnectErrors_;
1240 void connect(bool writeNow = false) {
1241 sslSocket_ = AsyncSSLSocket::newSocket(
1243 if (session_ != nullptr) {
1244 sslSocket_->setSSLSession(session_);
1247 sslSocket_->connect(this, address_, timeout_);
1248 if (sslSocket_ && writeNow) {
1249 // write some junk, used in an error test
1250 sslSocket_->write(this, buf_, sizeof(buf_));
1254 void connectSuccess() noexcept override {
1255 std::cerr << "client SSL socket connected" << std::endl;
1256 if (sslSocket_->getSSLSessionReused()) {
1260 if (session_ != nullptr) {
1261 SSL_SESSION_free(session_);
1263 session_ = sslSocket_->getSSLSession();
1267 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1268 sslSocket_->write(this, buf_, sizeof(buf_));
1269 sslSocket_->setReadCB(this);
1270 memset(readbuf_, 'b', sizeof(readbuf_));
1275 const AsyncSocketException& ex) noexcept override {
1276 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1281 void writeSuccess() noexcept override {
1282 std::cerr << "client write success" << std::endl;
1285 void writeErr(size_t /* bytesWritten */,
1286 const AsyncSocketException& ex) noexcept override {
1287 std::cerr << "client writeError: " << ex.what() << std::endl;
1289 writeAfterConnectErrors_++;
1293 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1294 *bufReturn = readbuf_ + bytesRead_;
1295 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1298 void readEOF() noexcept override {
1299 std::cerr << "client readEOF" << std::endl;
1303 const AsyncSocketException& ex) noexcept override {
1304 std::cerr << "client readError: " << ex.what() << std::endl;
1307 void readDataAvailable(size_t len) noexcept override {
1308 std::cerr << "client read data: " << len << std::endl;
1310 if (bytesRead_ == sizeof(buf_)) {
1311 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1312 sslSocket_->closeNow();
1314 if (requests_ != 0) {
1322 class SSLHandshakeBase :
1323 public AsyncSSLSocket::HandshakeCB,
1324 private AsyncTransportWrapper::WriteCallback {
1326 explicit SSLHandshakeBase(
1327 AsyncSSLSocket::UniquePtr socket,
1328 bool preverifyResult,
1329 bool verifyResult) :
1330 handshakeVerify_(false),
1331 handshakeSuccess_(false),
1332 handshakeError_(false),
1333 socket_(std::move(socket)),
1334 preverifyResult_(preverifyResult),
1335 verifyResult_(verifyResult) {
1338 AsyncSSLSocket::UniquePtr moveSocket() && {
1339 return std::move(socket_);
1342 bool handshakeVerify_;
1343 bool handshakeSuccess_;
1344 bool handshakeError_;
1345 std::chrono::nanoseconds handshakeTime;
1348 AsyncSSLSocket::UniquePtr socket_;
1349 bool preverifyResult_;
1352 // HandshakeCallback
1353 bool handshakeVer(AsyncSSLSocket* /* sock */,
1355 X509_STORE_CTX* /* ctx */) noexcept override {
1356 handshakeVerify_ = true;
1358 EXPECT_EQ(preverifyResult_, preverifyOk);
1359 return verifyResult_;
1362 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1363 LOG(INFO) << "Handshake success";
1364 handshakeSuccess_ = true;
1365 handshakeTime = socket_->getHandshakeTime();
1370 const AsyncSocketException& ex) noexcept override {
1371 LOG(INFO) << "Handshake error " << ex.what();
1372 handshakeError_ = true;
1373 handshakeTime = socket_->getHandshakeTime();
1377 void writeSuccess() noexcept override {
1382 size_t bytesWritten,
1383 const AsyncSocketException& ex) noexcept override {
1384 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1389 class SSLHandshakeClient : public SSLHandshakeBase {
1392 AsyncSSLSocket::UniquePtr socket,
1393 bool preverifyResult,
1394 bool verifyResult) :
1395 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1396 socket_->sslConn(this, std::chrono::milliseconds::zero());
1400 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1402 SSLHandshakeClientNoVerify(
1403 AsyncSSLSocket::UniquePtr socket,
1404 bool preverifyResult,
1405 bool verifyResult) :
1406 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1409 std::chrono::milliseconds::zero(),
1410 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1414 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1416 SSLHandshakeClientDoVerify(
1417 AsyncSSLSocket::UniquePtr socket,
1418 bool preverifyResult,
1419 bool verifyResult) :
1420 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1423 std::chrono::milliseconds::zero(),
1424 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1428 class SSLHandshakeServer : public SSLHandshakeBase {
1431 AsyncSSLSocket::UniquePtr socket,
1432 bool preverifyResult,
1434 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1435 socket_->sslAccept(this, std::chrono::milliseconds::zero());
1439 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1441 SSLHandshakeServerParseClientHello(
1442 AsyncSSLSocket::UniquePtr socket,
1443 bool preverifyResult,
1445 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1446 socket_->enableClientHelloParsing();
1447 socket_->sslAccept(this, std::chrono::milliseconds::zero());
1450 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1453 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1454 handshakeSuccess_ = true;
1455 sock->getSSLSharedCiphers(sharedCiphers_);
1456 sock->getSSLServerCiphers(serverCiphers_);
1457 sock->getSSLClientCiphers(clientCiphers_);
1458 chosenCipher_ = sock->getNegotiatedCipherName();
1463 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1465 SSLHandshakeServerNoVerify(
1466 AsyncSSLSocket::UniquePtr socket,
1467 bool preverifyResult,
1469 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1472 std::chrono::milliseconds::zero(),
1473 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1477 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1479 SSLHandshakeServerDoVerify(
1480 AsyncSSLSocket::UniquePtr socket,
1481 bool preverifyResult,
1483 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1486 std::chrono::milliseconds::zero(),
1487 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1491 class EventBaseAborter : public AsyncTimeout {
1493 EventBaseAborter(EventBase* eventBase,
1496 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1497 , eventBase_(eventBase) {
1498 scheduleTimeout(timeoutMS);
1501 void timeoutExpired() noexcept override {
1502 FAIL() << "test timed out";
1503 eventBase_->terminateLoopSoon();
1507 EventBase* eventBase_;