2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/io/async/AsyncSSLSocket.h>
24 #include <folly/io/async/AsyncServerSocket.h>
25 #include <folly/io/async/AsyncSocket.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/io/async/AsyncTransport.h>
28 #include <folly/io/async/EventBase.h>
29 #include <folly/io/async/ssl/SSLErrors.h>
30 #include <folly/portability/GTest.h>
31 #include <folly/portability/Sockets.h>
32 #include <folly/portability/Unistd.h>
35 #include <sys/types.h>
36 #include <condition_variable>
48 // The destructors of all callback classes assert that the state is
49 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
50 // are responsible for setting the succeeded state properly before the
51 // destructors are called.
53 class WriteCallbackBase :
54 public AsyncTransportWrapper::WriteCallback {
57 : state(STATE_WAITING)
59 , exception(AsyncSocketException::UNKNOWN, "none") {}
61 ~WriteCallbackBase() {
62 EXPECT_EQ(STATE_SUCCEEDED, state);
66 const std::shared_ptr<AsyncSSLSocket> &socket) {
70 void writeSuccess() noexcept override {
71 std::cerr << "writeSuccess" << std::endl;
72 state = STATE_SUCCEEDED;
77 const AsyncSocketException& ex) noexcept override {
78 std::cerr << "writeError: bytesWritten " << bytesWritten
79 << ", exception " << ex.what() << std::endl;
82 this->bytesWritten = bytesWritten;
85 socket_->detachEventBase();
88 std::shared_ptr<AsyncSSLSocket> socket_;
91 AsyncSocketException exception;
94 class ReadCallbackBase :
95 public AsyncTransportWrapper::ReadCallback {
97 explicit ReadCallbackBase(WriteCallbackBase* wcb)
98 : wcb_(wcb), state(STATE_WAITING) {}
100 ~ReadCallbackBase() {
101 EXPECT_EQ(STATE_SUCCEEDED, state);
105 const std::shared_ptr<AsyncSSLSocket> &socket) {
109 void setState(StateEnum s) {
117 const AsyncSocketException& ex) noexcept override {
118 std::cerr << "readError " << ex.what() << std::endl;
119 state = STATE_FAILED;
121 socket_->detachEventBase();
124 void readEOF() noexcept override {
125 std::cerr << "readEOF" << std::endl;
128 socket_->detachEventBase();
131 std::shared_ptr<AsyncSSLSocket> socket_;
132 WriteCallbackBase *wcb_;
136 class ReadCallback : public ReadCallbackBase {
138 explicit ReadCallback(WriteCallbackBase *wcb)
139 : ReadCallbackBase(wcb)
143 for (std::vector<Buffer>::iterator it = buffers.begin();
148 currentBuffer.free();
151 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
152 if (!currentBuffer.buffer) {
153 currentBuffer.allocate(4096);
155 *bufReturn = currentBuffer.buffer;
156 *lenReturn = currentBuffer.length;
159 void readDataAvailable(size_t len) noexcept override {
160 std::cerr << "readDataAvailable, len " << len << std::endl;
162 currentBuffer.length = len;
164 wcb_->setSocket(socket_);
166 // Write back the same data.
167 socket_->write(wcb_, currentBuffer.buffer, len);
169 buffers.push_back(currentBuffer);
170 currentBuffer.reset();
171 state = STATE_SUCCEEDED;
176 Buffer() : buffer(nullptr), length(0) {}
177 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
183 void allocate(size_t length) {
184 assert(buffer == nullptr);
185 this->buffer = static_cast<char*>(malloc(length));
186 this->length = length;
197 std::vector<Buffer> buffers;
198 Buffer currentBuffer;
201 class ReadErrorCallback : public ReadCallbackBase {
203 explicit ReadErrorCallback(WriteCallbackBase *wcb)
204 : ReadCallbackBase(wcb) {}
206 // Return nullptr buffer to trigger readError()
207 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
208 *bufReturn = nullptr;
212 void readDataAvailable(size_t /* len */) noexcept override {
213 // This should never to called.
218 const AsyncSocketException& ex) noexcept override {
219 ReadCallbackBase::readErr(ex);
220 std::cerr << "ReadErrorCallback::readError" << std::endl;
221 setState(STATE_SUCCEEDED);
225 class ReadEOFCallback : public ReadCallbackBase {
227 explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
229 // Return nullptr buffer to trigger readError()
230 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
231 *bufReturn = nullptr;
235 void readDataAvailable(size_t /* len */) noexcept override {
236 // This should never to called.
240 void readEOF() noexcept override {
241 ReadCallbackBase::readEOF();
242 setState(STATE_SUCCEEDED);
246 class WriteErrorCallback : public ReadCallback {
248 explicit WriteErrorCallback(WriteCallbackBase *wcb)
249 : ReadCallback(wcb) {}
251 void readDataAvailable(size_t len) noexcept override {
252 std::cerr << "readDataAvailable, len " << len << std::endl;
254 currentBuffer.length = len;
256 // close the socket before writing to trigger writeError().
257 ::close(socket_->getFd());
259 wcb_->setSocket(socket_);
261 // Write back the same data.
262 socket_->write(wcb_, currentBuffer.buffer, len);
264 if (wcb_->state == STATE_FAILED) {
265 setState(STATE_SUCCEEDED);
267 state = STATE_FAILED;
270 buffers.push_back(currentBuffer);
271 currentBuffer.reset();
274 void readErr(const AsyncSocketException& ex) noexcept override {
275 std::cerr << "readError " << ex.what() << std::endl;
276 // do nothing since this is expected
280 class EmptyReadCallback : public ReadCallback {
282 explicit EmptyReadCallback()
283 : ReadCallback(nullptr) {}
285 void readErr(const AsyncSocketException& ex) noexcept override {
286 std::cerr << "readError " << ex.what() << std::endl;
287 state = STATE_FAILED;
289 tcpSocket_->detachEventBase();
292 void readEOF() noexcept override {
293 std::cerr << "readEOF" << std::endl;
296 tcpSocket_->detachEventBase();
297 state = STATE_SUCCEEDED;
300 std::shared_ptr<AsyncSocket> tcpSocket_;
303 class HandshakeCallback :
304 public AsyncSSLSocket::HandshakeCB {
311 explicit HandshakeCallback(ReadCallbackBase *rcb,
312 ExpectType expect = EXPECT_SUCCESS):
313 state(STATE_WAITING),
318 const std::shared_ptr<AsyncSSLSocket> &socket) {
322 void setState(StateEnum s) {
327 // Functions inherited from AsyncSSLSocketHandshakeCallback
328 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
329 std::lock_guard<std::mutex> g(mutex_);
331 EXPECT_EQ(sock, socket_.get());
332 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
333 rcb_->setSocket(socket_);
334 sock->setReadCB(rcb_);
335 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
337 void handshakeErr(AsyncSSLSocket* /* sock */,
338 const AsyncSocketException& ex) noexcept override {
339 std::lock_guard<std::mutex> g(mutex_);
341 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
342 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
343 if (expect_ == EXPECT_ERROR) {
344 // rcb will never be invoked
345 rcb_->setState(STATE_SUCCEEDED);
347 errorString_ = ex.what();
350 void waitForHandshake() {
351 std::unique_lock<std::mutex> lock(mutex_);
352 cv_.wait(lock, [this] { return state != STATE_WAITING; });
355 ~HandshakeCallback() {
356 EXPECT_EQ(STATE_SUCCEEDED, state);
361 state = STATE_SUCCEEDED;
364 std::shared_ptr<AsyncSSLSocket> getSocket() {
369 std::shared_ptr<AsyncSSLSocket> socket_;
370 ReadCallbackBase *rcb_;
373 std::condition_variable cv_;
374 std::string errorString_;
377 class SSLServerAcceptCallbackBase:
378 public folly::AsyncServerSocket::AcceptCallback {
380 explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
381 state(STATE_WAITING), hcb_(hcb) {}
383 ~SSLServerAcceptCallbackBase() {
384 EXPECT_EQ(STATE_SUCCEEDED, state);
387 void acceptError(const std::exception& ex) noexcept override {
388 std::cerr << "SSLServerAcceptCallbackBase::acceptError "
389 << ex.what() << std::endl;
390 state = STATE_FAILED;
393 void connectionAccepted(
394 int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
395 printf("Connection accepted\n");
396 std::shared_ptr<AsyncSSLSocket> sslSock;
398 // Create a AsyncSSLSocket object with the fd. The socket should be
399 // added to the event base and in the state of accepting SSL connection.
400 sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
401 } catch (const std::exception &e) {
402 LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
403 "object with socket " << e.what() << fd;
409 connAccepted(sslSock);
412 virtual void connAccepted(
413 const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
416 HandshakeCallback *hcb_;
417 std::shared_ptr<folly::SSLContext> ctx_;
418 folly::EventBase* base_;
421 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
425 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
426 uint32_t timeout = 0):
427 SSLServerAcceptCallbackBase(hcb),
430 virtual ~SSLServerAcceptCallback() {
432 // if we set a timeout, we expect failure
433 EXPECT_EQ(hcb_->state, STATE_FAILED);
434 hcb_->setState(STATE_SUCCEEDED);
438 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
440 const std::shared_ptr<folly::AsyncSSLSocket> &s)
442 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
443 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
445 hcb_->setSocket(sock);
446 sock->sslAccept(hcb_, timeout_);
447 EXPECT_EQ(sock->getSSLState(),
448 AsyncSSLSocket::STATE_ACCEPTING);
450 state = STATE_SUCCEEDED;
454 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
456 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
457 SSLServerAcceptCallback(hcb) {}
459 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
461 const std::shared_ptr<folly::AsyncSSLSocket> &s)
464 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
466 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
468 int fd = sock->getFd();
472 // The accepted connection should already have TCP_NODELAY set
474 socklen_t valueLength = sizeof(value);
475 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
481 // Unset the TCP_NODELAY option.
483 socklen_t valueLength = sizeof(value);
484 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
487 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
491 SSLServerAcceptCallback::connAccepted(sock);
495 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
497 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
498 uint32_t timeout = 0):
499 SSLServerAcceptCallback(hcb, timeout) {}
501 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
503 const std::shared_ptr<folly::AsyncSSLSocket> &s)
505 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
507 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
509 hcb_->setSocket(sock);
510 sock->sslAccept(hcb_, timeout_);
511 ASSERT_TRUE((sock->getSSLState() ==
512 AsyncSSLSocket::STATE_ACCEPTING) ||
513 (sock->getSSLState() ==
514 AsyncSSLSocket::STATE_CACHE_LOOKUP));
516 state = STATE_SUCCEEDED;
521 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
523 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
524 SSLServerAcceptCallbackBase(hcb) {}
526 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
528 const std::shared_ptr<folly::AsyncSSLSocket> &s)
530 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
532 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
534 // The first call to sslAccept() should succeed.
535 hcb_->setSocket(sock);
536 sock->sslAccept(hcb_);
537 EXPECT_EQ(sock->getSSLState(),
538 AsyncSSLSocket::STATE_ACCEPTING);
540 // The second call to sslAccept() should fail.
541 HandshakeCallback callback2(hcb_->rcb_);
542 callback2.setSocket(sock);
543 sock->sslAccept(&callback2);
544 EXPECT_EQ(sock->getSSLState(),
545 AsyncSSLSocket::STATE_ERROR);
547 // Both callbacks should be in the error state.
548 EXPECT_EQ(hcb_->state, STATE_FAILED);
549 EXPECT_EQ(callback2.state, STATE_FAILED);
551 sock->detachEventBase();
553 state = STATE_SUCCEEDED;
554 hcb_->setState(STATE_SUCCEEDED);
555 callback2.setState(STATE_SUCCEEDED);
559 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
561 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
562 SSLServerAcceptCallbackBase(hcb) {}
564 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
566 const std::shared_ptr<folly::AsyncSSLSocket> &s)
568 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
570 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
572 hcb_->setSocket(sock);
573 sock->getEventBase()->tryRunAfterDelay([=] {
574 std::cerr << "Delayed SSL accept, client will have close by now"
576 // SSL accept will fail
579 AsyncSSLSocket::STATE_UNINIT);
580 hcb_->socket_->sslAccept(hcb_);
581 // This registers for an event
584 AsyncSSLSocket::STATE_ACCEPTING);
586 state = STATE_SUCCEEDED;
591 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
593 ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
594 // We don't care if we get invoked or not.
595 // The client may time out and give up before connAccepted() is even
597 state = STATE_SUCCEEDED;
600 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
602 const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
603 std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
605 // Just wait a while before closing the socket, so the client
606 // will time out waiting for the handshake to complete.
607 s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
611 class TestSSLServer {
614 std::shared_ptr<folly::SSLContext> ctx_;
615 SSLServerAcceptCallbackBase *acb_;
616 std::shared_ptr<folly::AsyncServerSocket> socket_;
617 folly::SocketAddress address_;
620 static void *Main(void *ctx) {
621 TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
623 std::cerr << "Server thread exited event loop" << std::endl;
628 // Create a TestSSLServer.
629 // This immediately starts listening on the given port.
630 explicit TestSSLServer(
631 SSLServerAcceptCallbackBase* acb,
632 bool enableTFO = false);
636 evb_.runInEventBaseThread([&](){
637 socket_->stopAccepting();
639 std::cerr << "Waiting for server thread to exit" << std::endl;
640 pthread_join(thread_, nullptr);
643 EventBase &getEventBase() { return evb_; }
645 const folly::SocketAddress& getAddress() const {
650 class TestSSLAsyncCacheServer : public TestSSLServer {
652 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
653 int lookupDelay = 100) :
655 SSL_CTX *sslCtx = ctx_->getSSLCtx();
656 SSL_CTX_sess_set_get_cb(sslCtx,
657 TestSSLAsyncCacheServer::getSessionCallback);
658 SSL_CTX_set_session_cache_mode(
659 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
662 lookupDelay_ = lookupDelay;
665 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
666 uint32_t getAsyncLookups() const { return asyncLookups_; }
669 static uint32_t asyncCallbacks_;
670 static uint32_t asyncLookups_;
671 static uint32_t lookupDelay_;
673 static SSL_SESSION* getSessionCallback(SSL* ssl,
674 unsigned char* /* sess_id */,
679 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
680 if (!SSL_want_sess_cache_lookup(ssl)) {
681 // libssl.so mismatch
682 std::cerr << "no async support" << std::endl;
686 AsyncSSLSocket *sslSocket =
687 AsyncSSLSocket::getFromSSL(ssl);
688 assert(sslSocket != nullptr);
689 // Going to simulate an async cache by just running delaying the miss 100ms
690 if (asyncCallbacks_ % 2 == 0) {
691 // This socket is already blocked on lookup, return miss
692 std::cerr << "returning miss" << std::endl;
694 // fresh meat - block it
695 std::cerr << "async lookup" << std::endl;
696 sslSocket->getEventBase()->tryRunAfterDelay(
697 std::bind(&AsyncSSLSocket::restartSSLAccept,
698 sslSocket), lookupDelay_);
699 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
707 void getfds(int fds[2]);
710 std::shared_ptr<folly::SSLContext> clientCtx,
711 std::shared_ptr<folly::SSLContext> serverCtx);
714 EventBase* eventBase,
715 AsyncSSLSocket::UniquePtr* clientSock,
716 AsyncSSLSocket::UniquePtr* serverSock);
718 class BlockingWriteClient :
719 private AsyncSSLSocket::HandshakeCB,
720 private AsyncTransportWrapper::WriteCallback {
722 explicit BlockingWriteClient(
723 AsyncSSLSocket::UniquePtr socket)
724 : socket_(std::move(socket)),
728 buf_.reset(new uint8_t[bufLen_]);
729 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
734 iov_.reset(new struct iovec[iovCount_]);
735 for (uint32_t n = 0; n < iovCount_; ++n) {
736 iov_[n].iov_base = buf_.get() + n;
738 iov_[n].iov_len = n % bufLen_;
740 iov_[n].iov_len = bufLen_ - (n % bufLen_);
744 socket_->sslConn(this, 100);
747 struct iovec* getIovec() const {
750 uint32_t getIovecCount() const {
755 void handshakeSuc(AsyncSSLSocket*) noexcept override {
756 socket_->writev(this, iov_.get(), iovCount_);
760 const AsyncSocketException& ex) noexcept override {
761 ADD_FAILURE() << "client handshake error: " << ex.what();
763 void writeSuccess() noexcept override {
768 const AsyncSocketException& ex) noexcept override {
769 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
773 AsyncSSLSocket::UniquePtr socket_;
776 std::unique_ptr<uint8_t[]> buf_;
777 std::unique_ptr<struct iovec[]> iov_;
780 class BlockingWriteServer :
781 private AsyncSSLSocket::HandshakeCB,
782 private AsyncTransportWrapper::ReadCallback {
784 explicit BlockingWriteServer(
785 AsyncSSLSocket::UniquePtr socket)
786 : socket_(std::move(socket)),
787 bufSize_(2500 * 2000),
789 buf_.reset(new uint8_t[bufSize_]);
790 socket_->sslAccept(this, 100);
793 void checkBuffer(struct iovec* iov, uint32_t count) const {
795 for (uint32_t n = 0; n < count; ++n) {
796 size_t bytesLeft = bytesRead_ - idx;
797 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
798 std::min(iov[n].iov_len, bytesLeft));
800 FAIL() << "buffer mismatch at iovec " << n << "/" << count
804 if (iov[n].iov_len > bytesLeft) {
805 FAIL() << "server did not read enough data: "
806 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
807 << " in iovec " << n << "/" << count;
810 idx += iov[n].iov_len;
812 if (idx != bytesRead_) {
813 ADD_FAILURE() << "server read extra data: " << bytesRead_
814 << " bytes read; expected " << idx;
819 void handshakeSuc(AsyncSSLSocket*) noexcept override {
820 // Wait 10ms before reading, so the client's writes will initially block.
821 socket_->getEventBase()->tryRunAfterDelay(
822 [this] { socket_->setReadCB(this); }, 10);
826 const AsyncSocketException& ex) noexcept override {
827 ADD_FAILURE() << "server handshake error: " << ex.what();
829 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
830 *bufReturn = buf_.get() + bytesRead_;
831 *lenReturn = bufSize_ - bytesRead_;
833 void readDataAvailable(size_t len) noexcept override {
835 socket_->setReadCB(nullptr);
836 socket_->getEventBase()->tryRunAfterDelay(
837 [this] { socket_->setReadCB(this); }, 2);
839 void readEOF() noexcept override {
843 const AsyncSocketException& ex) noexcept override {
844 ADD_FAILURE() << "server read error: " << ex.what();
847 AsyncSSLSocket::UniquePtr socket_;
850 std::unique_ptr<uint8_t[]> buf_;
854 private AsyncSSLSocket::HandshakeCB,
855 private AsyncTransportWrapper::WriteCallback {
858 AsyncSSLSocket::UniquePtr socket)
859 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
860 socket_->sslConn(this);
863 const unsigned char* nextProto;
864 unsigned nextProtoLength;
865 SSLContext::NextProtocolType protocolType;
868 void handshakeSuc(AsyncSSLSocket*) noexcept override {
869 socket_->getSelectedNextProtocol(
870 &nextProto, &nextProtoLength, &protocolType);
874 const AsyncSocketException& ex) noexcept override {
875 ADD_FAILURE() << "client handshake error: " << ex.what();
877 void writeSuccess() noexcept override {
882 const AsyncSocketException& ex) noexcept override {
883 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
887 AsyncSSLSocket::UniquePtr socket_;
891 private AsyncSSLSocket::HandshakeCB,
892 private AsyncTransportWrapper::ReadCallback {
894 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
895 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
896 socket_->sslAccept(this);
899 const unsigned char* nextProto;
900 unsigned nextProtoLength;
901 SSLContext::NextProtocolType protocolType;
904 void handshakeSuc(AsyncSSLSocket*) noexcept override {
905 socket_->getSelectedNextProtocol(
906 &nextProto, &nextProtoLength, &protocolType);
910 const AsyncSocketException& ex) noexcept override {
911 ADD_FAILURE() << "server handshake error: " << ex.what();
913 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
916 void readDataAvailable(size_t /* len */) noexcept override {}
917 void readEOF() noexcept override {
921 const AsyncSocketException& ex) noexcept override {
922 ADD_FAILURE() << "server read error: " << ex.what();
925 AsyncSSLSocket::UniquePtr socket_;
928 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
929 public AsyncTransportWrapper::ReadCallback {
931 explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
932 : socket_(std::move(socket)) {
933 socket_->sslAccept(this);
936 ~RenegotiatingServer() {
937 socket_->setReadCB(nullptr);
940 void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
941 LOG(INFO) << "Renegotiating server handshake success";
942 socket_->setReadCB(this);
946 const AsyncSocketException& ex) noexcept override {
947 ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
949 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
950 *lenReturn = sizeof(buf);
953 void readDataAvailable(size_t /* len */) noexcept override {}
954 void readEOF() noexcept override {}
955 void readErr(const AsyncSocketException& ex) noexcept override {
956 LOG(INFO) << "server got read error " << ex.what();
957 auto exPtr = dynamic_cast<const SSLException*>(&ex);
958 ASSERT_NE(nullptr, exPtr);
959 std::string exStr(ex.what());
960 SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
961 ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
962 renegotiationError_ = true;
965 AsyncSSLSocket::UniquePtr socket_;
966 unsigned char buf[128];
967 bool renegotiationError_{false};
970 #ifndef OPENSSL_NO_TLSEXT
972 private AsyncSSLSocket::HandshakeCB,
973 private AsyncTransportWrapper::WriteCallback {
976 AsyncSSLSocket::UniquePtr socket)
977 : serverNameMatch(false), socket_(std::move(socket)) {
978 socket_->sslConn(this);
981 bool serverNameMatch;
984 void handshakeSuc(AsyncSSLSocket*) noexcept override {
985 serverNameMatch = socket_->isServerNameMatch();
989 const AsyncSocketException& ex) noexcept override {
990 ADD_FAILURE() << "client handshake error: " << ex.what();
992 void writeSuccess() noexcept override {
997 const AsyncSocketException& ex) noexcept override {
998 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1002 AsyncSSLSocket::UniquePtr socket_;
1006 private AsyncSSLSocket::HandshakeCB,
1007 private AsyncTransportWrapper::ReadCallback {
1010 AsyncSSLSocket::UniquePtr socket,
1011 const std::shared_ptr<folly::SSLContext>& ctx,
1012 const std::shared_ptr<folly::SSLContext>& sniCtx,
1013 const std::string& expectedServerName)
1014 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1015 expectedServerName_(expectedServerName) {
1016 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1017 std::placeholders::_1));
1018 socket_->sslAccept(this);
1021 bool serverNameMatch;
1024 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1027 const AsyncSocketException& ex) noexcept override {
1028 ADD_FAILURE() << "server handshake error: " << ex.what();
1030 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1033 void readDataAvailable(size_t /* len */) noexcept override {}
1034 void readEOF() noexcept override {
1038 const AsyncSocketException& ex) noexcept override {
1039 ADD_FAILURE() << "server read error: " << ex.what();
1042 folly::SSLContext::ServerNameCallbackResult
1043 serverNameCallback(SSL *ssl) {
1044 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1047 !strcasecmp(expectedServerName_.c_str(), sn)) {
1048 AsyncSSLSocket *sslSocket =
1049 AsyncSSLSocket::getFromSSL(ssl);
1050 sslSocket->switchServerSSLContext(sniCtx_);
1051 serverNameMatch = true;
1052 return folly::SSLContext::SERVER_NAME_FOUND;
1054 serverNameMatch = false;
1055 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1059 AsyncSSLSocket::UniquePtr socket_;
1060 std::shared_ptr<folly::SSLContext> sniCtx_;
1061 std::string expectedServerName_;
1065 class SSLClient : public AsyncSocket::ConnectCallback,
1066 public AsyncTransportWrapper::WriteCallback,
1067 public AsyncTransportWrapper::ReadCallback
1070 EventBase *eventBase_;
1071 std::shared_ptr<AsyncSSLSocket> sslSocket_;
1072 SSL_SESSION *session_;
1073 std::shared_ptr<folly::SSLContext> ctx_;
1075 folly::SocketAddress address_;
1079 uint32_t bytesRead_;
1083 uint32_t writeAfterConnectErrors_;
1085 // These settings test that we eventually drain the
1086 // socket, even if the maxReadsPerEvent_ is hit during
1087 // a event loop iteration.
1088 static constexpr size_t kMaxReadsPerEvent = 2;
1089 // 2 event loop iterations
1090 static constexpr size_t kMaxReadBufferSz =
1091 sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1094 SSLClient(EventBase *eventBase,
1095 const folly::SocketAddress& address,
1097 uint32_t timeout = 0)
1098 : eventBase_(eventBase),
1100 requests_(requests),
1107 writeAfterConnectErrors_(0) {
1108 ctx_.reset(new folly::SSLContext());
1109 ctx_->setOptions(SSL_OP_NO_TICKET);
1110 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1111 memset(buf_, 'a', sizeof(buf_));
1116 SSL_SESSION_free(session_);
1119 EXPECT_EQ(bytesRead_, sizeof(buf_));
1123 uint32_t getHit() const { return hit_; }
1125 uint32_t getMiss() const { return miss_; }
1127 uint32_t getErrors() const { return errors_; }
1129 uint32_t getWriteAfterConnectErrors() const {
1130 return writeAfterConnectErrors_;
1133 void connect(bool writeNow = false) {
1134 sslSocket_ = AsyncSSLSocket::newSocket(
1136 if (session_ != nullptr) {
1137 sslSocket_->setSSLSession(session_);
1140 sslSocket_->connect(this, address_, timeout_);
1141 if (sslSocket_ && writeNow) {
1142 // write some junk, used in an error test
1143 sslSocket_->write(this, buf_, sizeof(buf_));
1147 void connectSuccess() noexcept override {
1148 std::cerr << "client SSL socket connected" << std::endl;
1149 if (sslSocket_->getSSLSessionReused()) {
1153 if (session_ != nullptr) {
1154 SSL_SESSION_free(session_);
1156 session_ = sslSocket_->getSSLSession();
1160 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1161 sslSocket_->write(this, buf_, sizeof(buf_));
1162 sslSocket_->setReadCB(this);
1163 memset(readbuf_, 'b', sizeof(readbuf_));
1168 const AsyncSocketException& ex) noexcept override {
1169 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1174 void writeSuccess() noexcept override {
1175 std::cerr << "client write success" << std::endl;
1178 void writeErr(size_t /* bytesWritten */,
1179 const AsyncSocketException& ex) noexcept override {
1180 std::cerr << "client writeError: " << ex.what() << std::endl;
1182 writeAfterConnectErrors_++;
1186 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1187 *bufReturn = readbuf_ + bytesRead_;
1188 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1191 void readEOF() noexcept override {
1192 std::cerr << "client readEOF" << std::endl;
1196 const AsyncSocketException& ex) noexcept override {
1197 std::cerr << "client readError: " << ex.what() << std::endl;
1200 void readDataAvailable(size_t len) noexcept override {
1201 std::cerr << "client read data: " << len << std::endl;
1203 if (bytesRead_ == sizeof(buf_)) {
1204 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1205 sslSocket_->closeNow();
1207 if (requests_ != 0) {
1215 class SSLHandshakeBase :
1216 public AsyncSSLSocket::HandshakeCB,
1217 private AsyncTransportWrapper::WriteCallback {
1219 explicit SSLHandshakeBase(
1220 AsyncSSLSocket::UniquePtr socket,
1221 bool preverifyResult,
1222 bool verifyResult) :
1223 handshakeVerify_(false),
1224 handshakeSuccess_(false),
1225 handshakeError_(false),
1226 socket_(std::move(socket)),
1227 preverifyResult_(preverifyResult),
1228 verifyResult_(verifyResult) {
1231 AsyncSSLSocket::UniquePtr moveSocket() && {
1232 return std::move(socket_);
1235 bool handshakeVerify_;
1236 bool handshakeSuccess_;
1237 bool handshakeError_;
1238 std::chrono::nanoseconds handshakeTime;
1241 AsyncSSLSocket::UniquePtr socket_;
1242 bool preverifyResult_;
1245 // HandshakeCallback
1246 bool handshakeVer(AsyncSSLSocket* /* sock */,
1248 X509_STORE_CTX* /* ctx */) noexcept override {
1249 handshakeVerify_ = true;
1251 EXPECT_EQ(preverifyResult_, preverifyOk);
1252 return verifyResult_;
1255 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1256 LOG(INFO) << "Handshake success";
1257 handshakeSuccess_ = true;
1258 handshakeTime = socket_->getHandshakeTime();
1263 const AsyncSocketException& ex) noexcept override {
1264 LOG(INFO) << "Handshake error " << ex.what();
1265 handshakeError_ = true;
1266 handshakeTime = socket_->getHandshakeTime();
1270 void writeSuccess() noexcept override {
1275 size_t bytesWritten,
1276 const AsyncSocketException& ex) noexcept override {
1277 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1282 class SSLHandshakeClient : public SSLHandshakeBase {
1285 AsyncSSLSocket::UniquePtr socket,
1286 bool preverifyResult,
1287 bool verifyResult) :
1288 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1289 socket_->sslConn(this, 0);
1293 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1295 SSLHandshakeClientNoVerify(
1296 AsyncSSLSocket::UniquePtr socket,
1297 bool preverifyResult,
1298 bool verifyResult) :
1299 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1300 socket_->sslConn(this, 0,
1301 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1305 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1307 SSLHandshakeClientDoVerify(
1308 AsyncSSLSocket::UniquePtr socket,
1309 bool preverifyResult,
1310 bool verifyResult) :
1311 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1312 socket_->sslConn(this, 0,
1313 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1317 class SSLHandshakeServer : public SSLHandshakeBase {
1320 AsyncSSLSocket::UniquePtr socket,
1321 bool preverifyResult,
1323 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1324 socket_->sslAccept(this, 0);
1328 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1330 SSLHandshakeServerParseClientHello(
1331 AsyncSSLSocket::UniquePtr socket,
1332 bool preverifyResult,
1334 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1335 socket_->enableClientHelloParsing();
1336 socket_->sslAccept(this, 0);
1339 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1342 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1343 handshakeSuccess_ = true;
1344 sock->getSSLSharedCiphers(sharedCiphers_);
1345 sock->getSSLServerCiphers(serverCiphers_);
1346 sock->getSSLClientCiphers(clientCiphers_);
1347 chosenCipher_ = sock->getNegotiatedCipherName();
1352 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1354 SSLHandshakeServerNoVerify(
1355 AsyncSSLSocket::UniquePtr socket,
1356 bool preverifyResult,
1358 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1359 socket_->sslAccept(this, 0,
1360 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1364 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1366 SSLHandshakeServerDoVerify(
1367 AsyncSSLSocket::UniquePtr socket,
1368 bool preverifyResult,
1370 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1371 socket_->sslAccept(this, 0,
1372 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1376 class EventBaseAborter : public AsyncTimeout {
1378 EventBaseAborter(EventBase* eventBase,
1381 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1382 , eventBase_(eventBase) {
1383 scheduleTimeout(timeoutMS);
1386 void timeoutExpired() noexcept override {
1387 FAIL() << "test timed out";
1388 eventBase_->terminateLoopSoon();
1392 EventBase* eventBase_;