2 * Copyright 2012-present Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
20 #include <folly/ExceptionWrapper.h>
21 #include <folly/SocketAddress.h>
22 #include <folly/experimental/TestUtil.h>
23 #include <folly/io/async/AsyncSSLSocket.h>
24 #include <folly/io/async/AsyncServerSocket.h>
25 #include <folly/io/async/AsyncSocket.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/io/async/AsyncTransport.h>
28 #include <folly/io/async/EventBase.h>
29 #include <folly/io/async/ssl/SSLErrors.h>
30 #include <folly/io/async/test/TestSSLServer.h>
31 #include <folly/portability/GTest.h>
32 #include <folly/portability/PThread.h>
33 #include <folly/portability/Sockets.h>
34 #include <folly/portability/Unistd.h>
37 #include <sys/types.h>
38 #include <condition_variable>
45 // The destructors of all callback classes assert that the state is
46 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
47 // are responsible for setting the succeeded state properly before the
48 // destructors are called.
50 class SendMsgParamsCallbackBase :
51 public folly::AsyncSocket::SendMsgParamsCallback {
53 SendMsgParamsCallbackBase() {}
56 const std::shared_ptr<AsyncSSLSocket> &socket) {
58 oldCallback_ = socket_->getSendMsgParamsCB();
59 socket_->setSendMsgParamCB(this);
62 int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
64 return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
67 void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
68 oldCallback_->getAncillaryData(flags, data);
71 uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
72 return oldCallback_->getAncillaryDataSize(flags);
75 std::shared_ptr<AsyncSSLSocket> socket_;
76 folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
79 class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
81 SendMsgFlagsCallback() {}
83 void resetFlags(int flags) {
87 int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
92 return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
99 class SendMsgDataCallback : public SendMsgFlagsCallback {
101 SendMsgDataCallback() {}
103 void resetData(std::vector<char>&& data) {
104 ancillaryData_.swap(data);
107 void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
108 if (ancillaryData_.size()) {
109 std::cerr << "getAncillaryData: copying data" << std::endl;
110 memcpy(data, ancillaryData_.data(), ancillaryData_.size());
112 oldCallback_->getAncillaryData(flags, data);
116 uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
117 if (ancillaryData_.size()) {
118 std::cerr << "getAncillaryDataSize: returning size" << std::endl;
119 return ancillaryData_.size();
121 return oldCallback_->getAncillaryDataSize(flags);
125 std::vector<char> ancillaryData_;
128 class WriteCallbackBase :
129 public AsyncTransportWrapper::WriteCallback {
131 explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
132 : state(STATE_WAITING)
134 , exception(AsyncSocketException::UNKNOWN, "none")
137 ~WriteCallbackBase() override {
138 EXPECT_EQ(STATE_SUCCEEDED, state);
141 virtual void setSocket(
142 const std::shared_ptr<AsyncSSLSocket> &socket) {
145 mcb_->setSocket(socket);
149 void writeSuccess() noexcept override {
150 std::cerr << "writeSuccess" << std::endl;
151 state = STATE_SUCCEEDED;
155 size_t nBytesWritten,
156 const AsyncSocketException& ex) noexcept override {
157 std::cerr << "writeError: bytesWritten " << nBytesWritten
158 << ", exception " << ex.what() << std::endl;
160 state = STATE_FAILED;
161 this->bytesWritten = nBytesWritten;
166 std::shared_ptr<AsyncSSLSocket> socket_;
169 AsyncSocketException exception;
170 SendMsgParamsCallbackBase* mcb_;
173 class ExpectWriteErrorCallback :
174 public WriteCallbackBase {
176 explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
177 : WriteCallbackBase(mcb) {}
179 ~ExpectWriteErrorCallback() override {
180 EXPECT_EQ(STATE_FAILED, state);
181 EXPECT_EQ(exception.type_,
182 AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
183 EXPECT_EQ(exception.errno_, 22);
184 // Suppress the assert in ~WriteCallbackBase()
185 state = STATE_SUCCEEDED;
189 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
190 /* copied from include/uapi/linux/net_tstamp.h */
191 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
192 enum SOF_TIMESTAMPING {
193 SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
194 SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
195 SOF_TIMESTAMPING_OPT_ID = (1 << 7),
196 SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
197 SOF_TIMESTAMPING_TX_ACK = (1 << 9),
198 SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
201 class WriteCheckTimestampCallback :
202 public WriteCallbackBase {
204 explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
205 : WriteCallbackBase(mcb) {}
207 ~WriteCheckTimestampCallback() override {
208 EXPECT_EQ(STATE_SUCCEEDED, state);
209 EXPECT_TRUE(gotTimestamp_);
210 EXPECT_TRUE(gotByteSeq_);
214 const std::shared_ptr<AsyncSSLSocket> &socket) override {
215 WriteCallbackBase::setSocket(socket);
217 EXPECT_NE(socket_->getFd(), 0);
218 int flags = SOF_TIMESTAMPING_OPT_ID
219 | SOF_TIMESTAMPING_OPT_TSONLY
220 | SOF_TIMESTAMPING_SOFTWARE;
221 AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
222 int ret = tstampingOpt.apply(socket_->getFd(), flags);
226 void checkForTimestampNotifications() noexcept {
227 int fd = socket_->getFd();
228 std::vector<char> ctrl(1024, 0);
233 memset(&msg, 0, sizeof(msg));
234 entry.iov_base = &data;
235 entry.iov_len = sizeof(data);
236 msg.msg_iov = &entry;
238 msg.msg_control = ctrl.data();
239 msg.msg_controllen = ctrl.size();
243 ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
245 if (errno != EAGAIN) {
246 auto errnoCopy = errno;
247 std::cerr << "::recvmsg exited with code " << ret
248 << ", errno: " << errnoCopy << std::endl;
249 AsyncSocketException ex(
250 AsyncSocketException::INTERNAL_ERROR,
258 for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
259 cmsg != nullptr && cmsg->cmsg_len != 0;
260 cmsg = CMSG_NXTHDR(&msg, cmsg)) {
261 if (cmsg->cmsg_level == SOL_SOCKET &&
262 cmsg->cmsg_type == SCM_TIMESTAMPING) {
263 gotTimestamp_ = true;
267 if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
268 (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
276 bool gotTimestamp_{false};
277 bool gotByteSeq_{false};
279 #endif // FOLLY_HAVE_MSG_ERRQUEUE
281 class ReadCallbackBase :
282 public AsyncTransportWrapper::ReadCallback {
284 explicit ReadCallbackBase(WriteCallbackBase* wcb)
285 : wcb_(wcb), state(STATE_WAITING) {}
287 ~ReadCallbackBase() override {
288 EXPECT_EQ(STATE_SUCCEEDED, state);
292 const std::shared_ptr<AsyncSSLSocket> &socket) {
296 void setState(StateEnum s) {
304 const AsyncSocketException& ex) noexcept override {
305 std::cerr << "readError " << ex.what() << std::endl;
306 state = STATE_FAILED;
310 void readEOF() noexcept override {
311 std::cerr << "readEOF" << std::endl;
316 std::shared_ptr<AsyncSSLSocket> socket_;
317 WriteCallbackBase *wcb_;
321 class ReadCallback : public ReadCallbackBase {
323 explicit ReadCallback(WriteCallbackBase *wcb)
324 : ReadCallbackBase(wcb)
327 ~ReadCallback() override {
328 for (std::vector<Buffer>::iterator it = buffers.begin();
333 currentBuffer.free();
336 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
337 if (!currentBuffer.buffer) {
338 currentBuffer.allocate(4096);
340 *bufReturn = currentBuffer.buffer;
341 *lenReturn = currentBuffer.length;
344 void readDataAvailable(size_t len) noexcept override {
345 std::cerr << "readDataAvailable, len " << len << std::endl;
347 currentBuffer.length = len;
349 wcb_->setSocket(socket_);
351 // Write back the same data.
352 socket_->write(wcb_, currentBuffer.buffer, len);
354 buffers.push_back(currentBuffer);
355 currentBuffer.reset();
356 state = STATE_SUCCEEDED;
361 Buffer() : buffer(nullptr), length(0) {}
362 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
368 void allocate(size_t len) {
369 assert(buffer == nullptr);
370 this->buffer = static_cast<char*>(malloc(len));
382 std::vector<Buffer> buffers;
383 Buffer currentBuffer;
386 class ReadErrorCallback : public ReadCallbackBase {
388 explicit ReadErrorCallback(WriteCallbackBase *wcb)
389 : ReadCallbackBase(wcb) {}
391 // Return nullptr buffer to trigger readError()
392 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
393 *bufReturn = nullptr;
397 void readDataAvailable(size_t /* len */) noexcept override {
398 // This should never to called.
403 const AsyncSocketException& ex) noexcept override {
404 ReadCallbackBase::readErr(ex);
405 std::cerr << "ReadErrorCallback::readError" << std::endl;
406 setState(STATE_SUCCEEDED);
410 class ReadEOFCallback : public ReadCallbackBase {
412 explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
414 // Return nullptr buffer to trigger readError()
415 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
416 *bufReturn = nullptr;
420 void readDataAvailable(size_t /* len */) noexcept override {
421 // This should never to called.
425 void readEOF() noexcept override {
426 ReadCallbackBase::readEOF();
427 setState(STATE_SUCCEEDED);
431 class WriteErrorCallback : public ReadCallback {
433 explicit WriteErrorCallback(WriteCallbackBase *wcb)
434 : ReadCallback(wcb) {}
436 void readDataAvailable(size_t len) noexcept override {
437 std::cerr << "readDataAvailable, len " << len << std::endl;
439 currentBuffer.length = len;
441 // close the socket before writing to trigger writeError().
442 ::close(socket_->getFd());
444 wcb_->setSocket(socket_);
446 // Write back the same data.
447 folly::test::msvcSuppressAbortOnInvalidParams([&] {
448 socket_->write(wcb_, currentBuffer.buffer, len);
451 if (wcb_->state == STATE_FAILED) {
452 setState(STATE_SUCCEEDED);
454 state = STATE_FAILED;
457 buffers.push_back(currentBuffer);
458 currentBuffer.reset();
461 void readErr(const AsyncSocketException& ex) noexcept override {
462 std::cerr << "readError " << ex.what() << std::endl;
463 // do nothing since this is expected
467 class EmptyReadCallback : public ReadCallback {
469 explicit EmptyReadCallback()
470 : ReadCallback(nullptr) {}
472 void readErr(const AsyncSocketException& ex) noexcept override {
473 std::cerr << "readError " << ex.what() << std::endl;
474 state = STATE_FAILED;
480 void readEOF() noexcept override {
481 std::cerr << "readEOF" << std::endl;
485 state = STATE_SUCCEEDED;
488 std::shared_ptr<AsyncSocket> tcpSocket_;
491 class HandshakeCallback :
492 public AsyncSSLSocket::HandshakeCB {
499 explicit HandshakeCallback(ReadCallbackBase *rcb,
500 ExpectType expect = EXPECT_SUCCESS):
501 state(STATE_WAITING),
506 const std::shared_ptr<AsyncSSLSocket> &socket) {
510 void setState(StateEnum s) {
515 // Functions inherited from AsyncSSLSocketHandshakeCallback
516 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
517 std::lock_guard<std::mutex> g(mutex_);
519 EXPECT_EQ(sock, socket_.get());
520 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
521 rcb_->setSocket(socket_);
522 sock->setReadCB(rcb_);
523 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
525 void handshakeErr(AsyncSSLSocket* /* sock */,
526 const AsyncSocketException& ex) noexcept override {
527 std::lock_guard<std::mutex> g(mutex_);
529 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
530 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
531 if (expect_ == EXPECT_ERROR) {
532 // rcb will never be invoked
533 rcb_->setState(STATE_SUCCEEDED);
535 errorString_ = ex.what();
538 void waitForHandshake() {
539 std::unique_lock<std::mutex> lock(mutex_);
540 cv_.wait(lock, [this] { return state != STATE_WAITING; });
543 ~HandshakeCallback() override {
544 EXPECT_EQ(STATE_SUCCEEDED, state);
549 state = STATE_SUCCEEDED;
552 std::shared_ptr<AsyncSSLSocket> getSocket() {
557 std::shared_ptr<AsyncSSLSocket> socket_;
558 ReadCallbackBase *rcb_;
561 std::condition_variable cv_;
562 std::string errorString_;
565 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
569 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
570 uint32_t timeout = 0):
571 SSLServerAcceptCallbackBase(hcb),
574 ~SSLServerAcceptCallback() override {
576 // if we set a timeout, we expect failure
577 EXPECT_EQ(hcb_->state, STATE_FAILED);
578 hcb_->setState(STATE_SUCCEEDED);
583 const std::shared_ptr<folly::AsyncSSLSocket> &s)
585 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
586 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
588 hcb_->setSocket(sock);
589 sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
590 EXPECT_EQ(sock->getSSLState(),
591 AsyncSSLSocket::STATE_ACCEPTING);
593 state = STATE_SUCCEEDED;
597 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
599 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
600 SSLServerAcceptCallback(hcb) {}
603 const std::shared_ptr<folly::AsyncSSLSocket> &s)
606 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
608 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
610 int fd = sock->getFd();
614 // The accepted connection should already have TCP_NODELAY set
616 socklen_t valueLength = sizeof(value);
617 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
623 // Unset the TCP_NODELAY option.
625 socklen_t valueLength = sizeof(value);
626 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
629 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
633 SSLServerAcceptCallback::connAccepted(sock);
637 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
639 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
640 uint32_t timeout = 0):
641 SSLServerAcceptCallback(hcb, timeout) {}
644 const std::shared_ptr<folly::AsyncSSLSocket> &s)
646 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
648 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
650 hcb_->setSocket(sock);
651 sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
652 ASSERT_TRUE((sock->getSSLState() ==
653 AsyncSSLSocket::STATE_ACCEPTING) ||
654 (sock->getSSLState() ==
655 AsyncSSLSocket::STATE_CACHE_LOOKUP));
657 state = STATE_SUCCEEDED;
662 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
664 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
665 SSLServerAcceptCallbackBase(hcb) {}
668 const std::shared_ptr<folly::AsyncSSLSocket> &s)
670 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
672 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
674 // The first call to sslAccept() should succeed.
675 hcb_->setSocket(sock);
676 sock->sslAccept(hcb_);
677 EXPECT_EQ(sock->getSSLState(),
678 AsyncSSLSocket::STATE_ACCEPTING);
680 // The second call to sslAccept() should fail.
681 HandshakeCallback callback2(hcb_->rcb_);
682 callback2.setSocket(sock);
683 sock->sslAccept(&callback2);
684 EXPECT_EQ(sock->getSSLState(),
685 AsyncSSLSocket::STATE_ERROR);
687 // Both callbacks should be in the error state.
688 EXPECT_EQ(hcb_->state, STATE_FAILED);
689 EXPECT_EQ(callback2.state, STATE_FAILED);
691 state = STATE_SUCCEEDED;
692 hcb_->setState(STATE_SUCCEEDED);
693 callback2.setState(STATE_SUCCEEDED);
697 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
699 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
700 SSLServerAcceptCallbackBase(hcb) {}
703 const std::shared_ptr<folly::AsyncSSLSocket> &s)
705 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
707 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
709 hcb_->setSocket(sock);
710 sock->getEventBase()->tryRunAfterDelay([=] {
711 std::cerr << "Delayed SSL accept, client will have close by now"
713 // SSL accept will fail
716 AsyncSSLSocket::STATE_UNINIT);
717 hcb_->socket_->sslAccept(hcb_);
718 // This registers for an event
721 AsyncSSLSocket::STATE_ACCEPTING);
723 state = STATE_SUCCEEDED;
728 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
730 ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
731 // We don't care if we get invoked or not.
732 // The client may time out and give up before connAccepted() is even
734 state = STATE_SUCCEEDED;
738 const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
739 std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
741 // Just wait a while before closing the socket, so the client
742 // will time out waiting for the handshake to complete.
743 s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
747 class TestSSLAsyncCacheServer : public TestSSLServer {
749 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
750 int lookupDelay = 100) :
752 SSL_CTX *sslCtx = ctx_->getSSLCtx();
753 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
754 SSL_CTX_sess_set_get_cb(sslCtx,
755 TestSSLAsyncCacheServer::getSessionCallback);
757 SSL_CTX_set_session_cache_mode(
758 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
761 lookupDelay_ = lookupDelay;
764 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
765 uint32_t getAsyncLookups() const { return asyncLookups_; }
768 static uint32_t asyncCallbacks_;
769 static uint32_t asyncLookups_;
770 static uint32_t lookupDelay_;
772 static SSL_SESSION* getSessionCallback(SSL* ssl,
773 unsigned char* /* sess_id */,
779 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
780 if (!SSL_want_sess_cache_lookup(ssl)) {
781 // libssl.so mismatch
782 std::cerr << "no async support" << std::endl;
786 AsyncSSLSocket *sslSocket =
787 AsyncSSLSocket::getFromSSL(ssl);
788 assert(sslSocket != nullptr);
789 // Going to simulate an async cache by just running delaying the miss 100ms
790 if (asyncCallbacks_ % 2 == 0) {
791 // This socket is already blocked on lookup, return miss
792 std::cerr << "returning miss" << std::endl;
794 // fresh meat - block it
795 std::cerr << "async lookup" << std::endl;
796 sslSocket->getEventBase()->tryRunAfterDelay(
797 std::bind(&AsyncSSLSocket::restartSSLAccept,
798 sslSocket), lookupDelay_);
799 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
807 void getfds(int fds[2]);
810 std::shared_ptr<folly::SSLContext> clientCtx,
811 std::shared_ptr<folly::SSLContext> serverCtx);
814 EventBase* eventBase,
815 AsyncSSLSocket::UniquePtr* clientSock,
816 AsyncSSLSocket::UniquePtr* serverSock);
818 class BlockingWriteClient :
819 private AsyncSSLSocket::HandshakeCB,
820 private AsyncTransportWrapper::WriteCallback {
822 explicit BlockingWriteClient(
823 AsyncSSLSocket::UniquePtr socket)
824 : socket_(std::move(socket)),
828 buf_ = std::make_unique<uint8_t[]>(bufLen_);
829 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
834 iov_ = std::make_unique<struct iovec[]>(iovCount_);
835 for (uint32_t n = 0; n < iovCount_; ++n) {
836 iov_[n].iov_base = buf_.get() + n;
838 iov_[n].iov_len = n % bufLen_;
840 iov_[n].iov_len = bufLen_ - (n % bufLen_);
844 socket_->sslConn(this, std::chrono::milliseconds(100));
847 struct iovec* getIovec() const {
850 uint32_t getIovecCount() const {
855 void handshakeSuc(AsyncSSLSocket*) noexcept override {
856 socket_->writev(this, iov_.get(), iovCount_);
860 const AsyncSocketException& ex) noexcept override {
861 ADD_FAILURE() << "client handshake error: " << ex.what();
863 void writeSuccess() noexcept override {
868 const AsyncSocketException& ex) noexcept override {
869 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
873 AsyncSSLSocket::UniquePtr socket_;
876 std::unique_ptr<uint8_t[]> buf_;
877 std::unique_ptr<struct iovec[]> iov_;
880 class BlockingWriteServer :
881 private AsyncSSLSocket::HandshakeCB,
882 private AsyncTransportWrapper::ReadCallback {
884 explicit BlockingWriteServer(
885 AsyncSSLSocket::UniquePtr socket)
886 : socket_(std::move(socket)),
887 bufSize_(2500 * 2000),
889 buf_ = std::make_unique<uint8_t[]>(bufSize_);
890 socket_->sslAccept(this, std::chrono::milliseconds(100));
893 void checkBuffer(struct iovec* iov, uint32_t count) const {
895 for (uint32_t n = 0; n < count; ++n) {
896 size_t bytesLeft = bytesRead_ - idx;
897 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
898 std::min(iov[n].iov_len, bytesLeft));
900 FAIL() << "buffer mismatch at iovec " << n << "/" << count
904 if (iov[n].iov_len > bytesLeft) {
905 FAIL() << "server did not read enough data: "
906 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
907 << " in iovec " << n << "/" << count;
910 idx += iov[n].iov_len;
912 if (idx != bytesRead_) {
913 ADD_FAILURE() << "server read extra data: " << bytesRead_
914 << " bytes read; expected " << idx;
919 void handshakeSuc(AsyncSSLSocket*) noexcept override {
920 // Wait 10ms before reading, so the client's writes will initially block.
921 socket_->getEventBase()->tryRunAfterDelay(
922 [this] { socket_->setReadCB(this); }, 10);
926 const AsyncSocketException& ex) noexcept override {
927 ADD_FAILURE() << "server handshake error: " << ex.what();
929 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
930 *bufReturn = buf_.get() + bytesRead_;
931 *lenReturn = bufSize_ - bytesRead_;
933 void readDataAvailable(size_t len) noexcept override {
935 socket_->setReadCB(nullptr);
936 socket_->getEventBase()->tryRunAfterDelay(
937 [this] { socket_->setReadCB(this); }, 2);
939 void readEOF() noexcept override {
943 const AsyncSocketException& ex) noexcept override {
944 ADD_FAILURE() << "server read error: " << ex.what();
947 AsyncSSLSocket::UniquePtr socket_;
950 std::unique_ptr<uint8_t[]> buf_;
954 private AsyncSSLSocket::HandshakeCB,
955 private AsyncTransportWrapper::WriteCallback {
958 AsyncSSLSocket::UniquePtr socket)
959 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
960 socket_->sslConn(this);
963 const unsigned char* nextProto;
964 unsigned nextProtoLength;
965 SSLContext::NextProtocolType protocolType;
966 folly::Optional<AsyncSocketException> except;
969 void handshakeSuc(AsyncSSLSocket*) noexcept override {
970 socket_->getSelectedNextProtocol(
971 &nextProto, &nextProtoLength, &protocolType);
975 const AsyncSocketException& ex) noexcept override {
978 void writeSuccess() noexcept override {
983 const AsyncSocketException& ex) noexcept override {
984 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
988 AsyncSSLSocket::UniquePtr socket_;
992 private AsyncSSLSocket::HandshakeCB,
993 private AsyncTransportWrapper::ReadCallback {
995 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
996 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
997 socket_->sslAccept(this);
1000 const unsigned char* nextProto;
1001 unsigned nextProtoLength;
1002 SSLContext::NextProtocolType protocolType;
1003 folly::Optional<AsyncSocketException> except;
1006 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1007 socket_->getSelectedNextProtocol(
1008 &nextProto, &nextProtoLength, &protocolType);
1012 const AsyncSocketException& ex) noexcept override {
1015 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1018 void readDataAvailable(size_t /* len */) noexcept override {}
1019 void readEOF() noexcept override {
1023 const AsyncSocketException& ex) noexcept override {
1024 ADD_FAILURE() << "server read error: " << ex.what();
1027 AsyncSSLSocket::UniquePtr socket_;
1030 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
1031 public AsyncTransportWrapper::ReadCallback {
1033 explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
1034 : socket_(std::move(socket)) {
1035 socket_->sslAccept(this);
1038 ~RenegotiatingServer() override {
1039 socket_->setReadCB(nullptr);
1042 void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
1043 LOG(INFO) << "Renegotiating server handshake success";
1044 socket_->setReadCB(this);
1048 const AsyncSocketException& ex) noexcept override {
1049 ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
1051 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1052 *lenReturn = sizeof(buf);
1055 void readDataAvailable(size_t /* len */) noexcept override {}
1056 void readEOF() noexcept override {}
1057 void readErr(const AsyncSocketException& ex) noexcept override {
1058 LOG(INFO) << "server got read error " << ex.what();
1059 auto exPtr = dynamic_cast<const SSLException*>(&ex);
1060 ASSERT_NE(nullptr, exPtr);
1061 std::string exStr(ex.what());
1062 SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
1063 ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
1064 renegotiationError_ = true;
1067 AsyncSSLSocket::UniquePtr socket_;
1068 unsigned char buf[128];
1069 bool renegotiationError_{false};
1072 #ifndef OPENSSL_NO_TLSEXT
1074 private AsyncSSLSocket::HandshakeCB,
1075 private AsyncTransportWrapper::WriteCallback {
1078 AsyncSSLSocket::UniquePtr socket)
1079 : serverNameMatch(false), socket_(std::move(socket)) {
1080 socket_->sslConn(this);
1083 bool serverNameMatch;
1086 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1087 serverNameMatch = socket_->isServerNameMatch();
1091 const AsyncSocketException& ex) noexcept override {
1092 ADD_FAILURE() << "client handshake error: " << ex.what();
1094 void writeSuccess() noexcept override {
1098 size_t bytesWritten,
1099 const AsyncSocketException& ex) noexcept override {
1100 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1104 AsyncSSLSocket::UniquePtr socket_;
1108 private AsyncSSLSocket::HandshakeCB,
1109 private AsyncTransportWrapper::ReadCallback {
1112 AsyncSSLSocket::UniquePtr socket,
1113 const std::shared_ptr<folly::SSLContext>& ctx,
1114 const std::shared_ptr<folly::SSLContext>& sniCtx,
1115 const std::string& expectedServerName)
1116 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1117 expectedServerName_(expectedServerName) {
1118 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1119 std::placeholders::_1));
1120 socket_->sslAccept(this);
1123 bool serverNameMatch;
1126 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1129 const AsyncSocketException& ex) noexcept override {
1130 ADD_FAILURE() << "server handshake error: " << ex.what();
1132 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1135 void readDataAvailable(size_t /* len */) noexcept override {}
1136 void readEOF() noexcept override {
1140 const AsyncSocketException& ex) noexcept override {
1141 ADD_FAILURE() << "server read error: " << ex.what();
1144 folly::SSLContext::ServerNameCallbackResult
1145 serverNameCallback(SSL *ssl) {
1146 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1149 !strcasecmp(expectedServerName_.c_str(), sn)) {
1150 AsyncSSLSocket *sslSocket =
1151 AsyncSSLSocket::getFromSSL(ssl);
1152 sslSocket->switchServerSSLContext(sniCtx_);
1153 serverNameMatch = true;
1154 return folly::SSLContext::SERVER_NAME_FOUND;
1156 serverNameMatch = false;
1157 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1161 AsyncSSLSocket::UniquePtr socket_;
1162 std::shared_ptr<folly::SSLContext> sniCtx_;
1163 std::string expectedServerName_;
1167 class SSLClient : public AsyncSocket::ConnectCallback,
1168 public AsyncTransportWrapper::WriteCallback,
1169 public AsyncTransportWrapper::ReadCallback
1172 EventBase *eventBase_;
1173 std::shared_ptr<AsyncSSLSocket> sslSocket_;
1174 SSL_SESSION *session_;
1175 std::shared_ptr<folly::SSLContext> ctx_;
1177 folly::SocketAddress address_;
1181 uint32_t bytesRead_;
1185 uint32_t writeAfterConnectErrors_;
1187 // These settings test that we eventually drain the
1188 // socket, even if the maxReadsPerEvent_ is hit during
1189 // a event loop iteration.
1190 static constexpr size_t kMaxReadsPerEvent = 2;
1191 // 2 event loop iterations
1192 static constexpr size_t kMaxReadBufferSz =
1193 sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1196 SSLClient(EventBase *eventBase,
1197 const folly::SocketAddress& address,
1199 uint32_t timeout = 0)
1200 : eventBase_(eventBase),
1202 requests_(requests),
1209 writeAfterConnectErrors_(0) {
1210 ctx_.reset(new folly::SSLContext());
1211 ctx_->setOptions(SSL_OP_NO_TICKET);
1212 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1213 memset(buf_, 'a', sizeof(buf_));
1216 ~SSLClient() override {
1218 SSL_SESSION_free(session_);
1221 EXPECT_EQ(bytesRead_, sizeof(buf_));
1225 uint32_t getHit() const { return hit_; }
1227 uint32_t getMiss() const { return miss_; }
1229 uint32_t getErrors() const { return errors_; }
1231 uint32_t getWriteAfterConnectErrors() const {
1232 return writeAfterConnectErrors_;
1235 void connect(bool writeNow = false) {
1236 sslSocket_ = AsyncSSLSocket::newSocket(
1238 if (session_ != nullptr) {
1239 sslSocket_->setSSLSession(session_);
1242 sslSocket_->connect(this, address_, timeout_);
1243 if (sslSocket_ && writeNow) {
1244 // write some junk, used in an error test
1245 sslSocket_->write(this, buf_, sizeof(buf_));
1249 void connectSuccess() noexcept override {
1250 std::cerr << "client SSL socket connected" << std::endl;
1251 if (sslSocket_->getSSLSessionReused()) {
1255 if (session_ != nullptr) {
1256 SSL_SESSION_free(session_);
1258 session_ = sslSocket_->getSSLSession();
1262 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1263 sslSocket_->write(this, buf_, sizeof(buf_));
1264 sslSocket_->setReadCB(this);
1265 memset(readbuf_, 'b', sizeof(readbuf_));
1270 const AsyncSocketException& ex) noexcept override {
1271 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1276 void writeSuccess() noexcept override {
1277 std::cerr << "client write success" << std::endl;
1280 void writeErr(size_t /* bytesWritten */,
1281 const AsyncSocketException& ex) noexcept override {
1282 std::cerr << "client writeError: " << ex.what() << std::endl;
1284 writeAfterConnectErrors_++;
1288 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1289 *bufReturn = readbuf_ + bytesRead_;
1290 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1293 void readEOF() noexcept override {
1294 std::cerr << "client readEOF" << std::endl;
1298 const AsyncSocketException& ex) noexcept override {
1299 std::cerr << "client readError: " << ex.what() << std::endl;
1302 void readDataAvailable(size_t len) noexcept override {
1303 std::cerr << "client read data: " << len << std::endl;
1305 if (bytesRead_ == sizeof(buf_)) {
1306 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1307 sslSocket_->closeNow();
1309 if (requests_ != 0) {
1317 class SSLHandshakeBase :
1318 public AsyncSSLSocket::HandshakeCB,
1319 private AsyncTransportWrapper::WriteCallback {
1321 explicit SSLHandshakeBase(
1322 AsyncSSLSocket::UniquePtr socket,
1323 bool preverifyResult,
1324 bool verifyResult) :
1325 handshakeVerify_(false),
1326 handshakeSuccess_(false),
1327 handshakeError_(false),
1328 socket_(std::move(socket)),
1329 preverifyResult_(preverifyResult),
1330 verifyResult_(verifyResult) {
1333 AsyncSSLSocket::UniquePtr moveSocket() && {
1334 return std::move(socket_);
1337 bool handshakeVerify_;
1338 bool handshakeSuccess_;
1339 bool handshakeError_;
1340 std::chrono::nanoseconds handshakeTime;
1343 AsyncSSLSocket::UniquePtr socket_;
1344 bool preverifyResult_;
1347 // HandshakeCallback
1348 bool handshakeVer(AsyncSSLSocket* /* sock */,
1350 X509_STORE_CTX* /* ctx */) noexcept override {
1351 handshakeVerify_ = true;
1353 EXPECT_EQ(preverifyResult_, preverifyOk);
1354 return verifyResult_;
1357 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1358 LOG(INFO) << "Handshake success";
1359 handshakeSuccess_ = true;
1360 handshakeTime = socket_->getHandshakeTime();
1365 const AsyncSocketException& ex) noexcept override {
1366 LOG(INFO) << "Handshake error " << ex.what();
1367 handshakeError_ = true;
1368 handshakeTime = socket_->getHandshakeTime();
1372 void writeSuccess() noexcept override {
1377 size_t bytesWritten,
1378 const AsyncSocketException& ex) noexcept override {
1379 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1384 class SSLHandshakeClient : public SSLHandshakeBase {
1387 AsyncSSLSocket::UniquePtr socket,
1388 bool preverifyResult,
1389 bool verifyResult) :
1390 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1391 socket_->sslConn(this, std::chrono::milliseconds::zero());
1395 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1397 SSLHandshakeClientNoVerify(
1398 AsyncSSLSocket::UniquePtr socket,
1399 bool preverifyResult,
1400 bool verifyResult) :
1401 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1404 std::chrono::milliseconds::zero(),
1405 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1409 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1411 SSLHandshakeClientDoVerify(
1412 AsyncSSLSocket::UniquePtr socket,
1413 bool preverifyResult,
1414 bool verifyResult) :
1415 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1418 std::chrono::milliseconds::zero(),
1419 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1423 class SSLHandshakeServer : public SSLHandshakeBase {
1426 AsyncSSLSocket::UniquePtr socket,
1427 bool preverifyResult,
1429 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1430 socket_->sslAccept(this, std::chrono::milliseconds::zero());
1434 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1436 SSLHandshakeServerParseClientHello(
1437 AsyncSSLSocket::UniquePtr socket,
1438 bool preverifyResult,
1440 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1441 socket_->enableClientHelloParsing();
1442 socket_->sslAccept(this, std::chrono::milliseconds::zero());
1445 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1448 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1449 handshakeSuccess_ = true;
1450 sock->getSSLSharedCiphers(sharedCiphers_);
1451 sock->getSSLServerCiphers(serverCiphers_);
1452 sock->getSSLClientCiphers(clientCiphers_);
1453 chosenCipher_ = sock->getNegotiatedCipherName();
1458 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1460 SSLHandshakeServerNoVerify(
1461 AsyncSSLSocket::UniquePtr socket,
1462 bool preverifyResult,
1464 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1467 std::chrono::milliseconds::zero(),
1468 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1472 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1474 SSLHandshakeServerDoVerify(
1475 AsyncSSLSocket::UniquePtr socket,
1476 bool preverifyResult,
1478 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1481 std::chrono::milliseconds::zero(),
1482 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1486 class EventBaseAborter : public AsyncTimeout {
1488 EventBaseAborter(EventBase* eventBase,
1491 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1492 , eventBase_(eventBase) {
1493 scheduleTimeout(timeoutMS);
1496 void timeoutExpired() noexcept override {
1497 FAIL() << "test timed out";
1498 eventBase_->terminateLoopSoon();
1502 EventBase* eventBase_;
1505 } // namespace folly