2 * Copyright 2016 Facebook, Inc.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/io/async/AsyncSSLSocket.h>
24 #include <folly/io/async/AsyncServerSocket.h>
25 #include <folly/io/async/AsyncSocket.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/io/async/AsyncTransport.h>
28 #include <folly/io/async/EventBase.h>
29 #include <folly/io/async/ssl/SSLErrors.h>
30 #include <folly/portability/Sockets.h>
31 #include <folly/portability/Unistd.h>
33 #include <gtest/gtest.h>
37 #include <sys/types.h>
47 // The destructors of all callback classes assert that the state is
48 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
49 // are responsible for setting the succeeded state properly before the
50 // destructors are called.
52 class WriteCallbackBase :
53 public AsyncTransportWrapper::WriteCallback {
56 : state(STATE_WAITING)
58 , exception(AsyncSocketException::UNKNOWN, "none") {}
60 ~WriteCallbackBase() {
61 EXPECT_EQ(STATE_SUCCEEDED, state);
65 const std::shared_ptr<AsyncSSLSocket> &socket) {
69 void writeSuccess() noexcept override {
70 std::cerr << "writeSuccess" << std::endl;
71 state = STATE_SUCCEEDED;
76 const AsyncSocketException& ex) noexcept override {
77 std::cerr << "writeError: bytesWritten " << bytesWritten
78 << ", exception " << ex.what() << std::endl;
81 this->bytesWritten = bytesWritten;
84 socket_->detachEventBase();
87 std::shared_ptr<AsyncSSLSocket> socket_;
90 AsyncSocketException exception;
93 class ReadCallbackBase :
94 public AsyncTransportWrapper::ReadCallback {
96 explicit ReadCallbackBase(WriteCallbackBase* wcb)
97 : wcb_(wcb), state(STATE_WAITING) {}
100 EXPECT_EQ(STATE_SUCCEEDED, state);
104 const std::shared_ptr<AsyncSSLSocket> &socket) {
108 void setState(StateEnum s) {
116 const AsyncSocketException& ex) noexcept override {
117 std::cerr << "readError " << ex.what() << std::endl;
118 state = STATE_FAILED;
120 socket_->detachEventBase();
123 void readEOF() noexcept override {
124 std::cerr << "readEOF" << std::endl;
127 socket_->detachEventBase();
130 std::shared_ptr<AsyncSSLSocket> socket_;
131 WriteCallbackBase *wcb_;
135 class ReadCallback : public ReadCallbackBase {
137 explicit ReadCallback(WriteCallbackBase *wcb)
138 : ReadCallbackBase(wcb)
142 for (std::vector<Buffer>::iterator it = buffers.begin();
147 currentBuffer.free();
150 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
151 if (!currentBuffer.buffer) {
152 currentBuffer.allocate(4096);
154 *bufReturn = currentBuffer.buffer;
155 *lenReturn = currentBuffer.length;
158 void readDataAvailable(size_t len) noexcept override {
159 std::cerr << "readDataAvailable, len " << len << std::endl;
161 currentBuffer.length = len;
163 wcb_->setSocket(socket_);
165 // Write back the same data.
166 socket_->write(wcb_, currentBuffer.buffer, len);
168 buffers.push_back(currentBuffer);
169 currentBuffer.reset();
170 state = STATE_SUCCEEDED;
175 Buffer() : buffer(nullptr), length(0) {}
176 Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
182 void allocate(size_t length) {
183 assert(buffer == nullptr);
184 this->buffer = static_cast<char*>(malloc(length));
185 this->length = length;
196 std::vector<Buffer> buffers;
197 Buffer currentBuffer;
200 class ReadErrorCallback : public ReadCallbackBase {
202 explicit ReadErrorCallback(WriteCallbackBase *wcb)
203 : ReadCallbackBase(wcb) {}
205 // Return nullptr buffer to trigger readError()
206 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
207 *bufReturn = nullptr;
211 void readDataAvailable(size_t /* len */) noexcept override {
212 // This should never to called.
217 const AsyncSocketException& ex) noexcept override {
218 ReadCallbackBase::readErr(ex);
219 std::cerr << "ReadErrorCallback::readError" << std::endl;
220 setState(STATE_SUCCEEDED);
224 class ReadEOFCallback : public ReadCallbackBase {
226 explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
228 // Return nullptr buffer to trigger readError()
229 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
230 *bufReturn = nullptr;
234 void readDataAvailable(size_t /* len */) noexcept override {
235 // This should never to called.
239 void readEOF() noexcept override {
240 ReadCallbackBase::readEOF();
241 setState(STATE_SUCCEEDED);
245 class WriteErrorCallback : public ReadCallback {
247 explicit WriteErrorCallback(WriteCallbackBase *wcb)
248 : ReadCallback(wcb) {}
250 void readDataAvailable(size_t len) noexcept override {
251 std::cerr << "readDataAvailable, len " << len << std::endl;
253 currentBuffer.length = len;
255 // close the socket before writing to trigger writeError().
256 ::close(socket_->getFd());
258 wcb_->setSocket(socket_);
260 // Write back the same data.
261 socket_->write(wcb_, currentBuffer.buffer, len);
263 if (wcb_->state == STATE_FAILED) {
264 setState(STATE_SUCCEEDED);
266 state = STATE_FAILED;
269 buffers.push_back(currentBuffer);
270 currentBuffer.reset();
273 void readErr(const AsyncSocketException& ex) noexcept override {
274 std::cerr << "readError " << ex.what() << std::endl;
275 // do nothing since this is expected
279 class EmptyReadCallback : public ReadCallback {
281 explicit EmptyReadCallback()
282 : ReadCallback(nullptr) {}
284 void readErr(const AsyncSocketException& ex) noexcept override {
285 std::cerr << "readError " << ex.what() << std::endl;
286 state = STATE_FAILED;
288 tcpSocket_->detachEventBase();
291 void readEOF() noexcept override {
292 std::cerr << "readEOF" << std::endl;
295 tcpSocket_->detachEventBase();
296 state = STATE_SUCCEEDED;
299 std::shared_ptr<AsyncSocket> tcpSocket_;
302 class HandshakeCallback :
303 public AsyncSSLSocket::HandshakeCB {
310 explicit HandshakeCallback(ReadCallbackBase *rcb,
311 ExpectType expect = EXPECT_SUCCESS):
312 state(STATE_WAITING),
317 const std::shared_ptr<AsyncSSLSocket> &socket) {
321 void setState(StateEnum s) {
326 // Functions inherited from AsyncSSLSocketHandshakeCallback
327 void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
328 std::lock_guard<std::mutex> g(mutex_);
330 EXPECT_EQ(sock, socket_.get());
331 std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
332 rcb_->setSocket(socket_);
333 sock->setReadCB(rcb_);
334 state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
336 void handshakeErr(AsyncSSLSocket* /* sock */,
337 const AsyncSocketException& ex) noexcept override {
338 std::lock_guard<std::mutex> g(mutex_);
340 std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
341 state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
342 if (expect_ == EXPECT_ERROR) {
343 // rcb will never be invoked
344 rcb_->setState(STATE_SUCCEEDED);
346 errorString_ = ex.what();
349 void waitForHandshake() {
350 std::unique_lock<std::mutex> lock(mutex_);
351 cv_.wait(lock, [this] { return state != STATE_WAITING; });
354 ~HandshakeCallback() {
355 EXPECT_EQ(STATE_SUCCEEDED, state);
360 state = STATE_SUCCEEDED;
363 std::shared_ptr<AsyncSSLSocket> getSocket() {
368 std::shared_ptr<AsyncSSLSocket> socket_;
369 ReadCallbackBase *rcb_;
372 std::condition_variable cv_;
373 std::string errorString_;
376 class SSLServerAcceptCallbackBase:
377 public folly::AsyncServerSocket::AcceptCallback {
379 explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
380 state(STATE_WAITING), hcb_(hcb) {}
382 ~SSLServerAcceptCallbackBase() {
383 EXPECT_EQ(STATE_SUCCEEDED, state);
386 void acceptError(const std::exception& ex) noexcept override {
387 std::cerr << "SSLServerAcceptCallbackBase::acceptError "
388 << ex.what() << std::endl;
389 state = STATE_FAILED;
392 void connectionAccepted(
393 int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
394 printf("Connection accepted\n");
395 std::shared_ptr<AsyncSSLSocket> sslSock;
397 // Create a AsyncSSLSocket object with the fd. The socket should be
398 // added to the event base and in the state of accepting SSL connection.
399 sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
400 } catch (const std::exception &e) {
401 LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
402 "object with socket " << e.what() << fd;
408 connAccepted(sslSock);
411 virtual void connAccepted(
412 const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
415 HandshakeCallback *hcb_;
416 std::shared_ptr<folly::SSLContext> ctx_;
417 folly::EventBase* base_;
420 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
424 explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
425 uint32_t timeout = 0):
426 SSLServerAcceptCallbackBase(hcb),
429 virtual ~SSLServerAcceptCallback() {
431 // if we set a timeout, we expect failure
432 EXPECT_EQ(hcb_->state, STATE_FAILED);
433 hcb_->setState(STATE_SUCCEEDED);
437 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
439 const std::shared_ptr<folly::AsyncSSLSocket> &s)
441 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
442 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
444 hcb_->setSocket(sock);
445 sock->sslAccept(hcb_, timeout_);
446 EXPECT_EQ(sock->getSSLState(),
447 AsyncSSLSocket::STATE_ACCEPTING);
449 state = STATE_SUCCEEDED;
453 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
455 explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
456 SSLServerAcceptCallback(hcb) {}
458 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
460 const std::shared_ptr<folly::AsyncSSLSocket> &s)
463 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
465 std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
467 int fd = sock->getFd();
471 // The accepted connection should already have TCP_NODELAY set
473 socklen_t valueLength = sizeof(value);
474 int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
480 // Unset the TCP_NODELAY option.
482 socklen_t valueLength = sizeof(value);
483 int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
486 rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
490 SSLServerAcceptCallback::connAccepted(sock);
494 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
496 explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
497 uint32_t timeout = 0):
498 SSLServerAcceptCallback(hcb, timeout) {}
500 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
502 const std::shared_ptr<folly::AsyncSSLSocket> &s)
504 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
506 std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
508 hcb_->setSocket(sock);
509 sock->sslAccept(hcb_, timeout_);
510 ASSERT_TRUE((sock->getSSLState() ==
511 AsyncSSLSocket::STATE_ACCEPTING) ||
512 (sock->getSSLState() ==
513 AsyncSSLSocket::STATE_CACHE_LOOKUP));
515 state = STATE_SUCCEEDED;
520 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
522 explicit HandshakeErrorCallback(HandshakeCallback *hcb):
523 SSLServerAcceptCallbackBase(hcb) {}
525 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
527 const std::shared_ptr<folly::AsyncSSLSocket> &s)
529 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
531 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
533 // The first call to sslAccept() should succeed.
534 hcb_->setSocket(sock);
535 sock->sslAccept(hcb_);
536 EXPECT_EQ(sock->getSSLState(),
537 AsyncSSLSocket::STATE_ACCEPTING);
539 // The second call to sslAccept() should fail.
540 HandshakeCallback callback2(hcb_->rcb_);
541 callback2.setSocket(sock);
542 sock->sslAccept(&callback2);
543 EXPECT_EQ(sock->getSSLState(),
544 AsyncSSLSocket::STATE_ERROR);
546 // Both callbacks should be in the error state.
547 EXPECT_EQ(hcb_->state, STATE_FAILED);
548 EXPECT_EQ(callback2.state, STATE_FAILED);
550 sock->detachEventBase();
552 state = STATE_SUCCEEDED;
553 hcb_->setState(STATE_SUCCEEDED);
554 callback2.setState(STATE_SUCCEEDED);
558 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
560 explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
561 SSLServerAcceptCallbackBase(hcb) {}
563 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
565 const std::shared_ptr<folly::AsyncSSLSocket> &s)
567 std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
569 auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
571 hcb_->setSocket(sock);
572 sock->getEventBase()->tryRunAfterDelay([=] {
573 std::cerr << "Delayed SSL accept, client will have close by now"
575 // SSL accept will fail
578 AsyncSSLSocket::STATE_UNINIT);
579 hcb_->socket_->sslAccept(hcb_);
580 // This registers for an event
583 AsyncSSLSocket::STATE_ACCEPTING);
585 state = STATE_SUCCEEDED;
590 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
592 ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
593 // We don't care if we get invoked or not.
594 // The client may time out and give up before connAccepted() is even
596 state = STATE_SUCCEEDED;
599 // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
601 const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
602 std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
604 // Just wait a while before closing the socket, so the client
605 // will time out waiting for the handshake to complete.
606 s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
610 class TestSSLServer {
613 std::shared_ptr<folly::SSLContext> ctx_;
614 SSLServerAcceptCallbackBase *acb_;
615 std::shared_ptr<folly::AsyncServerSocket> socket_;
616 folly::SocketAddress address_;
619 static void *Main(void *ctx) {
620 TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
622 std::cerr << "Server thread exited event loop" << std::endl;
627 // Create a TestSSLServer.
628 // This immediately starts listening on the given port.
629 explicit TestSSLServer(
630 SSLServerAcceptCallbackBase* acb,
631 bool enableTFO = false);
635 evb_.runInEventBaseThread([&](){
636 socket_->stopAccepting();
638 std::cerr << "Waiting for server thread to exit" << std::endl;
639 pthread_join(thread_, nullptr);
642 EventBase &getEventBase() { return evb_; }
644 const folly::SocketAddress& getAddress() const {
649 class TestSSLAsyncCacheServer : public TestSSLServer {
651 explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
652 int lookupDelay = 100) :
654 SSL_CTX *sslCtx = ctx_->getSSLCtx();
655 SSL_CTX_sess_set_get_cb(sslCtx,
656 TestSSLAsyncCacheServer::getSessionCallback);
657 SSL_CTX_set_session_cache_mode(
658 sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
661 lookupDelay_ = lookupDelay;
664 uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
665 uint32_t getAsyncLookups() const { return asyncLookups_; }
668 static uint32_t asyncCallbacks_;
669 static uint32_t asyncLookups_;
670 static uint32_t lookupDelay_;
672 static SSL_SESSION* getSessionCallback(SSL* ssl,
673 unsigned char* /* sess_id */,
678 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
679 if (!SSL_want_sess_cache_lookup(ssl)) {
680 // libssl.so mismatch
681 std::cerr << "no async support" << std::endl;
685 AsyncSSLSocket *sslSocket =
686 AsyncSSLSocket::getFromSSL(ssl);
687 assert(sslSocket != nullptr);
688 // Going to simulate an async cache by just running delaying the miss 100ms
689 if (asyncCallbacks_ % 2 == 0) {
690 // This socket is already blocked on lookup, return miss
691 std::cerr << "returning miss" << std::endl;
693 // fresh meat - block it
694 std::cerr << "async lookup" << std::endl;
695 sslSocket->getEventBase()->tryRunAfterDelay(
696 std::bind(&AsyncSSLSocket::restartSSLAccept,
697 sslSocket), lookupDelay_);
698 *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
706 void getfds(int fds[2]);
709 std::shared_ptr<folly::SSLContext> clientCtx,
710 std::shared_ptr<folly::SSLContext> serverCtx);
713 EventBase* eventBase,
714 AsyncSSLSocket::UniquePtr* clientSock,
715 AsyncSSLSocket::UniquePtr* serverSock);
717 class BlockingWriteClient :
718 private AsyncSSLSocket::HandshakeCB,
719 private AsyncTransportWrapper::WriteCallback {
721 explicit BlockingWriteClient(
722 AsyncSSLSocket::UniquePtr socket)
723 : socket_(std::move(socket)),
727 buf_.reset(new uint8_t[bufLen_]);
728 for (uint32_t n = 0; n < sizeof(buf_); ++n) {
733 iov_.reset(new struct iovec[iovCount_]);
734 for (uint32_t n = 0; n < iovCount_; ++n) {
735 iov_[n].iov_base = buf_.get() + n;
737 iov_[n].iov_len = n % bufLen_;
739 iov_[n].iov_len = bufLen_ - (n % bufLen_);
743 socket_->sslConn(this, 100);
746 struct iovec* getIovec() const {
749 uint32_t getIovecCount() const {
754 void handshakeSuc(AsyncSSLSocket*) noexcept override {
755 socket_->writev(this, iov_.get(), iovCount_);
759 const AsyncSocketException& ex) noexcept override {
760 ADD_FAILURE() << "client handshake error: " << ex.what();
762 void writeSuccess() noexcept override {
767 const AsyncSocketException& ex) noexcept override {
768 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
772 AsyncSSLSocket::UniquePtr socket_;
775 std::unique_ptr<uint8_t[]> buf_;
776 std::unique_ptr<struct iovec[]> iov_;
779 class BlockingWriteServer :
780 private AsyncSSLSocket::HandshakeCB,
781 private AsyncTransportWrapper::ReadCallback {
783 explicit BlockingWriteServer(
784 AsyncSSLSocket::UniquePtr socket)
785 : socket_(std::move(socket)),
786 bufSize_(2500 * 2000),
788 buf_.reset(new uint8_t[bufSize_]);
789 socket_->sslAccept(this, 100);
792 void checkBuffer(struct iovec* iov, uint32_t count) const {
794 for (uint32_t n = 0; n < count; ++n) {
795 size_t bytesLeft = bytesRead_ - idx;
796 int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
797 std::min(iov[n].iov_len, bytesLeft));
799 FAIL() << "buffer mismatch at iovec " << n << "/" << count
803 if (iov[n].iov_len > bytesLeft) {
804 FAIL() << "server did not read enough data: "
805 << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
806 << " in iovec " << n << "/" << count;
809 idx += iov[n].iov_len;
811 if (idx != bytesRead_) {
812 ADD_FAILURE() << "server read extra data: " << bytesRead_
813 << " bytes read; expected " << idx;
818 void handshakeSuc(AsyncSSLSocket*) noexcept override {
819 // Wait 10ms before reading, so the client's writes will initially block.
820 socket_->getEventBase()->tryRunAfterDelay(
821 [this] { socket_->setReadCB(this); }, 10);
825 const AsyncSocketException& ex) noexcept override {
826 ADD_FAILURE() << "server handshake error: " << ex.what();
828 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
829 *bufReturn = buf_.get() + bytesRead_;
830 *lenReturn = bufSize_ - bytesRead_;
832 void readDataAvailable(size_t len) noexcept override {
834 socket_->setReadCB(nullptr);
835 socket_->getEventBase()->tryRunAfterDelay(
836 [this] { socket_->setReadCB(this); }, 2);
838 void readEOF() noexcept override {
842 const AsyncSocketException& ex) noexcept override {
843 ADD_FAILURE() << "server read error: " << ex.what();
846 AsyncSSLSocket::UniquePtr socket_;
849 std::unique_ptr<uint8_t[]> buf_;
853 private AsyncSSLSocket::HandshakeCB,
854 private AsyncTransportWrapper::WriteCallback {
857 AsyncSSLSocket::UniquePtr socket)
858 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
859 socket_->sslConn(this);
862 const unsigned char* nextProto;
863 unsigned nextProtoLength;
864 SSLContext::NextProtocolType protocolType;
867 void handshakeSuc(AsyncSSLSocket*) noexcept override {
868 socket_->getSelectedNextProtocol(
869 &nextProto, &nextProtoLength, &protocolType);
873 const AsyncSocketException& ex) noexcept override {
874 ADD_FAILURE() << "client handshake error: " << ex.what();
876 void writeSuccess() noexcept override {
881 const AsyncSocketException& ex) noexcept override {
882 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
886 AsyncSSLSocket::UniquePtr socket_;
890 private AsyncSSLSocket::HandshakeCB,
891 private AsyncTransportWrapper::ReadCallback {
893 explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
894 : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
895 socket_->sslAccept(this);
898 const unsigned char* nextProto;
899 unsigned nextProtoLength;
900 SSLContext::NextProtocolType protocolType;
903 void handshakeSuc(AsyncSSLSocket*) noexcept override {
904 socket_->getSelectedNextProtocol(
905 &nextProto, &nextProtoLength, &protocolType);
909 const AsyncSocketException& ex) noexcept override {
910 ADD_FAILURE() << "server handshake error: " << ex.what();
912 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
915 void readDataAvailable(size_t /* len */) noexcept override {}
916 void readEOF() noexcept override {
920 const AsyncSocketException& ex) noexcept override {
921 ADD_FAILURE() << "server read error: " << ex.what();
924 AsyncSSLSocket::UniquePtr socket_;
927 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
928 public AsyncTransportWrapper::ReadCallback {
930 explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
931 : socket_(std::move(socket)) {
932 socket_->sslAccept(this);
935 ~RenegotiatingServer() {
936 socket_->setReadCB(nullptr);
939 void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
940 LOG(INFO) << "Renegotiating server handshake success";
941 socket_->setReadCB(this);
945 const AsyncSocketException& ex) noexcept override {
946 ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
948 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
949 *lenReturn = sizeof(buf);
952 void readDataAvailable(size_t /* len */) noexcept override {}
953 void readEOF() noexcept override {}
954 void readErr(const AsyncSocketException& ex) noexcept override {
955 LOG(INFO) << "server got read error " << ex.what();
956 auto exPtr = dynamic_cast<const SSLException*>(&ex);
957 ASSERT_NE(nullptr, exPtr);
958 std::string exStr(ex.what());
959 SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
960 ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
961 renegotiationError_ = true;
964 AsyncSSLSocket::UniquePtr socket_;
965 unsigned char buf[128];
966 bool renegotiationError_{false};
969 #ifndef OPENSSL_NO_TLSEXT
971 private AsyncSSLSocket::HandshakeCB,
972 private AsyncTransportWrapper::WriteCallback {
975 AsyncSSLSocket::UniquePtr socket)
976 : serverNameMatch(false), socket_(std::move(socket)) {
977 socket_->sslConn(this);
980 bool serverNameMatch;
983 void handshakeSuc(AsyncSSLSocket*) noexcept override {
984 serverNameMatch = socket_->isServerNameMatch();
988 const AsyncSocketException& ex) noexcept override {
989 ADD_FAILURE() << "client handshake error: " << ex.what();
991 void writeSuccess() noexcept override {
996 const AsyncSocketException& ex) noexcept override {
997 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1001 AsyncSSLSocket::UniquePtr socket_;
1005 private AsyncSSLSocket::HandshakeCB,
1006 private AsyncTransportWrapper::ReadCallback {
1009 AsyncSSLSocket::UniquePtr socket,
1010 const std::shared_ptr<folly::SSLContext>& ctx,
1011 const std::shared_ptr<folly::SSLContext>& sniCtx,
1012 const std::string& expectedServerName)
1013 : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1014 expectedServerName_(expectedServerName) {
1015 ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1016 std::placeholders::_1));
1017 socket_->sslAccept(this);
1020 bool serverNameMatch;
1023 void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1026 const AsyncSocketException& ex) noexcept override {
1027 ADD_FAILURE() << "server handshake error: " << ex.what();
1029 void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1032 void readDataAvailable(size_t /* len */) noexcept override {}
1033 void readEOF() noexcept override {
1037 const AsyncSocketException& ex) noexcept override {
1038 ADD_FAILURE() << "server read error: " << ex.what();
1041 folly::SSLContext::ServerNameCallbackResult
1042 serverNameCallback(SSL *ssl) {
1043 const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1046 !strcasecmp(expectedServerName_.c_str(), sn)) {
1047 AsyncSSLSocket *sslSocket =
1048 AsyncSSLSocket::getFromSSL(ssl);
1049 sslSocket->switchServerSSLContext(sniCtx_);
1050 serverNameMatch = true;
1051 return folly::SSLContext::SERVER_NAME_FOUND;
1053 serverNameMatch = false;
1054 return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1058 AsyncSSLSocket::UniquePtr socket_;
1059 std::shared_ptr<folly::SSLContext> sniCtx_;
1060 std::string expectedServerName_;
1064 class SSLClient : public AsyncSocket::ConnectCallback,
1065 public AsyncTransportWrapper::WriteCallback,
1066 public AsyncTransportWrapper::ReadCallback
1069 EventBase *eventBase_;
1070 std::shared_ptr<AsyncSSLSocket> sslSocket_;
1071 SSL_SESSION *session_;
1072 std::shared_ptr<folly::SSLContext> ctx_;
1074 folly::SocketAddress address_;
1078 uint32_t bytesRead_;
1082 uint32_t writeAfterConnectErrors_;
1084 // These settings test that we eventually drain the
1085 // socket, even if the maxReadsPerEvent_ is hit during
1086 // a event loop iteration.
1087 static constexpr size_t kMaxReadsPerEvent = 2;
1088 // 2 event loop iterations
1089 static constexpr size_t kMaxReadBufferSz =
1090 sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1093 SSLClient(EventBase *eventBase,
1094 const folly::SocketAddress& address,
1096 uint32_t timeout = 0)
1097 : eventBase_(eventBase),
1099 requests_(requests),
1106 writeAfterConnectErrors_(0) {
1107 ctx_.reset(new folly::SSLContext());
1108 ctx_->setOptions(SSL_OP_NO_TICKET);
1109 ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1110 memset(buf_, 'a', sizeof(buf_));
1115 SSL_SESSION_free(session_);
1118 EXPECT_EQ(bytesRead_, sizeof(buf_));
1122 uint32_t getHit() const { return hit_; }
1124 uint32_t getMiss() const { return miss_; }
1126 uint32_t getErrors() const { return errors_; }
1128 uint32_t getWriteAfterConnectErrors() const {
1129 return writeAfterConnectErrors_;
1132 void connect(bool writeNow = false) {
1133 sslSocket_ = AsyncSSLSocket::newSocket(
1135 if (session_ != nullptr) {
1136 sslSocket_->setSSLSession(session_);
1139 sslSocket_->connect(this, address_, timeout_);
1140 if (sslSocket_ && writeNow) {
1141 // write some junk, used in an error test
1142 sslSocket_->write(this, buf_, sizeof(buf_));
1146 void connectSuccess() noexcept override {
1147 std::cerr << "client SSL socket connected" << std::endl;
1148 if (sslSocket_->getSSLSessionReused()) {
1152 if (session_ != nullptr) {
1153 SSL_SESSION_free(session_);
1155 session_ = sslSocket_->getSSLSession();
1159 sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1160 sslSocket_->write(this, buf_, sizeof(buf_));
1161 sslSocket_->setReadCB(this);
1162 memset(readbuf_, 'b', sizeof(readbuf_));
1167 const AsyncSocketException& ex) noexcept override {
1168 std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1173 void writeSuccess() noexcept override {
1174 std::cerr << "client write success" << std::endl;
1177 void writeErr(size_t /* bytesWritten */,
1178 const AsyncSocketException& ex) noexcept override {
1179 std::cerr << "client writeError: " << ex.what() << std::endl;
1181 writeAfterConnectErrors_++;
1185 void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1186 *bufReturn = readbuf_ + bytesRead_;
1187 *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1190 void readEOF() noexcept override {
1191 std::cerr << "client readEOF" << std::endl;
1195 const AsyncSocketException& ex) noexcept override {
1196 std::cerr << "client readError: " << ex.what() << std::endl;
1199 void readDataAvailable(size_t len) noexcept override {
1200 std::cerr << "client read data: " << len << std::endl;
1202 if (bytesRead_ == sizeof(buf_)) {
1203 EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1204 sslSocket_->closeNow();
1206 if (requests_ != 0) {
1214 class SSLHandshakeBase :
1215 public AsyncSSLSocket::HandshakeCB,
1216 private AsyncTransportWrapper::WriteCallback {
1218 explicit SSLHandshakeBase(
1219 AsyncSSLSocket::UniquePtr socket,
1220 bool preverifyResult,
1221 bool verifyResult) :
1222 handshakeVerify_(false),
1223 handshakeSuccess_(false),
1224 handshakeError_(false),
1225 socket_(std::move(socket)),
1226 preverifyResult_(preverifyResult),
1227 verifyResult_(verifyResult) {
1230 AsyncSSLSocket::UniquePtr moveSocket() && {
1231 return std::move(socket_);
1234 bool handshakeVerify_;
1235 bool handshakeSuccess_;
1236 bool handshakeError_;
1237 std::chrono::nanoseconds handshakeTime;
1240 AsyncSSLSocket::UniquePtr socket_;
1241 bool preverifyResult_;
1244 // HandshakeCallback
1245 bool handshakeVer(AsyncSSLSocket* /* sock */,
1247 X509_STORE_CTX* /* ctx */) noexcept override {
1248 handshakeVerify_ = true;
1250 EXPECT_EQ(preverifyResult_, preverifyOk);
1251 return verifyResult_;
1254 void handshakeSuc(AsyncSSLSocket*) noexcept override {
1255 LOG(INFO) << "Handshake success";
1256 handshakeSuccess_ = true;
1257 handshakeTime = socket_->getHandshakeTime();
1262 const AsyncSocketException& ex) noexcept override {
1263 LOG(INFO) << "Handshake error " << ex.what();
1264 handshakeError_ = true;
1265 handshakeTime = socket_->getHandshakeTime();
1269 void writeSuccess() noexcept override {
1274 size_t bytesWritten,
1275 const AsyncSocketException& ex) noexcept override {
1276 ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1281 class SSLHandshakeClient : public SSLHandshakeBase {
1284 AsyncSSLSocket::UniquePtr socket,
1285 bool preverifyResult,
1286 bool verifyResult) :
1287 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1288 socket_->sslConn(this, 0);
1292 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1294 SSLHandshakeClientNoVerify(
1295 AsyncSSLSocket::UniquePtr socket,
1296 bool preverifyResult,
1297 bool verifyResult) :
1298 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1299 socket_->sslConn(this, 0,
1300 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1304 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1306 SSLHandshakeClientDoVerify(
1307 AsyncSSLSocket::UniquePtr socket,
1308 bool preverifyResult,
1309 bool verifyResult) :
1310 SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1311 socket_->sslConn(this, 0,
1312 folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1316 class SSLHandshakeServer : public SSLHandshakeBase {
1319 AsyncSSLSocket::UniquePtr socket,
1320 bool preverifyResult,
1322 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1323 socket_->sslAccept(this, 0);
1327 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1329 SSLHandshakeServerParseClientHello(
1330 AsyncSSLSocket::UniquePtr socket,
1331 bool preverifyResult,
1333 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1334 socket_->enableClientHelloParsing();
1335 socket_->sslAccept(this, 0);
1338 std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1341 void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1342 handshakeSuccess_ = true;
1343 sock->getSSLSharedCiphers(sharedCiphers_);
1344 sock->getSSLServerCiphers(serverCiphers_);
1345 sock->getSSLClientCiphers(clientCiphers_);
1346 chosenCipher_ = sock->getNegotiatedCipherName();
1351 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1353 SSLHandshakeServerNoVerify(
1354 AsyncSSLSocket::UniquePtr socket,
1355 bool preverifyResult,
1357 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1358 socket_->sslAccept(this, 0,
1359 folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1363 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1365 SSLHandshakeServerDoVerify(
1366 AsyncSSLSocket::UniquePtr socket,
1367 bool preverifyResult,
1369 : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1370 socket_->sslAccept(this, 0,
1371 folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1375 class EventBaseAborter : public AsyncTimeout {
1377 EventBaseAborter(EventBase* eventBase,
1380 eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1381 , eventBase_(eventBase) {
1382 scheduleTimeout(timeoutMS);
1385 void timeoutExpired() noexcept override {
1386 FAIL() << "test timed out";
1387 eventBase_->terminateLoopSoon();
1391 EventBase* eventBase_;