2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/experimental/TestUtil.h>
24 #include <folly/io/async/AsyncSSLSocket.h>
25 #include <folly/io/async/AsyncServerSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/AsyncTimeout.h>
28 #include <folly/io/async/AsyncTransport.h>
29 #include <folly/io/async/EventBase.h>
30 #include <folly/io/async/ssl/SSLErrors.h>
31 #include <folly/portability/GTest.h>
32 #include <folly/portability/Sockets.h>
33 #include <folly/portability/Unistd.h>
36 #include <sys/types.h>
37 #include <condition_variable>
49 // The destructors of all callback classes assert that the state is
50 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
51 // are responsible for setting the succeeded state properly before the
52 // destructors are called.
54 class WriteCallbackBase :
55 public AsyncTransportWrapper::WriteCallback {
58 : state(STATE_WAITING)
60 , exception(AsyncSocketException::UNKNOWN, "none") {}
62 ~WriteCallbackBase() {
63 EXPECT_EQ(STATE_SUCCEEDED, state);
67 const std::shared_ptr<AsyncSSLSocket> &socket) {
71 void writeSuccess() noexcept override {
72 std::cerr << "writeSuccess" << std::endl;
73 state = STATE_SUCCEEDED;
78 const AsyncSocketException& ex) noexcept override {
79 std::cerr << "writeError: bytesWritten " << nBytesWritten
80 << ", exception " << ex.what() << std::endl;
83 this->bytesWritten = nBytesWritten;
88 std::shared_ptr<AsyncSSLSocket> socket_;
91 AsyncSocketException exception;
94 class ReadCallbackBase :
95 public AsyncTransportWrapper::ReadCallback {
97 explicit ReadCallbackBase(WriteCallbackBase* wcb)
98 : wcb_(wcb), state(STATE_WAITING) {}
100 ~ReadCallbackBase() {
101 EXPECT_EQ(STATE_SUCCEEDED, state);
105 const std::shared_ptr<AsyncSSLSocket> &socket) {
109 void setState(StateEnum s) {
117 const AsyncSocketException& ex) noexcept override {
118 std::cerr << "readError " << ex.what() << std::endl;
119 state = STATE_FAILED;
123 void readEOF() noexcept override {
124 std::cerr << "readEOF" << std::endl;
129 std::shared_ptr<AsyncSSLSocket> socket_;
130 WriteCallbackBase *wcb_;
134 class ReadCallback : public ReadCallbackBase {
136 explicit ReadCallback(WriteCallbackBase *wcb)
137 : ReadCallbackBase(wcb)
141 for (std::vector<Buffer>::iterator it = buffers.begin();
146 currentBuffer.free();
149 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
150 if (!currentBuffer.buffer) {
151 currentBuffer.allocate(4096);
153 *bufReturn = currentBuffer.buffer;
154 *lenReturn = currentBuffer.length;
157 void readDataAvailable(size_t len) noexcept override {
158 std::cerr << "readDataAvailable, len " << len << std::endl;
160 currentBuffer.length = len;
162 wcb_->setSocket(socket_);
164 // Write back the same data.
165 socket_->write(wcb_, currentBuffer.buffer, len);
167 buffers.push_back(currentBuffer);
168 currentBuffer.reset();
169 state = STATE_SUCCEEDED;
174 Buffer() : buffer(nullptr), length(0) {}
175 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
181 void allocate(size_t len) {
182 assert(buffer == nullptr);
183 this->buffer = static_cast<char*>(malloc(len));
195 std::vector<Buffer> buffers;
196 Buffer currentBuffer;
199 class ReadErrorCallback : public ReadCallbackBase {
201 explicit ReadErrorCallback(WriteCallbackBase *wcb)
202 : ReadCallbackBase(wcb) {}
204 // Return nullptr buffer to trigger readError()
205 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
206 *bufReturn = nullptr;
210 void readDataAvailable(size_t /* len */) noexcept override {
211 // This should never to called.
216 const AsyncSocketException& ex) noexcept override {
217 ReadCallbackBase::readErr(ex);
218 std::cerr << "ReadErrorCallback::readError" << std::endl;
219 setState(STATE_SUCCEEDED);
223 class ReadEOFCallback : public ReadCallbackBase {
225 explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
227 // Return nullptr buffer to trigger readError()
228 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
229 *bufReturn = nullptr;
233 void readDataAvailable(size_t /* len */) noexcept override {
234 // This should never to called.
238 void readEOF() noexcept override {
239 ReadCallbackBase::readEOF();
240 setState(STATE_SUCCEEDED);
244 class WriteErrorCallback : public ReadCallback {
246 explicit WriteErrorCallback(WriteCallbackBase *wcb)
247 : ReadCallback(wcb) {}
249 void readDataAvailable(size_t len) noexcept override {
250 std::cerr << "readDataAvailable, len " << len << std::endl;
252 currentBuffer.length = len;
254 // close the socket before writing to trigger writeError().
255 ::close(socket_->getFd());
257 wcb_->setSocket(socket_);
259 // Write back the same data.
260 folly::test::msvcSuppressAbortOnInvalidParams([&] {
261 socket_->write(wcb_, currentBuffer.buffer, len);
264 if (wcb_->state == STATE_FAILED) {
265 setState(STATE_SUCCEEDED);
267 state = STATE_FAILED;
270 buffers.push_back(currentBuffer);
271 currentBuffer.reset();
274 void readErr(const AsyncSocketException& ex) noexcept override {
275 std::cerr << "readError " << ex.what() << std::endl;
276 // do nothing since this is expected
280 class EmptyReadCallback : public ReadCallback {
282 explicit EmptyReadCallback()
283 : ReadCallback(nullptr) {}
285 void readErr(const AsyncSocketException& ex) noexcept override {
286 std::cerr << "readError " << ex.what() << std::endl;
287 state = STATE_FAILED;
293 void readEOF() noexcept override {
294 std::cerr << "readEOF" << std::endl;
298 state = STATE_SUCCEEDED;
301 std::shared_ptr<AsyncSocket> tcpSocket_;
304 class HandshakeCallback :
305 public AsyncSSLSocket::HandshakeCB {
312 explicit HandshakeCallback(ReadCallbackBase *rcb,
313 ExpectType expect = EXPECT_SUCCESS):
314 state(STATE_WAITING),
319 const std::shared_ptr<AsyncSSLSocket> &socket) {
323 void setState(StateEnum s) {
328 // Functions inherited from AsyncSSLSocketHandshakeCallback
329 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
330 std::lock_guard<std::mutex> g(mutex_);
332 EXPECT_EQ(sock, socket_.get());
333 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
334 rcb_->setSocket(socket_);
335 sock->setReadCB(rcb_);
336 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
338 void handshakeErr(AsyncSSLSocket* /* sock */,
339 const AsyncSocketException& ex) noexcept override {
340 std::lock_guard<std::mutex> g(mutex_);
342 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
343 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
344 if (expect_ == EXPECT_ERROR) {
345 // rcb will never be invoked
346 rcb_->setState(STATE_SUCCEEDED);
348 errorString_ = ex.what();
351 void waitForHandshake() {
352 std::unique_lock<std::mutex> lock(mutex_);
353 cv_.wait(lock, [this] { return state != STATE_WAITING; });
356 ~HandshakeCallback() {
357 EXPECT_EQ(STATE_SUCCEEDED, state);
362 state = STATE_SUCCEEDED;
365 std::shared_ptr<AsyncSSLSocket> getSocket() {
370 std::shared_ptr<AsyncSSLSocket> socket_;
371 ReadCallbackBase *rcb_;
374 std::condition_variable cv_;
375 std::string errorString_;
378 class SSLServerAcceptCallbackBase:
379 public folly::AsyncServerSocket::AcceptCallback {
381 explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
382 state(STATE_WAITING), hcb_(hcb) {}
384 ~SSLServerAcceptCallbackBase() {
385 EXPECT_EQ(STATE_SUCCEEDED, state);
388 void acceptError(const std::exception& ex) noexcept override {
389 std::cerr << "SSLServerAcceptCallbackBase::acceptError "
390 << ex.what() << std::endl;
391 state = STATE_FAILED;
394 void connectionAccepted(
395 int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
397 socket_->detachEventBase();
399 printf("Connection accepted\n");
401 // Create a AsyncSSLSocket object with the fd. The socket should be
402 // added to the event base and in the state of accepting SSL connection.
403 socket_ = AsyncSSLSocket::newSocket(ctx_, base_, fd);
404 } catch (const std::exception &e) {
405 LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
406 "object with socket " << e.what() << fd;
412 connAccepted(socket_);
415 virtual void connAccepted(
416 const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
419 socket_->detachEventBase();
423 HandshakeCallback *hcb_;
424 std::shared_ptr<folly::SSLContext> ctx_;
425 std::shared_ptr<AsyncSSLSocket> socket_;
426 folly::EventBase* base_;
429 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
433 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
434 uint32_t timeout = 0):
435 SSLServerAcceptCallbackBase(hcb),
438 virtual ~SSLServerAcceptCallback() {
440 // if we set a timeout, we expect failure
441 EXPECT_EQ(hcb_->state, STATE_FAILED);
442 hcb_->setState(STATE_SUCCEEDED);
446 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
448 const std::shared_ptr<folly::AsyncSSLSocket> &s)
450 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
451 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
453 hcb_->setSocket(sock);
454 sock->sslAccept(hcb_, timeout_);
455 EXPECT_EQ(sock->getSSLState(),
456 AsyncSSLSocket::STATE_ACCEPTING);
458 state = STATE_SUCCEEDED;
462 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
464 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
465 SSLServerAcceptCallback(hcb) {}
467 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
469 const std::shared_ptr<folly::AsyncSSLSocket> &s)
472 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
474 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
476 int fd = sock->getFd();
480 // The accepted connection should already have TCP_NODELAY set
482 socklen_t valueLength = sizeof(value);
483 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
489 // Unset the TCP_NODELAY option.
491 socklen_t valueLength = sizeof(value);
492 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
495 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
499 SSLServerAcceptCallback::connAccepted(sock);
503 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
505 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
506 uint32_t timeout = 0):
507 SSLServerAcceptCallback(hcb, timeout) {}
509 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
511 const std::shared_ptr<folly::AsyncSSLSocket> &s)
513 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
515 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
517 hcb_->setSocket(sock);
518 sock->sslAccept(hcb_, timeout_);
519 ASSERT_TRUE((sock->getSSLState() ==
520 AsyncSSLSocket::STATE_ACCEPTING) ||
521 (sock->getSSLState() ==
522 AsyncSSLSocket::STATE_CACHE_LOOKUP));
524 state = STATE_SUCCEEDED;
529 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
531 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
532 SSLServerAcceptCallbackBase(hcb) {}
534 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
536 const std::shared_ptr<folly::AsyncSSLSocket> &s)
538 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
540 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
542 // The first call to sslAccept() should succeed.
543 hcb_->setSocket(sock);
544 sock->sslAccept(hcb_);
545 EXPECT_EQ(sock->getSSLState(),
546 AsyncSSLSocket::STATE_ACCEPTING);
548 // The second call to sslAccept() should fail.
549 HandshakeCallback callback2(hcb_->rcb_);
550 callback2.setSocket(sock);
551 sock->sslAccept(&callback2);
552 EXPECT_EQ(sock->getSSLState(),
553 AsyncSSLSocket::STATE_ERROR);
555 // Both callbacks should be in the error state.
556 EXPECT_EQ(hcb_->state, STATE_FAILED);
557 EXPECT_EQ(callback2.state, STATE_FAILED);
559 state = STATE_SUCCEEDED;
560 hcb_->setState(STATE_SUCCEEDED);
561 callback2.setState(STATE_SUCCEEDED);
565 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
567 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
568 SSLServerAcceptCallbackBase(hcb) {}
570 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
572 const std::shared_ptr<folly::AsyncSSLSocket> &s)
574 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
576 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
578 hcb_->setSocket(sock);
579 sock->getEventBase()->tryRunAfterDelay([=] {
580 std::cerr << "Delayed SSL accept, client will have close by now"
582 // SSL accept will fail
585 AsyncSSLSocket::STATE_UNINIT);
586 hcb_->socket_->sslAccept(hcb_);
587 // This registers for an event
590 AsyncSSLSocket::STATE_ACCEPTING);
592 state = STATE_SUCCEEDED;
597 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
599 ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
600 // We don't care if we get invoked or not.
601 // The client may time out and give up before connAccepted() is even
603 state = STATE_SUCCEEDED;
606 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
608 const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
609 std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
611 // Just wait a while before closing the socket, so the client
612 // will time out waiting for the handshake to complete.
613 s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
617 class TestSSLServer {
620 std::shared_ptr<folly::SSLContext> ctx_;
621 SSLServerAcceptCallbackBase *acb_;
622 std::shared_ptr<folly::AsyncServerSocket> socket_;
623 folly::SocketAddress address_;
626 static void *Main(void *ctx) {
627 TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
629 self->acb_->detach();
630 std::cerr << "Server thread exited event loop" << std::endl;
635 // Create a TestSSLServer.
636 // This immediately starts listening on the given port.
637 explicit TestSSLServer(
638 SSLServerAcceptCallbackBase* acb,
639 bool enableTFO = false);
643 evb_.runInEventBaseThread([&](){
644 socket_->stopAccepting();
646 std::cerr << "Waiting for server thread to exit" << std::endl;
647 pthread_join(thread_, nullptr);
650 EventBase &getEventBase() { return evb_; }
652 const folly::SocketAddress& getAddress() const {
657 class TestSSLAsyncCacheServer : public TestSSLServer {
659 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
660 int lookupDelay = 100) :
662 SSL_CTX *sslCtx = ctx_->getSSLCtx();
663 SSL_CTX_sess_set_get_cb(sslCtx,
664 TestSSLAsyncCacheServer::getSessionCallback);
665 SSL_CTX_set_session_cache_mode(
666 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
669 lookupDelay_ = lookupDelay;
672 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
673 uint32_t getAsyncLookups() const { return asyncLookups_; }
676 static uint32_t asyncCallbacks_;
677 static uint32_t asyncLookups_;
678 static uint32_t lookupDelay_;
680 static SSL_SESSION* getSessionCallback(SSL* ssl,
681 unsigned char* /* sess_id */,
686 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
687 if (!SSL_want_sess_cache_lookup(ssl)) {
688 // libssl.so mismatch
689 std::cerr << "no async support" << std::endl;
693 AsyncSSLSocket *sslSocket =
694 AsyncSSLSocket::getFromSSL(ssl);
695 assert(sslSocket != nullptr);
696 // Going to simulate an async cache by just running delaying the miss 100ms
697 if (asyncCallbacks_ % 2 == 0) {
698 // This socket is already blocked on lookup, return miss
699 std::cerr << "returning miss" << std::endl;
701 // fresh meat - block it
702 std::cerr << "async lookup" << std::endl;
703 sslSocket->getEventBase()->tryRunAfterDelay(
704 std::bind(&AsyncSSLSocket::restartSSLAccept,
705 sslSocket), lookupDelay_);
706 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
714 void getfds(int fds[2]);
717 std::shared_ptr<folly::SSLContext> clientCtx,
718 std::shared_ptr<folly::SSLContext> serverCtx);
721 EventBase* eventBase,
722 AsyncSSLSocket::UniquePtr* clientSock,
723 AsyncSSLSocket::UniquePtr* serverSock);
725 class BlockingWriteClient :
726 private AsyncSSLSocket::HandshakeCB,
727 private AsyncTransportWrapper::WriteCallback {
729 explicit BlockingWriteClient(
730 AsyncSSLSocket::UniquePtr socket)
731 : socket_(std::move(socket)),
735 buf_.reset(new uint8_t[bufLen_]);
736 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
741 iov_.reset(new struct iovec[iovCount_]);
742 for (uint32_t n = 0; n < iovCount_; ++n) {
743 iov_[n].iov_base = buf_.get() + n;
745 iov_[n].iov_len = n % bufLen_;
747 iov_[n].iov_len = bufLen_ - (n % bufLen_);
751 socket_->sslConn(this, 100);
754 struct iovec* getIovec() const {
757 uint32_t getIovecCount() const {
762 void handshakeSuc(AsyncSSLSocket*) noexcept override {
763 socket_->writev(this, iov_.get(), iovCount_);
767 const AsyncSocketException& ex) noexcept override {
768 ADD_FAILURE() << "client handshake error: " << ex.what();
770 void writeSuccess() noexcept override {
775 const AsyncSocketException& ex) noexcept override {
776 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
780 AsyncSSLSocket::UniquePtr socket_;
783 std::unique_ptr<uint8_t[]> buf_;
784 std::unique_ptr<struct iovec[]> iov_;
787 class BlockingWriteServer :
788 private AsyncSSLSocket::HandshakeCB,
789 private AsyncTransportWrapper::ReadCallback {
791 explicit BlockingWriteServer(
792 AsyncSSLSocket::UniquePtr socket)
793 : socket_(std::move(socket)),
794 bufSize_(2500 * 2000),
796 buf_.reset(new uint8_t[bufSize_]);
797 socket_->sslAccept(this, 100);
800 void checkBuffer(struct iovec* iov, uint32_t count) const {
802 for (uint32_t n = 0; n < count; ++n) {
803 size_t bytesLeft = bytesRead_ - idx;
804 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
805 std::min(iov[n].iov_len, bytesLeft));
807 FAIL() << "buffer mismatch at iovec " << n << "/" << count
811 if (iov[n].iov_len > bytesLeft) {
812 FAIL() << "server did not read enough data: "
813 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
814 << " in iovec " << n << "/" << count;
817 idx += iov[n].iov_len;
819 if (idx != bytesRead_) {
820 ADD_FAILURE() << "server read extra data: " << bytesRead_
821 << " bytes read; expected " << idx;
826 void handshakeSuc(AsyncSSLSocket*) noexcept override {
827 // Wait 10ms before reading, so the client's writes will initially block.
828 socket_->getEventBase()->tryRunAfterDelay(
829 [this] { socket_->setReadCB(this); }, 10);
833 const AsyncSocketException& ex) noexcept override {
834 ADD_FAILURE() << "server handshake error: " << ex.what();
836 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
837 *bufReturn = buf_.get() + bytesRead_;
838 *lenReturn = bufSize_ - bytesRead_;
840 void readDataAvailable(size_t len) noexcept override {
842 socket_->setReadCB(nullptr);
843 socket_->getEventBase()->tryRunAfterDelay(
844 [this] { socket_->setReadCB(this); }, 2);
846 void readEOF() noexcept override {
850 const AsyncSocketException& ex) noexcept override {
851 ADD_FAILURE() << "server read error: " << ex.what();
854 AsyncSSLSocket::UniquePtr socket_;
857 std::unique_ptr<uint8_t[]> buf_;
861 private AsyncSSLSocket::HandshakeCB,
862 private AsyncTransportWrapper::WriteCallback {
865 AsyncSSLSocket::UniquePtr socket)
866 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
867 socket_->sslConn(this);
870 const unsigned char* nextProto;
871 unsigned nextProtoLength;
872 SSLContext::NextProtocolType protocolType;
875 void handshakeSuc(AsyncSSLSocket*) noexcept override {
876 socket_->getSelectedNextProtocol(
877 &nextProto, &nextProtoLength, &protocolType);
881 const AsyncSocketException& ex) noexcept override {
882 ADD_FAILURE() << "client handshake error: " << ex.what();
884 void writeSuccess() noexcept override {
889 const AsyncSocketException& ex) noexcept override {
890 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
894 AsyncSSLSocket::UniquePtr socket_;
898 private AsyncSSLSocket::HandshakeCB,
899 private AsyncTransportWrapper::ReadCallback {
901 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
902 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
903 socket_->sslAccept(this);
906 const unsigned char* nextProto;
907 unsigned nextProtoLength;
908 SSLContext::NextProtocolType protocolType;
911 void handshakeSuc(AsyncSSLSocket*) noexcept override {
912 socket_->getSelectedNextProtocol(
913 &nextProto, &nextProtoLength, &protocolType);
917 const AsyncSocketException& ex) noexcept override {
918 ADD_FAILURE() << "server handshake error: " << ex.what();
920 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
923 void readDataAvailable(size_t /* len */) noexcept override {}
924 void readEOF() noexcept override {
928 const AsyncSocketException& ex) noexcept override {
929 ADD_FAILURE() << "server read error: " << ex.what();
932 AsyncSSLSocket::UniquePtr socket_;
935 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
936 public AsyncTransportWrapper::ReadCallback {
938 explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
939 : socket_(std::move(socket)) {
940 socket_->sslAccept(this);
943 ~RenegotiatingServer() {
944 socket_->setReadCB(nullptr);
947 void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
948 LOG(INFO) << "Renegotiating server handshake success";
949 socket_->setReadCB(this);
953 const AsyncSocketException& ex) noexcept override {
954 ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
956 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
957 *lenReturn = sizeof(buf);
960 void readDataAvailable(size_t /* len */) noexcept override {}
961 void readEOF() noexcept override {}
962 void readErr(const AsyncSocketException& ex) noexcept override {
963 LOG(INFO) << "server got read error " << ex.what();
964 auto exPtr = dynamic_cast<const SSLException*>(&ex);
965 ASSERT_NE(nullptr, exPtr);
966 std::string exStr(ex.what());
967 SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
968 ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
969 renegotiationError_ = true;
972 AsyncSSLSocket::UniquePtr socket_;
973 unsigned char buf[128];
974 bool renegotiationError_{false};
977 #ifndef OPENSSL_NO_TLSEXT
979 private AsyncSSLSocket::HandshakeCB,
980 private AsyncTransportWrapper::WriteCallback {
983 AsyncSSLSocket::UniquePtr socket)
984 : serverNameMatch(false), socket_(std::move(socket)) {
985 socket_->sslConn(this);
988 bool serverNameMatch;
991 void handshakeSuc(AsyncSSLSocket*) noexcept override {
992 serverNameMatch = socket_->isServerNameMatch();
996 const AsyncSocketException& ex) noexcept override {
997 ADD_FAILURE() << "client handshake error: " << ex.what();
999 void writeSuccess() noexcept override {
1003 size_t bytesWritten,
1004 const AsyncSocketException& ex) noexcept override {
1005 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1009 AsyncSSLSocket::UniquePtr socket_;
1013 private AsyncSSLSocket::HandshakeCB,
1014 private AsyncTransportWrapper::ReadCallback {
1017 AsyncSSLSocket::UniquePtr socket,
1018 const std::shared_ptr<folly::SSLContext>& ctx,
1019 const std::shared_ptr<folly::SSLContext>& sniCtx,
1020 const std::string& expectedServerName)
1021 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1022 expectedServerName_(expectedServerName) {
1023 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1024 std::placeholders::_1));
1025 socket_->sslAccept(this);
1028 bool serverNameMatch;
1031 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1034 const AsyncSocketException& ex) noexcept override {
1035 ADD_FAILURE() << "server handshake error: " << ex.what();
1037 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1040 void readDataAvailable(size_t /* len */) noexcept override {}
1041 void readEOF() noexcept override {
1045 const AsyncSocketException& ex) noexcept override {
1046 ADD_FAILURE() << "server read error: " << ex.what();
1049 folly::SSLContext::ServerNameCallbackResult
1050 serverNameCallback(SSL *ssl) {
1051 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1054 !strcasecmp(expectedServerName_.c_str(), sn)) {
1055 AsyncSSLSocket *sslSocket =
1056 AsyncSSLSocket::getFromSSL(ssl);
1057 sslSocket->switchServerSSLContext(sniCtx_);
1058 serverNameMatch = true;
1059 return folly::SSLContext::SERVER_NAME_FOUND;
1061 serverNameMatch = false;
1062 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1066 AsyncSSLSocket::UniquePtr socket_;
1067 std::shared_ptr<folly::SSLContext> sniCtx_;
1068 std::string expectedServerName_;
1072 class SSLClient : public AsyncSocket::ConnectCallback,
1073 public AsyncTransportWrapper::WriteCallback,
1074 public AsyncTransportWrapper::ReadCallback
1077 EventBase *eventBase_;
1078 std::shared_ptr<AsyncSSLSocket> sslSocket_;
1079 SSL_SESSION *session_;
1080 std::shared_ptr<folly::SSLContext> ctx_;
1082 folly::SocketAddress address_;
1086 uint32_t bytesRead_;
1090 uint32_t writeAfterConnectErrors_;
1092 // These settings test that we eventually drain the
1093 // socket, even if the maxReadsPerEvent_ is hit during
1094 // a event loop iteration.
1095 static constexpr size_t kMaxReadsPerEvent = 2;
1096 // 2 event loop iterations
1097 static constexpr size_t kMaxReadBufferSz =
1098 sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1101 SSLClient(EventBase *eventBase,
1102 const folly::SocketAddress& address,
1104 uint32_t timeout = 0)
1105 : eventBase_(eventBase),
1107 requests_(requests),
1114 writeAfterConnectErrors_(0) {
1115 ctx_.reset(new folly::SSLContext());
1116 ctx_->setOptions(SSL_OP_NO_TICKET);
1117 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1118 memset(buf_, 'a', sizeof(buf_));
1123 SSL_SESSION_free(session_);
1126 EXPECT_EQ(bytesRead_, sizeof(buf_));
1130 uint32_t getHit() const { return hit_; }
1132 uint32_t getMiss() const { return miss_; }
1134 uint32_t getErrors() const { return errors_; }
1136 uint32_t getWriteAfterConnectErrors() const {
1137 return writeAfterConnectErrors_;
1140 void connect(bool writeNow = false) {
1141 sslSocket_ = AsyncSSLSocket::newSocket(
1143 if (session_ != nullptr) {
1144 sslSocket_->setSSLSession(session_);
1147 sslSocket_->connect(this, address_, timeout_);
1148 if (sslSocket_ && writeNow) {
1149 // write some junk, used in an error test
1150 sslSocket_->write(this, buf_, sizeof(buf_));
1154 void connectSuccess() noexcept override {
1155 std::cerr << "client SSL socket connected" << std::endl;
1156 if (sslSocket_->getSSLSessionReused()) {
1160 if (session_ != nullptr) {
1161 SSL_SESSION_free(session_);
1163 session_ = sslSocket_->getSSLSession();
1167 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1168 sslSocket_->write(this, buf_, sizeof(buf_));
1169 sslSocket_->setReadCB(this);
1170 memset(readbuf_, 'b', sizeof(readbuf_));
1175 const AsyncSocketException& ex) noexcept override {
1176 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1181 void writeSuccess() noexcept override {
1182 std::cerr << "client write success" << std::endl;
1185 void writeErr(size_t /* bytesWritten */,
1186 const AsyncSocketException& ex) noexcept override {
1187 std::cerr << "client writeError: " << ex.what() << std::endl;
1189 writeAfterConnectErrors_++;
1193 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1194 *bufReturn = readbuf_ + bytesRead_;
1195 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1198 void readEOF() noexcept override {
1199 std::cerr << "client readEOF" << std::endl;
1203 const AsyncSocketException& ex) noexcept override {
1204 std::cerr << "client readError: " << ex.what() << std::endl;
1207 void readDataAvailable(size_t len) noexcept override {
1208 std::cerr << "client read data: " << len << std::endl;
1210 if (bytesRead_ == sizeof(buf_)) {
1211 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1212 sslSocket_->closeNow();
1214 if (requests_ != 0) {
1222 class SSLHandshakeBase :
1223 public AsyncSSLSocket::HandshakeCB,
1224 private AsyncTransportWrapper::WriteCallback {
1226 explicit SSLHandshakeBase(
1227 AsyncSSLSocket::UniquePtr socket,
1228 bool preverifyResult,
1229 bool verifyResult) :
1230 handshakeVerify_(false),
1231 handshakeSuccess_(false),
1232 handshakeError_(false),
1233 socket_(std::move(socket)),
1234 preverifyResult_(preverifyResult),
1235 verifyResult_(verifyResult) {
1238 AsyncSSLSocket::UniquePtr moveSocket() && {
1239 return std::move(socket_);
1242 bool handshakeVerify_;
1243 bool handshakeSuccess_;
1244 bool handshakeError_;
1245 std::chrono::nanoseconds handshakeTime;
1248 AsyncSSLSocket::UniquePtr socket_;
1249 bool preverifyResult_;
1252 // HandshakeCallback
1253 bool handshakeVer(AsyncSSLSocket* /* sock */,
1255 X509_STORE_CTX* /* ctx */) noexcept override {
1256 handshakeVerify_ = true;
1258 EXPECT_EQ(preverifyResult_, preverifyOk);
1259 return verifyResult_;
1262 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1263 LOG(INFO) << "Handshake success";
1264 handshakeSuccess_ = true;
1265 handshakeTime = socket_->getHandshakeTime();
1270 const AsyncSocketException& ex) noexcept override {
1271 LOG(INFO) << "Handshake error " << ex.what();
1272 handshakeError_ = true;
1273 handshakeTime = socket_->getHandshakeTime();
1277 void writeSuccess() noexcept override {
1282 size_t bytesWritten,
1283 const AsyncSocketException& ex) noexcept override {
1284 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1289 class SSLHandshakeClient : public SSLHandshakeBase {
1292 AsyncSSLSocket::UniquePtr socket,
1293 bool preverifyResult,
1294 bool verifyResult) :
1295 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1296 socket_->sslConn(this, 0);
1300 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1302 SSLHandshakeClientNoVerify(
1303 AsyncSSLSocket::UniquePtr socket,
1304 bool preverifyResult,
1305 bool verifyResult) :
1306 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1307 socket_->sslConn(this, 0,
1308 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1312 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1314 SSLHandshakeClientDoVerify(
1315 AsyncSSLSocket::UniquePtr socket,
1316 bool preverifyResult,
1317 bool verifyResult) :
1318 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1319 socket_->sslConn(this, 0,
1320 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1324 class SSLHandshakeServer : public SSLHandshakeBase {
1327 AsyncSSLSocket::UniquePtr socket,
1328 bool preverifyResult,
1330 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1331 socket_->sslAccept(this, 0);
1335 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1337 SSLHandshakeServerParseClientHello(
1338 AsyncSSLSocket::UniquePtr socket,
1339 bool preverifyResult,
1341 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1342 socket_->enableClientHelloParsing();
1343 socket_->sslAccept(this, 0);
1346 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1349 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1350 handshakeSuccess_ = true;
1351 sock->getSSLSharedCiphers(sharedCiphers_);
1352 sock->getSSLServerCiphers(serverCiphers_);
1353 sock->getSSLClientCiphers(clientCiphers_);
1354 chosenCipher_ = sock->getNegotiatedCipherName();
1359 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1361 SSLHandshakeServerNoVerify(
1362 AsyncSSLSocket::UniquePtr socket,
1363 bool preverifyResult,
1365 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1366 socket_->sslAccept(this, 0,
1367 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1371 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1373 SSLHandshakeServerDoVerify(
1374 AsyncSSLSocket::UniquePtr socket,
1375 bool preverifyResult,
1377 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1378 socket_->sslAccept(this, 0,
1379 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1383 class EventBaseAborter : public AsyncTimeout {
1385 EventBaseAborter(EventBase* eventBase,
1388 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1389 , eventBase_(eventBase) {
1390 scheduleTimeout(timeoutMS);
1393 void timeoutExpired() noexcept override {
1394 FAIL() << "test timed out";
1395 eventBase_->terminateLoopSoon();
1399 EventBase* eventBase_;