folly/io/async/tests: always detach event base in tests, fixes UBSAN tests
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/experimental/TestUtil.h>
24 #include <folly/io/async/AsyncSSLSocket.h>
25 #include <folly/io/async/AsyncServerSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/AsyncTimeout.h>
28 #include <folly/io/async/AsyncTransport.h>
29 #include <folly/io/async/EventBase.h>
30 #include <folly/io/async/ssl/SSLErrors.h>
31 #include <folly/portability/GTest.h>
32 #include <folly/portability/Sockets.h>
33 #include <folly/portability/Unistd.h>
34
35 #include <fcntl.h>
36 #include <sys/types.h>
37 #include <condition_variable>
38 #include <iostream>
39 #include <list>
40
41 namespace folly {
42
43 enum StateEnum {
44   STATE_WAITING,
45   STATE_SUCCEEDED,
46   STATE_FAILED
47 };
48
49 // The destructors of all callback classes assert that the state is
50 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
51 // are responsible for setting the succeeded state properly before the
52 // destructors are called.
53
54 class WriteCallbackBase :
55 public AsyncTransportWrapper::WriteCallback {
56 public:
57   WriteCallbackBase()
58       : state(STATE_WAITING)
59       , bytesWritten(0)
60       , exception(AsyncSocketException::UNKNOWN, "none") {}
61
62   ~WriteCallbackBase() {
63     EXPECT_EQ(STATE_SUCCEEDED, state);
64   }
65
66   void setSocket(
67     const std::shared_ptr<AsyncSSLSocket> &socket) {
68     socket_ = socket;
69   }
70
71   void writeSuccess() noexcept override {
72     std::cerr << "writeSuccess" << std::endl;
73     state = STATE_SUCCEEDED;
74   }
75
76   void writeErr(
77     size_t nBytesWritten,
78     const AsyncSocketException& ex) noexcept override {
79     std::cerr << "writeError: bytesWritten " << nBytesWritten
80          << ", exception " << ex.what() << std::endl;
81
82     state = STATE_FAILED;
83     this->bytesWritten = nBytesWritten;
84     exception = ex;
85     socket_->close();
86   }
87
88   std::shared_ptr<AsyncSSLSocket> socket_;
89   StateEnum state;
90   size_t bytesWritten;
91   AsyncSocketException exception;
92 };
93
94 class ReadCallbackBase :
95 public AsyncTransportWrapper::ReadCallback {
96  public:
97   explicit ReadCallbackBase(WriteCallbackBase* wcb)
98       : wcb_(wcb), state(STATE_WAITING) {}
99
100   ~ReadCallbackBase() {
101     EXPECT_EQ(STATE_SUCCEEDED, state);
102   }
103
104   void setSocket(
105     const std::shared_ptr<AsyncSSLSocket> &socket) {
106     socket_ = socket;
107   }
108
109   void setState(StateEnum s) {
110     state = s;
111     if (wcb_) {
112       wcb_->state = s;
113     }
114   }
115
116   void readErr(
117     const AsyncSocketException& ex) noexcept override {
118     std::cerr << "readError " << ex.what() << std::endl;
119     state = STATE_FAILED;
120     socket_->close();
121   }
122
123   void readEOF() noexcept override {
124     std::cerr << "readEOF" << std::endl;
125
126     socket_->close();
127   }
128
129   std::shared_ptr<AsyncSSLSocket> socket_;
130   WriteCallbackBase *wcb_;
131   StateEnum state;
132 };
133
134 class ReadCallback : public ReadCallbackBase {
135 public:
136   explicit ReadCallback(WriteCallbackBase *wcb)
137       : ReadCallbackBase(wcb)
138       , buffers() {}
139
140   ~ReadCallback() {
141     for (std::vector<Buffer>::iterator it = buffers.begin();
142          it != buffers.end();
143          ++it) {
144       it->free();
145     }
146     currentBuffer.free();
147   }
148
149   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
150     if (!currentBuffer.buffer) {
151       currentBuffer.allocate(4096);
152     }
153     *bufReturn = currentBuffer.buffer;
154     *lenReturn = currentBuffer.length;
155   }
156
157   void readDataAvailable(size_t len) noexcept override {
158     std::cerr << "readDataAvailable, len " << len << std::endl;
159
160     currentBuffer.length = len;
161
162     wcb_->setSocket(socket_);
163
164     // Write back the same data.
165     socket_->write(wcb_, currentBuffer.buffer, len);
166
167     buffers.push_back(currentBuffer);
168     currentBuffer.reset();
169     state = STATE_SUCCEEDED;
170   }
171
172   class Buffer {
173   public:
174     Buffer() : buffer(nullptr), length(0) {}
175     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
176
177     void reset() {
178       buffer = nullptr;
179       length = 0;
180     }
181     void allocate(size_t len) {
182       assert(buffer == nullptr);
183       this->buffer = static_cast<char*>(malloc(len));
184       this->length = len;
185     }
186     void free() {
187       ::free(buffer);
188       reset();
189     }
190
191     char* buffer;
192     size_t length;
193   };
194
195   std::vector<Buffer> buffers;
196   Buffer currentBuffer;
197 };
198
199 class ReadErrorCallback : public ReadCallbackBase {
200 public:
201   explicit ReadErrorCallback(WriteCallbackBase *wcb)
202       : ReadCallbackBase(wcb) {}
203
204   // Return nullptr buffer to trigger readError()
205   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
206     *bufReturn = nullptr;
207     *lenReturn = 0;
208   }
209
210   void readDataAvailable(size_t /* len */) noexcept override {
211     // This should never to called.
212     FAIL();
213   }
214
215   void readErr(
216     const AsyncSocketException& ex) noexcept override {
217     ReadCallbackBase::readErr(ex);
218     std::cerr << "ReadErrorCallback::readError" << std::endl;
219     setState(STATE_SUCCEEDED);
220   }
221 };
222
223 class ReadEOFCallback : public ReadCallbackBase {
224  public:
225   explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
226
227   // Return nullptr buffer to trigger readError()
228   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
229     *bufReturn = nullptr;
230     *lenReturn = 0;
231   }
232
233   void readDataAvailable(size_t /* len */) noexcept override {
234     // This should never to called.
235     FAIL();
236   }
237
238   void readEOF() noexcept override {
239     ReadCallbackBase::readEOF();
240     setState(STATE_SUCCEEDED);
241   }
242 };
243
244 class WriteErrorCallback : public ReadCallback {
245 public:
246   explicit WriteErrorCallback(WriteCallbackBase *wcb)
247       : ReadCallback(wcb) {}
248
249   void readDataAvailable(size_t len) noexcept override {
250     std::cerr << "readDataAvailable, len " << len << std::endl;
251
252     currentBuffer.length = len;
253
254     // close the socket before writing to trigger writeError().
255     ::close(socket_->getFd());
256
257     wcb_->setSocket(socket_);
258
259     // Write back the same data.
260     folly::test::msvcSuppressAbortOnInvalidParams([&] {
261       socket_->write(wcb_, currentBuffer.buffer, len);
262     });
263
264     if (wcb_->state == STATE_FAILED) {
265       setState(STATE_SUCCEEDED);
266     } else {
267       state = STATE_FAILED;
268     }
269
270     buffers.push_back(currentBuffer);
271     currentBuffer.reset();
272   }
273
274   void readErr(const AsyncSocketException& ex) noexcept override {
275     std::cerr << "readError " << ex.what() << std::endl;
276     // do nothing since this is expected
277   }
278 };
279
280 class EmptyReadCallback : public ReadCallback {
281 public:
282   explicit EmptyReadCallback()
283       : ReadCallback(nullptr) {}
284
285   void readErr(const AsyncSocketException& ex) noexcept override {
286     std::cerr << "readError " << ex.what() << std::endl;
287     state = STATE_FAILED;
288     if (tcpSocket_) {
289       tcpSocket_->close();
290     }
291   }
292
293   void readEOF() noexcept override {
294     std::cerr << "readEOF" << std::endl;
295     if (tcpSocket_) {
296       tcpSocket_->close();
297     }
298     state = STATE_SUCCEEDED;
299   }
300
301   std::shared_ptr<AsyncSocket> tcpSocket_;
302 };
303
304 class HandshakeCallback :
305 public AsyncSSLSocket::HandshakeCB {
306 public:
307   enum ExpectType {
308     EXPECT_SUCCESS,
309     EXPECT_ERROR
310   };
311
312   explicit HandshakeCallback(ReadCallbackBase *rcb,
313                              ExpectType expect = EXPECT_SUCCESS):
314       state(STATE_WAITING),
315       rcb_(rcb),
316       expect_(expect) {}
317
318   void setSocket(
319     const std::shared_ptr<AsyncSSLSocket> &socket) {
320     socket_ = socket;
321   }
322
323   void setState(StateEnum s) {
324     state = s;
325     rcb_->setState(s);
326   }
327
328   // Functions inherited from AsyncSSLSocketHandshakeCallback
329   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
330     std::lock_guard<std::mutex> g(mutex_);
331     cv_.notify_all();
332     EXPECT_EQ(sock, socket_.get());
333     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
334     rcb_->setSocket(socket_);
335     sock->setReadCB(rcb_);
336     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
337   }
338   void handshakeErr(AsyncSSLSocket* /* sock */,
339                     const AsyncSocketException& ex) noexcept override {
340     std::lock_guard<std::mutex> g(mutex_);
341     cv_.notify_all();
342     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
343     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
344     if (expect_ == EXPECT_ERROR) {
345       // rcb will never be invoked
346       rcb_->setState(STATE_SUCCEEDED);
347     }
348     errorString_ = ex.what();
349   }
350
351   void waitForHandshake() {
352     std::unique_lock<std::mutex> lock(mutex_);
353     cv_.wait(lock, [this] { return state != STATE_WAITING; });
354   }
355
356   ~HandshakeCallback() {
357     EXPECT_EQ(STATE_SUCCEEDED, state);
358   }
359
360   void closeSocket() {
361     socket_->close();
362     state = STATE_SUCCEEDED;
363   }
364
365   std::shared_ptr<AsyncSSLSocket> getSocket() {
366     return socket_;
367   }
368
369   StateEnum state;
370   std::shared_ptr<AsyncSSLSocket> socket_;
371   ReadCallbackBase *rcb_;
372   ExpectType expect_;
373   std::mutex mutex_;
374   std::condition_variable cv_;
375   std::string errorString_;
376 };
377
378 class SSLServerAcceptCallbackBase:
379 public folly::AsyncServerSocket::AcceptCallback {
380 public:
381   explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
382   state(STATE_WAITING), hcb_(hcb) {}
383
384   ~SSLServerAcceptCallbackBase() {
385     EXPECT_EQ(STATE_SUCCEEDED, state);
386   }
387
388   void acceptError(const std::exception& ex) noexcept override {
389     std::cerr << "SSLServerAcceptCallbackBase::acceptError "
390               << ex.what() << std::endl;
391     state = STATE_FAILED;
392   }
393
394   void connectionAccepted(
395       int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
396     if (socket_) {
397       socket_->detachEventBase();
398     }
399     printf("Connection accepted\n");
400     try {
401       // Create a AsyncSSLSocket object with the fd. The socket should be
402       // added to the event base and in the state of accepting SSL connection.
403       socket_ = AsyncSSLSocket::newSocket(ctx_, base_, fd);
404     } catch (const std::exception &e) {
405       LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
406         "object with socket " << e.what() << fd;
407       ::close(fd);
408       acceptError(e);
409       return;
410     }
411
412     connAccepted(socket_);
413   }
414
415   virtual void connAccepted(
416     const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
417
418   void detach() {
419     socket_->detachEventBase();
420   }
421
422   StateEnum state;
423   HandshakeCallback *hcb_;
424   std::shared_ptr<folly::SSLContext> ctx_;
425   std::shared_ptr<AsyncSSLSocket> socket_;
426   folly::EventBase* base_;
427 };
428
429 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
430 public:
431   uint32_t timeout_;
432
433   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
434                                    uint32_t timeout = 0):
435       SSLServerAcceptCallbackBase(hcb),
436       timeout_(timeout) {}
437
438   virtual ~SSLServerAcceptCallback() {
439     if (timeout_ > 0) {
440       // if we set a timeout, we expect failure
441       EXPECT_EQ(hcb_->state, STATE_FAILED);
442       hcb_->setState(STATE_SUCCEEDED);
443     }
444   }
445
446   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
447   void connAccepted(
448     const std::shared_ptr<folly::AsyncSSLSocket> &s)
449     noexcept override {
450     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
451     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
452
453     hcb_->setSocket(sock);
454     sock->sslAccept(hcb_, timeout_);
455     EXPECT_EQ(sock->getSSLState(),
456                       AsyncSSLSocket::STATE_ACCEPTING);
457
458     state = STATE_SUCCEEDED;
459   }
460 };
461
462 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
463 public:
464   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
465       SSLServerAcceptCallback(hcb) {}
466
467   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
468   void connAccepted(
469     const std::shared_ptr<folly::AsyncSSLSocket> &s)
470     noexcept override {
471
472     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
473
474     std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
475               << std::endl;
476     int fd = sock->getFd();
477
478 #ifndef TCP_NOPUSH
479     {
480     // The accepted connection should already have TCP_NODELAY set
481     int value;
482     socklen_t valueLength = sizeof(value);
483     int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
484     EXPECT_EQ(rc, 0);
485     EXPECT_EQ(value, 1);
486     }
487 #endif
488
489     // Unset the TCP_NODELAY option.
490     int value = 0;
491     socklen_t valueLength = sizeof(value);
492     int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
493     EXPECT_EQ(rc, 0);
494
495     rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
496     EXPECT_EQ(rc, 0);
497     EXPECT_EQ(value, 0);
498
499     SSLServerAcceptCallback::connAccepted(sock);
500   }
501 };
502
503 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
504 public:
505   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
506                                              uint32_t timeout = 0):
507     SSLServerAcceptCallback(hcb, timeout) {}
508
509   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
510   void connAccepted(
511     const std::shared_ptr<folly::AsyncSSLSocket> &s)
512     noexcept override {
513     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
514
515     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
516
517     hcb_->setSocket(sock);
518     sock->sslAccept(hcb_, timeout_);
519     ASSERT_TRUE((sock->getSSLState() ==
520                  AsyncSSLSocket::STATE_ACCEPTING) ||
521                 (sock->getSSLState() ==
522                  AsyncSSLSocket::STATE_CACHE_LOOKUP));
523
524     state = STATE_SUCCEEDED;
525   }
526 };
527
528
529 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
530 public:
531   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
532   SSLServerAcceptCallbackBase(hcb)  {}
533
534   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
535   void connAccepted(
536     const std::shared_ptr<folly::AsyncSSLSocket> &s)
537     noexcept override {
538     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
539
540     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
541
542     // The first call to sslAccept() should succeed.
543     hcb_->setSocket(sock);
544     sock->sslAccept(hcb_);
545     EXPECT_EQ(sock->getSSLState(),
546                       AsyncSSLSocket::STATE_ACCEPTING);
547
548     // The second call to sslAccept() should fail.
549     HandshakeCallback callback2(hcb_->rcb_);
550     callback2.setSocket(sock);
551     sock->sslAccept(&callback2);
552     EXPECT_EQ(sock->getSSLState(),
553                       AsyncSSLSocket::STATE_ERROR);
554
555     // Both callbacks should be in the error state.
556     EXPECT_EQ(hcb_->state, STATE_FAILED);
557     EXPECT_EQ(callback2.state, STATE_FAILED);
558
559     state = STATE_SUCCEEDED;
560     hcb_->setState(STATE_SUCCEEDED);
561     callback2.setState(STATE_SUCCEEDED);
562   }
563 };
564
565 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
566 public:
567   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
568   SSLServerAcceptCallbackBase(hcb)  {}
569
570   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
571   void connAccepted(
572     const std::shared_ptr<folly::AsyncSSLSocket> &s)
573     noexcept override {
574     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
575
576     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
577
578     hcb_->setSocket(sock);
579     sock->getEventBase()->tryRunAfterDelay([=] {
580         std::cerr << "Delayed SSL accept, client will have close by now"
581                   << std::endl;
582         // SSL accept will fail
583         EXPECT_EQ(
584           sock->getSSLState(),
585           AsyncSSLSocket::STATE_UNINIT);
586         hcb_->socket_->sslAccept(hcb_);
587         // This registers for an event
588         EXPECT_EQ(
589           sock->getSSLState(),
590           AsyncSSLSocket::STATE_ACCEPTING);
591
592         state = STATE_SUCCEEDED;
593       }, 100);
594   }
595 };
596
597 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
598  public:
599   ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
600     // We don't care if we get invoked or not.
601     // The client may time out and give up before connAccepted() is even
602     // called.
603     state = STATE_SUCCEEDED;
604   }
605
606   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
607   void connAccepted(
608       const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
609     std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
610
611     // Just wait a while before closing the socket, so the client
612     // will time out waiting for the handshake to complete.
613     s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
614   }
615 };
616
617 class TestSSLServer {
618  protected:
619   EventBase evb_;
620   std::shared_ptr<folly::SSLContext> ctx_;
621   SSLServerAcceptCallbackBase *acb_;
622   std::shared_ptr<folly::AsyncServerSocket> socket_;
623   folly::SocketAddress address_;
624   pthread_t thread_;
625
626   static void *Main(void *ctx) {
627     TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
628     self->evb_.loop();
629     self->acb_->detach();
630     std::cerr << "Server thread exited event loop" << std::endl;
631     return nullptr;
632   }
633
634  public:
635   // Create a TestSSLServer.
636   // This immediately starts listening on the given port.
637   explicit TestSSLServer(
638       SSLServerAcceptCallbackBase* acb,
639       bool enableTFO = false);
640
641   // Kill the thread.
642   ~TestSSLServer() {
643     evb_.runInEventBaseThread([&](){
644       socket_->stopAccepting();
645     });
646     std::cerr << "Waiting for server thread to exit" << std::endl;
647     pthread_join(thread_, nullptr);
648   }
649
650   EventBase &getEventBase() { return evb_; }
651
652   const folly::SocketAddress& getAddress() const {
653     return address_;
654   }
655 };
656
657 class TestSSLAsyncCacheServer : public TestSSLServer {
658  public:
659   explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
660         int lookupDelay = 100) :
661       TestSSLServer(acb) {
662     SSL_CTX *sslCtx = ctx_->getSSLCtx();
663     SSL_CTX_sess_set_get_cb(sslCtx,
664                             TestSSLAsyncCacheServer::getSessionCallback);
665     SSL_CTX_set_session_cache_mode(
666       sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
667     asyncCallbacks_ = 0;
668     asyncLookups_ = 0;
669     lookupDelay_ = lookupDelay;
670   }
671
672   uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
673   uint32_t getAsyncLookups() const { return asyncLookups_; }
674
675  private:
676   static uint32_t asyncCallbacks_;
677   static uint32_t asyncLookups_;
678   static uint32_t lookupDelay_;
679
680   static SSL_SESSION* getSessionCallback(SSL* ssl,
681                                          unsigned char* /* sess_id */,
682                                          int /* id_len */,
683                                          int* copyflag) {
684     *copyflag = 0;
685     asyncCallbacks_++;
686 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
687     if (!SSL_want_sess_cache_lookup(ssl)) {
688       // libssl.so mismatch
689       std::cerr << "no async support" << std::endl;
690       return nullptr;
691     }
692
693     AsyncSSLSocket *sslSocket =
694         AsyncSSLSocket::getFromSSL(ssl);
695     assert(sslSocket != nullptr);
696     // Going to simulate an async cache by just running delaying the miss 100ms
697     if (asyncCallbacks_ % 2 == 0) {
698       // This socket is already blocked on lookup, return miss
699       std::cerr << "returning miss" << std::endl;
700     } else {
701       // fresh meat - block it
702       std::cerr << "async lookup" << std::endl;
703       sslSocket->getEventBase()->tryRunAfterDelay(
704         std::bind(&AsyncSSLSocket::restartSSLAccept,
705                   sslSocket), lookupDelay_);
706       *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
707       asyncLookups_++;
708     }
709 #endif
710     return nullptr;
711   }
712 };
713
714 void getfds(int fds[2]);
715
716 void getctx(
717   std::shared_ptr<folly::SSLContext> clientCtx,
718   std::shared_ptr<folly::SSLContext> serverCtx);
719
720 void sslsocketpair(
721   EventBase* eventBase,
722   AsyncSSLSocket::UniquePtr* clientSock,
723   AsyncSSLSocket::UniquePtr* serverSock);
724
725 class BlockingWriteClient :
726   private AsyncSSLSocket::HandshakeCB,
727   private AsyncTransportWrapper::WriteCallback {
728  public:
729   explicit BlockingWriteClient(
730     AsyncSSLSocket::UniquePtr socket)
731     : socket_(std::move(socket)),
732       bufLen_(2500),
733       iovCount_(2000) {
734     // Fill buf_
735     buf_.reset(new uint8_t[bufLen_]);
736     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
737       buf_[n] = n % 0xff;
738     }
739
740     // Initialize iov_
741     iov_.reset(new struct iovec[iovCount_]);
742     for (uint32_t n = 0; n < iovCount_; ++n) {
743       iov_[n].iov_base = buf_.get() + n;
744       if (n & 0x1) {
745         iov_[n].iov_len = n % bufLen_;
746       } else {
747         iov_[n].iov_len = bufLen_ - (n % bufLen_);
748       }
749     }
750
751     socket_->sslConn(this, 100);
752   }
753
754   struct iovec* getIovec() const {
755     return iov_.get();
756   }
757   uint32_t getIovecCount() const {
758     return iovCount_;
759   }
760
761  private:
762   void handshakeSuc(AsyncSSLSocket*) noexcept override {
763     socket_->writev(this, iov_.get(), iovCount_);
764   }
765   void handshakeErr(
766     AsyncSSLSocket*,
767     const AsyncSocketException& ex) noexcept override {
768     ADD_FAILURE() << "client handshake error: " << ex.what();
769   }
770   void writeSuccess() noexcept override {
771     socket_->close();
772   }
773   void writeErr(
774     size_t bytesWritten,
775     const AsyncSocketException& ex) noexcept override {
776     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
777                   << ex.what();
778   }
779
780   AsyncSSLSocket::UniquePtr socket_;
781   uint32_t bufLen_;
782   uint32_t iovCount_;
783   std::unique_ptr<uint8_t[]> buf_;
784   std::unique_ptr<struct iovec[]> iov_;
785 };
786
787 class BlockingWriteServer :
788     private AsyncSSLSocket::HandshakeCB,
789     private AsyncTransportWrapper::ReadCallback {
790  public:
791   explicit BlockingWriteServer(
792     AsyncSSLSocket::UniquePtr socket)
793     : socket_(std::move(socket)),
794       bufSize_(2500 * 2000),
795       bytesRead_(0) {
796     buf_.reset(new uint8_t[bufSize_]);
797     socket_->sslAccept(this, 100);
798   }
799
800   void checkBuffer(struct iovec* iov, uint32_t count) const {
801     uint32_t idx = 0;
802     for (uint32_t n = 0; n < count; ++n) {
803       size_t bytesLeft = bytesRead_ - idx;
804       int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
805                       std::min(iov[n].iov_len, bytesLeft));
806       if (rc != 0) {
807         FAIL() << "buffer mismatch at iovec " << n << "/" << count
808                << ": rc=" << rc;
809
810       }
811       if (iov[n].iov_len > bytesLeft) {
812         FAIL() << "server did not read enough data: "
813                << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
814                << " in iovec " << n << "/" << count;
815       }
816
817       idx += iov[n].iov_len;
818     }
819     if (idx != bytesRead_) {
820       ADD_FAILURE() << "server read extra data: " << bytesRead_
821                     << " bytes read; expected " << idx;
822     }
823   }
824
825  private:
826   void handshakeSuc(AsyncSSLSocket*) noexcept override {
827     // Wait 10ms before reading, so the client's writes will initially block.
828     socket_->getEventBase()->tryRunAfterDelay(
829         [this] { socket_->setReadCB(this); }, 10);
830   }
831   void handshakeErr(
832     AsyncSSLSocket*,
833     const AsyncSocketException& ex) noexcept override {
834     ADD_FAILURE() << "server handshake error: " << ex.what();
835   }
836   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
837     *bufReturn = buf_.get() + bytesRead_;
838     *lenReturn = bufSize_ - bytesRead_;
839   }
840   void readDataAvailable(size_t len) noexcept override {
841     bytesRead_ += len;
842     socket_->setReadCB(nullptr);
843     socket_->getEventBase()->tryRunAfterDelay(
844         [this] { socket_->setReadCB(this); }, 2);
845   }
846   void readEOF() noexcept override {
847     socket_->close();
848   }
849   void readErr(
850     const AsyncSocketException& ex) noexcept override {
851     ADD_FAILURE() << "server read error: " << ex.what();
852   }
853
854   AsyncSSLSocket::UniquePtr socket_;
855   uint32_t bufSize_;
856   uint32_t bytesRead_;
857   std::unique_ptr<uint8_t[]> buf_;
858 };
859
860 class NpnClient :
861   private AsyncSSLSocket::HandshakeCB,
862   private AsyncTransportWrapper::WriteCallback {
863  public:
864   explicit NpnClient(
865     AsyncSSLSocket::UniquePtr socket)
866       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
867     socket_->sslConn(this);
868   }
869
870   const unsigned char* nextProto;
871   unsigned nextProtoLength;
872   SSLContext::NextProtocolType protocolType;
873
874  private:
875   void handshakeSuc(AsyncSSLSocket*) noexcept override {
876     socket_->getSelectedNextProtocol(
877         &nextProto, &nextProtoLength, &protocolType);
878   }
879   void handshakeErr(
880     AsyncSSLSocket*,
881     const AsyncSocketException& ex) noexcept override {
882     ADD_FAILURE() << "client handshake error: " << ex.what();
883   }
884   void writeSuccess() noexcept override {
885     socket_->close();
886   }
887   void writeErr(
888     size_t bytesWritten,
889     const AsyncSocketException& ex) noexcept override {
890     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
891                   << ex.what();
892   }
893
894   AsyncSSLSocket::UniquePtr socket_;
895 };
896
897 class NpnServer :
898     private AsyncSSLSocket::HandshakeCB,
899     private AsyncTransportWrapper::ReadCallback {
900  public:
901   explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
902       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
903     socket_->sslAccept(this);
904   }
905
906   const unsigned char* nextProto;
907   unsigned nextProtoLength;
908   SSLContext::NextProtocolType protocolType;
909
910  private:
911   void handshakeSuc(AsyncSSLSocket*) noexcept override {
912     socket_->getSelectedNextProtocol(
913         &nextProto, &nextProtoLength, &protocolType);
914   }
915   void handshakeErr(
916     AsyncSSLSocket*,
917     const AsyncSocketException& ex) noexcept override {
918     ADD_FAILURE() << "server handshake error: " << ex.what();
919   }
920   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
921     *lenReturn = 0;
922   }
923   void readDataAvailable(size_t /* len */) noexcept override {}
924   void readEOF() noexcept override {
925     socket_->close();
926   }
927   void readErr(
928     const AsyncSocketException& ex) noexcept override {
929     ADD_FAILURE() << "server read error: " << ex.what();
930   }
931
932   AsyncSSLSocket::UniquePtr socket_;
933 };
934
935 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
936                             public AsyncTransportWrapper::ReadCallback {
937  public:
938   explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
939       : socket_(std::move(socket)) {
940     socket_->sslAccept(this);
941   }
942
943   ~RenegotiatingServer() {
944     socket_->setReadCB(nullptr);
945   }
946
947   void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
948     LOG(INFO) << "Renegotiating server handshake success";
949     socket_->setReadCB(this);
950   }
951   void handshakeErr(
952       AsyncSSLSocket*,
953       const AsyncSocketException& ex) noexcept override {
954     ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
955   }
956   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
957     *lenReturn = sizeof(buf);
958     *bufReturn = buf;
959   }
960   void readDataAvailable(size_t /* len */) noexcept override {}
961   void readEOF() noexcept override {}
962   void readErr(const AsyncSocketException& ex) noexcept override {
963     LOG(INFO) << "server got read error " << ex.what();
964     auto exPtr = dynamic_cast<const SSLException*>(&ex);
965     ASSERT_NE(nullptr, exPtr);
966     std::string exStr(ex.what());
967     SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
968     ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
969     renegotiationError_ = true;
970   }
971
972   AsyncSSLSocket::UniquePtr socket_;
973   unsigned char buf[128];
974   bool renegotiationError_{false};
975 };
976
977 #ifndef OPENSSL_NO_TLSEXT
978 class SNIClient :
979   private AsyncSSLSocket::HandshakeCB,
980   private AsyncTransportWrapper::WriteCallback {
981  public:
982   explicit SNIClient(
983     AsyncSSLSocket::UniquePtr socket)
984       : serverNameMatch(false), socket_(std::move(socket)) {
985     socket_->sslConn(this);
986   }
987
988   bool serverNameMatch;
989
990  private:
991   void handshakeSuc(AsyncSSLSocket*) noexcept override {
992     serverNameMatch = socket_->isServerNameMatch();
993   }
994   void handshakeErr(
995     AsyncSSLSocket*,
996     const AsyncSocketException& ex) noexcept override {
997     ADD_FAILURE() << "client handshake error: " << ex.what();
998   }
999   void writeSuccess() noexcept override {
1000     socket_->close();
1001   }
1002   void writeErr(
1003     size_t bytesWritten,
1004     const AsyncSocketException& ex) noexcept override {
1005     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1006                   << ex.what();
1007   }
1008
1009   AsyncSSLSocket::UniquePtr socket_;
1010 };
1011
1012 class SNIServer :
1013     private AsyncSSLSocket::HandshakeCB,
1014     private AsyncTransportWrapper::ReadCallback {
1015  public:
1016   explicit SNIServer(
1017     AsyncSSLSocket::UniquePtr socket,
1018     const std::shared_ptr<folly::SSLContext>& ctx,
1019     const std::shared_ptr<folly::SSLContext>& sniCtx,
1020     const std::string& expectedServerName)
1021       : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1022         expectedServerName_(expectedServerName) {
1023     ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1024                                          std::placeholders::_1));
1025     socket_->sslAccept(this);
1026   }
1027
1028   bool serverNameMatch;
1029
1030  private:
1031   void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1032   void handshakeErr(
1033     AsyncSSLSocket*,
1034     const AsyncSocketException& ex) noexcept override {
1035     ADD_FAILURE() << "server handshake error: " << ex.what();
1036   }
1037   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1038     *lenReturn = 0;
1039   }
1040   void readDataAvailable(size_t /* len */) noexcept override {}
1041   void readEOF() noexcept override {
1042     socket_->close();
1043   }
1044   void readErr(
1045     const AsyncSocketException& ex) noexcept override {
1046     ADD_FAILURE() << "server read error: " << ex.what();
1047   }
1048
1049   folly::SSLContext::ServerNameCallbackResult
1050     serverNameCallback(SSL *ssl) {
1051     const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1052     if (sniCtx_ &&
1053         sn &&
1054         !strcasecmp(expectedServerName_.c_str(), sn)) {
1055       AsyncSSLSocket *sslSocket =
1056           AsyncSSLSocket::getFromSSL(ssl);
1057       sslSocket->switchServerSSLContext(sniCtx_);
1058       serverNameMatch = true;
1059       return folly::SSLContext::SERVER_NAME_FOUND;
1060     } else {
1061       serverNameMatch = false;
1062       return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1063     }
1064   }
1065
1066   AsyncSSLSocket::UniquePtr socket_;
1067   std::shared_ptr<folly::SSLContext> sniCtx_;
1068   std::string expectedServerName_;
1069 };
1070 #endif
1071
1072 class SSLClient : public AsyncSocket::ConnectCallback,
1073                   public AsyncTransportWrapper::WriteCallback,
1074                   public AsyncTransportWrapper::ReadCallback
1075 {
1076  private:
1077   EventBase *eventBase_;
1078   std::shared_ptr<AsyncSSLSocket> sslSocket_;
1079   SSL_SESSION *session_;
1080   std::shared_ptr<folly::SSLContext> ctx_;
1081   uint32_t requests_;
1082   folly::SocketAddress address_;
1083   uint32_t timeout_;
1084   char buf_[128];
1085   char readbuf_[128];
1086   uint32_t bytesRead_;
1087   uint32_t hit_;
1088   uint32_t miss_;
1089   uint32_t errors_;
1090   uint32_t writeAfterConnectErrors_;
1091
1092   // These settings test that we eventually drain the
1093   // socket, even if the maxReadsPerEvent_ is hit during
1094   // a event loop iteration.
1095   static constexpr size_t kMaxReadsPerEvent = 2;
1096   // 2 event loop iterations
1097   static constexpr size_t kMaxReadBufferSz =
1098     sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1099
1100  public:
1101   SSLClient(EventBase *eventBase,
1102             const folly::SocketAddress& address,
1103             uint32_t requests,
1104             uint32_t timeout = 0)
1105       : eventBase_(eventBase),
1106         session_(nullptr),
1107         requests_(requests),
1108         address_(address),
1109         timeout_(timeout),
1110         bytesRead_(0),
1111         hit_(0),
1112         miss_(0),
1113         errors_(0),
1114         writeAfterConnectErrors_(0) {
1115     ctx_.reset(new folly::SSLContext());
1116     ctx_->setOptions(SSL_OP_NO_TICKET);
1117     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1118     memset(buf_, 'a', sizeof(buf_));
1119   }
1120
1121   ~SSLClient() {
1122     if (session_) {
1123       SSL_SESSION_free(session_);
1124     }
1125     if (errors_ == 0) {
1126       EXPECT_EQ(bytesRead_, sizeof(buf_));
1127     }
1128   }
1129
1130   uint32_t getHit() const { return hit_; }
1131
1132   uint32_t getMiss() const { return miss_; }
1133
1134   uint32_t getErrors() const { return errors_; }
1135
1136   uint32_t getWriteAfterConnectErrors() const {
1137     return writeAfterConnectErrors_;
1138   }
1139
1140   void connect(bool writeNow = false) {
1141     sslSocket_ = AsyncSSLSocket::newSocket(
1142       ctx_, eventBase_);
1143     if (session_ != nullptr) {
1144       sslSocket_->setSSLSession(session_);
1145     }
1146     requests_--;
1147     sslSocket_->connect(this, address_, timeout_);
1148     if (sslSocket_ && writeNow) {
1149       // write some junk, used in an error test
1150       sslSocket_->write(this, buf_, sizeof(buf_));
1151     }
1152   }
1153
1154   void connectSuccess() noexcept override {
1155     std::cerr << "client SSL socket connected" << std::endl;
1156     if (sslSocket_->getSSLSessionReused()) {
1157       hit_++;
1158     } else {
1159       miss_++;
1160       if (session_ != nullptr) {
1161         SSL_SESSION_free(session_);
1162       }
1163       session_ = sslSocket_->getSSLSession();
1164     }
1165
1166     // write()
1167     sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1168     sslSocket_->write(this, buf_, sizeof(buf_));
1169     sslSocket_->setReadCB(this);
1170     memset(readbuf_, 'b', sizeof(readbuf_));
1171     bytesRead_ = 0;
1172   }
1173
1174   void connectErr(
1175     const AsyncSocketException& ex) noexcept override {
1176     std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1177     errors_++;
1178     sslSocket_.reset();
1179   }
1180
1181   void writeSuccess() noexcept override {
1182     std::cerr << "client write success" << std::endl;
1183   }
1184
1185   void writeErr(size_t /* bytesWritten */,
1186                 const AsyncSocketException& ex) noexcept override {
1187     std::cerr << "client writeError: " << ex.what() << std::endl;
1188     if (!sslSocket_) {
1189       writeAfterConnectErrors_++;
1190     }
1191   }
1192
1193   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1194     *bufReturn = readbuf_ + bytesRead_;
1195     *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1196   }
1197
1198   void readEOF() noexcept override {
1199     std::cerr << "client readEOF" << std::endl;
1200   }
1201
1202   void readErr(
1203     const AsyncSocketException& ex) noexcept override {
1204     std::cerr << "client readError: " << ex.what() << std::endl;
1205   }
1206
1207   void readDataAvailable(size_t len) noexcept override {
1208     std::cerr << "client read data: " << len << std::endl;
1209     bytesRead_ += len;
1210     if (bytesRead_ == sizeof(buf_)) {
1211       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1212       sslSocket_->closeNow();
1213       sslSocket_.reset();
1214       if (requests_ != 0) {
1215         connect();
1216       }
1217     }
1218   }
1219
1220 };
1221
1222 class SSLHandshakeBase :
1223   public AsyncSSLSocket::HandshakeCB,
1224   private AsyncTransportWrapper::WriteCallback {
1225  public:
1226   explicit SSLHandshakeBase(
1227    AsyncSSLSocket::UniquePtr socket,
1228    bool preverifyResult,
1229    bool verifyResult) :
1230     handshakeVerify_(false),
1231     handshakeSuccess_(false),
1232     handshakeError_(false),
1233     socket_(std::move(socket)),
1234     preverifyResult_(preverifyResult),
1235     verifyResult_(verifyResult) {
1236   }
1237
1238   AsyncSSLSocket::UniquePtr moveSocket() && {
1239     return std::move(socket_);
1240   }
1241
1242   bool handshakeVerify_;
1243   bool handshakeSuccess_;
1244   bool handshakeError_;
1245   std::chrono::nanoseconds handshakeTime;
1246
1247  protected:
1248   AsyncSSLSocket::UniquePtr socket_;
1249   bool preverifyResult_;
1250   bool verifyResult_;
1251
1252   // HandshakeCallback
1253   bool handshakeVer(AsyncSSLSocket* /* sock */,
1254                     bool preverifyOk,
1255                     X509_STORE_CTX* /* ctx */) noexcept override {
1256     handshakeVerify_ = true;
1257
1258     EXPECT_EQ(preverifyResult_, preverifyOk);
1259     return verifyResult_;
1260   }
1261
1262   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1263     LOG(INFO) << "Handshake success";
1264     handshakeSuccess_ = true;
1265     handshakeTime = socket_->getHandshakeTime();
1266   }
1267
1268   void handshakeErr(
1269       AsyncSSLSocket*,
1270       const AsyncSocketException& ex) noexcept override {
1271     LOG(INFO) << "Handshake error " << ex.what();
1272     handshakeError_ = true;
1273     handshakeTime = socket_->getHandshakeTime();
1274   }
1275
1276   // WriteCallback
1277   void writeSuccess() noexcept override {
1278     socket_->close();
1279   }
1280
1281   void writeErr(
1282    size_t bytesWritten,
1283    const AsyncSocketException& ex) noexcept override {
1284     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1285                   << ex.what();
1286   }
1287 };
1288
1289 class SSLHandshakeClient : public SSLHandshakeBase {
1290  public:
1291   SSLHandshakeClient(
1292    AsyncSSLSocket::UniquePtr socket,
1293    bool preverifyResult,
1294    bool verifyResult) :
1295     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1296     socket_->sslConn(this, 0);
1297   }
1298 };
1299
1300 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1301  public:
1302   SSLHandshakeClientNoVerify(
1303    AsyncSSLSocket::UniquePtr socket,
1304    bool preverifyResult,
1305    bool verifyResult) :
1306     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1307     socket_->sslConn(this, 0,
1308       folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1309   }
1310 };
1311
1312 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1313  public:
1314   SSLHandshakeClientDoVerify(
1315    AsyncSSLSocket::UniquePtr socket,
1316    bool preverifyResult,
1317    bool verifyResult) :
1318     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1319     socket_->sslConn(this, 0,
1320       folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1321   }
1322 };
1323
1324 class SSLHandshakeServer : public SSLHandshakeBase {
1325  public:
1326   SSLHandshakeServer(
1327       AsyncSSLSocket::UniquePtr socket,
1328       bool preverifyResult,
1329       bool verifyResult)
1330     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1331     socket_->sslAccept(this, 0);
1332   }
1333 };
1334
1335 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1336  public:
1337   SSLHandshakeServerParseClientHello(
1338       AsyncSSLSocket::UniquePtr socket,
1339       bool preverifyResult,
1340       bool verifyResult)
1341       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1342     socket_->enableClientHelloParsing();
1343     socket_->sslAccept(this, 0);
1344   }
1345
1346   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1347
1348  protected:
1349   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1350     handshakeSuccess_ = true;
1351     sock->getSSLSharedCiphers(sharedCiphers_);
1352     sock->getSSLServerCiphers(serverCiphers_);
1353     sock->getSSLClientCiphers(clientCiphers_);
1354     chosenCipher_ = sock->getNegotiatedCipherName();
1355   }
1356 };
1357
1358
1359 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1360  public:
1361   SSLHandshakeServerNoVerify(
1362       AsyncSSLSocket::UniquePtr socket,
1363       bool preverifyResult,
1364       bool verifyResult)
1365     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1366     socket_->sslAccept(this, 0,
1367       folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1368   }
1369 };
1370
1371 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1372  public:
1373   SSLHandshakeServerDoVerify(
1374       AsyncSSLSocket::UniquePtr socket,
1375       bool preverifyResult,
1376       bool verifyResult)
1377     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1378     socket_->sslAccept(this, 0,
1379       folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1380   }
1381 };
1382
1383 class EventBaseAborter : public AsyncTimeout {
1384  public:
1385   EventBaseAborter(EventBase* eventBase,
1386                    uint32_t timeoutMS)
1387     : AsyncTimeout(
1388       eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1389     , eventBase_(eventBase) {
1390     scheduleTimeout(timeoutMS);
1391   }
1392
1393   void timeoutExpired() noexcept override {
1394     FAIL() << "test timed out";
1395     eventBase_->terminateLoopSoon();
1396   }
1397
1398  private:
1399   EventBase* eventBase_;
1400 };
1401
1402 }