2017
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/experimental/TestUtil.h>
24 #include <folly/io/async/AsyncSSLSocket.h>
25 #include <folly/io/async/AsyncServerSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/AsyncTimeout.h>
28 #include <folly/io/async/AsyncTransport.h>
29 #include <folly/io/async/EventBase.h>
30 #include <folly/io/async/ssl/SSLErrors.h>
31 #include <folly/portability/GTest.h>
32 #include <folly/portability/Sockets.h>
33 #include <folly/portability/Unistd.h>
34
35 #include <fcntl.h>
36 #include <sys/types.h>
37 #include <condition_variable>
38 #include <iostream>
39 #include <list>
40
41 namespace folly {
42
43 enum StateEnum {
44   STATE_WAITING,
45   STATE_SUCCEEDED,
46   STATE_FAILED
47 };
48
49 // The destructors of all callback classes assert that the state is
50 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
51 // are responsible for setting the succeeded state properly before the
52 // destructors are called.
53
54 class WriteCallbackBase :
55 public AsyncTransportWrapper::WriteCallback {
56 public:
57   WriteCallbackBase()
58       : state(STATE_WAITING)
59       , bytesWritten(0)
60       , exception(AsyncSocketException::UNKNOWN, "none") {}
61
62   ~WriteCallbackBase() {
63     EXPECT_EQ(STATE_SUCCEEDED, state);
64   }
65
66   void setSocket(
67     const std::shared_ptr<AsyncSSLSocket> &socket) {
68     socket_ = socket;
69   }
70
71   void writeSuccess() noexcept override {
72     std::cerr << "writeSuccess" << std::endl;
73     state = STATE_SUCCEEDED;
74   }
75
76   void writeErr(
77     size_t nBytesWritten,
78     const AsyncSocketException& ex) noexcept override {
79     std::cerr << "writeError: bytesWritten " << nBytesWritten
80          << ", exception " << ex.what() << std::endl;
81
82     state = STATE_FAILED;
83     this->bytesWritten = nBytesWritten;
84     exception = ex;
85     socket_->close();
86   }
87
88   std::shared_ptr<AsyncSSLSocket> socket_;
89   StateEnum state;
90   size_t bytesWritten;
91   AsyncSocketException exception;
92 };
93
94 class ReadCallbackBase :
95 public AsyncTransportWrapper::ReadCallback {
96  public:
97   explicit ReadCallbackBase(WriteCallbackBase* wcb)
98       : wcb_(wcb), state(STATE_WAITING) {}
99
100   ~ReadCallbackBase() {
101     EXPECT_EQ(STATE_SUCCEEDED, state);
102   }
103
104   void setSocket(
105     const std::shared_ptr<AsyncSSLSocket> &socket) {
106     socket_ = socket;
107   }
108
109   void setState(StateEnum s) {
110     state = s;
111     if (wcb_) {
112       wcb_->state = s;
113     }
114   }
115
116   void readErr(
117     const AsyncSocketException& ex) noexcept override {
118     std::cerr << "readError " << ex.what() << std::endl;
119     state = STATE_FAILED;
120     socket_->close();
121   }
122
123   void readEOF() noexcept override {
124     std::cerr << "readEOF" << std::endl;
125
126     socket_->close();
127   }
128
129   std::shared_ptr<AsyncSSLSocket> socket_;
130   WriteCallbackBase *wcb_;
131   StateEnum state;
132 };
133
134 class ReadCallback : public ReadCallbackBase {
135 public:
136   explicit ReadCallback(WriteCallbackBase *wcb)
137       : ReadCallbackBase(wcb)
138       , buffers() {}
139
140   ~ReadCallback() {
141     for (std::vector<Buffer>::iterator it = buffers.begin();
142          it != buffers.end();
143          ++it) {
144       it->free();
145     }
146     currentBuffer.free();
147   }
148
149   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
150     if (!currentBuffer.buffer) {
151       currentBuffer.allocate(4096);
152     }
153     *bufReturn = currentBuffer.buffer;
154     *lenReturn = currentBuffer.length;
155   }
156
157   void readDataAvailable(size_t len) noexcept override {
158     std::cerr << "readDataAvailable, len " << len << std::endl;
159
160     currentBuffer.length = len;
161
162     wcb_->setSocket(socket_);
163
164     // Write back the same data.
165     socket_->write(wcb_, currentBuffer.buffer, len);
166
167     buffers.push_back(currentBuffer);
168     currentBuffer.reset();
169     state = STATE_SUCCEEDED;
170   }
171
172   class Buffer {
173   public:
174     Buffer() : buffer(nullptr), length(0) {}
175     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
176
177     void reset() {
178       buffer = nullptr;
179       length = 0;
180     }
181     void allocate(size_t len) {
182       assert(buffer == nullptr);
183       this->buffer = static_cast<char*>(malloc(len));
184       this->length = len;
185     }
186     void free() {
187       ::free(buffer);
188       reset();
189     }
190
191     char* buffer;
192     size_t length;
193   };
194
195   std::vector<Buffer> buffers;
196   Buffer currentBuffer;
197 };
198
199 class ReadErrorCallback : public ReadCallbackBase {
200 public:
201   explicit ReadErrorCallback(WriteCallbackBase *wcb)
202       : ReadCallbackBase(wcb) {}
203
204   // Return nullptr buffer to trigger readError()
205   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
206     *bufReturn = nullptr;
207     *lenReturn = 0;
208   }
209
210   void readDataAvailable(size_t /* len */) noexcept override {
211     // This should never to called.
212     FAIL();
213   }
214
215   void readErr(
216     const AsyncSocketException& ex) noexcept override {
217     ReadCallbackBase::readErr(ex);
218     std::cerr << "ReadErrorCallback::readError" << std::endl;
219     setState(STATE_SUCCEEDED);
220   }
221 };
222
223 class ReadEOFCallback : public ReadCallbackBase {
224  public:
225   explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
226
227   // Return nullptr buffer to trigger readError()
228   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
229     *bufReturn = nullptr;
230     *lenReturn = 0;
231   }
232
233   void readDataAvailable(size_t /* len */) noexcept override {
234     // This should never to called.
235     FAIL();
236   }
237
238   void readEOF() noexcept override {
239     ReadCallbackBase::readEOF();
240     setState(STATE_SUCCEEDED);
241   }
242 };
243
244 class WriteErrorCallback : public ReadCallback {
245 public:
246   explicit WriteErrorCallback(WriteCallbackBase *wcb)
247       : ReadCallback(wcb) {}
248
249   void readDataAvailable(size_t len) noexcept override {
250     std::cerr << "readDataAvailable, len " << len << std::endl;
251
252     currentBuffer.length = len;
253
254     // close the socket before writing to trigger writeError().
255     ::close(socket_->getFd());
256
257     wcb_->setSocket(socket_);
258
259     // Write back the same data.
260     folly::test::msvcSuppressAbortOnInvalidParams([&] {
261       socket_->write(wcb_, currentBuffer.buffer, len);
262     });
263
264     if (wcb_->state == STATE_FAILED) {
265       setState(STATE_SUCCEEDED);
266     } else {
267       state = STATE_FAILED;
268     }
269
270     buffers.push_back(currentBuffer);
271     currentBuffer.reset();
272   }
273
274   void readErr(const AsyncSocketException& ex) noexcept override {
275     std::cerr << "readError " << ex.what() << std::endl;
276     // do nothing since this is expected
277   }
278 };
279
280 class EmptyReadCallback : public ReadCallback {
281 public:
282   explicit EmptyReadCallback()
283       : ReadCallback(nullptr) {}
284
285   void readErr(const AsyncSocketException& ex) noexcept override {
286     std::cerr << "readError " << ex.what() << std::endl;
287     state = STATE_FAILED;
288     if (tcpSocket_) {
289       tcpSocket_->close();
290     }
291   }
292
293   void readEOF() noexcept override {
294     std::cerr << "readEOF" << std::endl;
295     if (tcpSocket_) {
296       tcpSocket_->close();
297     }
298     state = STATE_SUCCEEDED;
299   }
300
301   std::shared_ptr<AsyncSocket> tcpSocket_;
302 };
303
304 class HandshakeCallback :
305 public AsyncSSLSocket::HandshakeCB {
306 public:
307   enum ExpectType {
308     EXPECT_SUCCESS,
309     EXPECT_ERROR
310   };
311
312   explicit HandshakeCallback(ReadCallbackBase *rcb,
313                              ExpectType expect = EXPECT_SUCCESS):
314       state(STATE_WAITING),
315       rcb_(rcb),
316       expect_(expect) {}
317
318   void setSocket(
319     const std::shared_ptr<AsyncSSLSocket> &socket) {
320     socket_ = socket;
321   }
322
323   void setState(StateEnum s) {
324     state = s;
325     rcb_->setState(s);
326   }
327
328   // Functions inherited from AsyncSSLSocketHandshakeCallback
329   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
330     std::lock_guard<std::mutex> g(mutex_);
331     cv_.notify_all();
332     EXPECT_EQ(sock, socket_.get());
333     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
334     rcb_->setSocket(socket_);
335     sock->setReadCB(rcb_);
336     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
337   }
338   void handshakeErr(AsyncSSLSocket* /* sock */,
339                     const AsyncSocketException& ex) noexcept override {
340     std::lock_guard<std::mutex> g(mutex_);
341     cv_.notify_all();
342     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
343     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
344     if (expect_ == EXPECT_ERROR) {
345       // rcb will never be invoked
346       rcb_->setState(STATE_SUCCEEDED);
347     }
348     errorString_ = ex.what();
349   }
350
351   void waitForHandshake() {
352     std::unique_lock<std::mutex> lock(mutex_);
353     cv_.wait(lock, [this] { return state != STATE_WAITING; });
354   }
355
356   ~HandshakeCallback() {
357     EXPECT_EQ(STATE_SUCCEEDED, state);
358   }
359
360   void closeSocket() {
361     socket_->close();
362     state = STATE_SUCCEEDED;
363   }
364
365   std::shared_ptr<AsyncSSLSocket> getSocket() {
366     return socket_;
367   }
368
369   StateEnum state;
370   std::shared_ptr<AsyncSSLSocket> socket_;
371   ReadCallbackBase *rcb_;
372   ExpectType expect_;
373   std::mutex mutex_;
374   std::condition_variable cv_;
375   std::string errorString_;
376 };
377
378 class SSLServerAcceptCallbackBase:
379 public folly::AsyncServerSocket::AcceptCallback {
380 public:
381   explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
382   state(STATE_WAITING), hcb_(hcb) {}
383
384   ~SSLServerAcceptCallbackBase() {
385     EXPECT_EQ(STATE_SUCCEEDED, state);
386   }
387
388   void acceptError(const std::exception& ex) noexcept override {
389     std::cerr << "SSLServerAcceptCallbackBase::acceptError "
390               << ex.what() << std::endl;
391     state = STATE_FAILED;
392   }
393
394   void connectionAccepted(
395       int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
396     if (socket_) {
397       socket_->detachEventBase();
398     }
399     printf("Connection accepted\n");
400     try {
401       // Create a AsyncSSLSocket object with the fd. The socket should be
402       // added to the event base and in the state of accepting SSL connection.
403       socket_ = AsyncSSLSocket::newSocket(ctx_, base_, fd);
404     } catch (const std::exception &e) {
405       LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
406         "object with socket " << e.what() << fd;
407       ::close(fd);
408       acceptError(e);
409       return;
410     }
411
412     connAccepted(socket_);
413   }
414
415   virtual void connAccepted(
416     const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
417
418   void detach() {
419     socket_->detachEventBase();
420   }
421
422   StateEnum state;
423   HandshakeCallback *hcb_;
424   std::shared_ptr<folly::SSLContext> ctx_;
425   std::shared_ptr<AsyncSSLSocket> socket_;
426   folly::EventBase* base_;
427 };
428
429 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
430 public:
431   uint32_t timeout_;
432
433   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
434                                    uint32_t timeout = 0):
435       SSLServerAcceptCallbackBase(hcb),
436       timeout_(timeout) {}
437
438   virtual ~SSLServerAcceptCallback() {
439     if (timeout_ > 0) {
440       // if we set a timeout, we expect failure
441       EXPECT_EQ(hcb_->state, STATE_FAILED);
442       hcb_->setState(STATE_SUCCEEDED);
443     }
444   }
445
446   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
447   void connAccepted(
448     const std::shared_ptr<folly::AsyncSSLSocket> &s)
449     noexcept override {
450     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
451     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
452
453     hcb_->setSocket(sock);
454     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
455     EXPECT_EQ(sock->getSSLState(),
456                       AsyncSSLSocket::STATE_ACCEPTING);
457
458     state = STATE_SUCCEEDED;
459   }
460 };
461
462 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
463 public:
464   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
465       SSLServerAcceptCallback(hcb) {}
466
467   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
468   void connAccepted(
469     const std::shared_ptr<folly::AsyncSSLSocket> &s)
470     noexcept override {
471
472     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
473
474     std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
475               << std::endl;
476     int fd = sock->getFd();
477
478 #ifndef TCP_NOPUSH
479     {
480     // The accepted connection should already have TCP_NODELAY set
481     int value;
482     socklen_t valueLength = sizeof(value);
483     int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
484     EXPECT_EQ(rc, 0);
485     EXPECT_EQ(value, 1);
486     }
487 #endif
488
489     // Unset the TCP_NODELAY option.
490     int value = 0;
491     socklen_t valueLength = sizeof(value);
492     int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
493     EXPECT_EQ(rc, 0);
494
495     rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
496     EXPECT_EQ(rc, 0);
497     EXPECT_EQ(value, 0);
498
499     SSLServerAcceptCallback::connAccepted(sock);
500   }
501 };
502
503 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
504 public:
505   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
506                                              uint32_t timeout = 0):
507     SSLServerAcceptCallback(hcb, timeout) {}
508
509   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
510   void connAccepted(
511     const std::shared_ptr<folly::AsyncSSLSocket> &s)
512     noexcept override {
513     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
514
515     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
516
517     hcb_->setSocket(sock);
518     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
519     ASSERT_TRUE((sock->getSSLState() ==
520                  AsyncSSLSocket::STATE_ACCEPTING) ||
521                 (sock->getSSLState() ==
522                  AsyncSSLSocket::STATE_CACHE_LOOKUP));
523
524     state = STATE_SUCCEEDED;
525   }
526 };
527
528
529 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
530 public:
531   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
532   SSLServerAcceptCallbackBase(hcb)  {}
533
534   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
535   void connAccepted(
536     const std::shared_ptr<folly::AsyncSSLSocket> &s)
537     noexcept override {
538     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
539
540     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
541
542     // The first call to sslAccept() should succeed.
543     hcb_->setSocket(sock);
544     sock->sslAccept(hcb_);
545     EXPECT_EQ(sock->getSSLState(),
546                       AsyncSSLSocket::STATE_ACCEPTING);
547
548     // The second call to sslAccept() should fail.
549     HandshakeCallback callback2(hcb_->rcb_);
550     callback2.setSocket(sock);
551     sock->sslAccept(&callback2);
552     EXPECT_EQ(sock->getSSLState(),
553                       AsyncSSLSocket::STATE_ERROR);
554
555     // Both callbacks should be in the error state.
556     EXPECT_EQ(hcb_->state, STATE_FAILED);
557     EXPECT_EQ(callback2.state, STATE_FAILED);
558
559     state = STATE_SUCCEEDED;
560     hcb_->setState(STATE_SUCCEEDED);
561     callback2.setState(STATE_SUCCEEDED);
562   }
563 };
564
565 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
566 public:
567   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
568   SSLServerAcceptCallbackBase(hcb)  {}
569
570   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
571   void connAccepted(
572     const std::shared_ptr<folly::AsyncSSLSocket> &s)
573     noexcept override {
574     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
575
576     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
577
578     hcb_->setSocket(sock);
579     sock->getEventBase()->tryRunAfterDelay([=] {
580         std::cerr << "Delayed SSL accept, client will have close by now"
581                   << std::endl;
582         // SSL accept will fail
583         EXPECT_EQ(
584           sock->getSSLState(),
585           AsyncSSLSocket::STATE_UNINIT);
586         hcb_->socket_->sslAccept(hcb_);
587         // This registers for an event
588         EXPECT_EQ(
589           sock->getSSLState(),
590           AsyncSSLSocket::STATE_ACCEPTING);
591
592         state = STATE_SUCCEEDED;
593       }, 100);
594   }
595 };
596
597 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
598  public:
599   ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
600     // We don't care if we get invoked or not.
601     // The client may time out and give up before connAccepted() is even
602     // called.
603     state = STATE_SUCCEEDED;
604   }
605
606   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
607   void connAccepted(
608       const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
609     std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
610
611     // Just wait a while before closing the socket, so the client
612     // will time out waiting for the handshake to complete.
613     s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
614   }
615 };
616
617 class TestSSLServer {
618  protected:
619   EventBase evb_;
620   std::shared_ptr<folly::SSLContext> ctx_;
621   SSLServerAcceptCallbackBase *acb_;
622   std::shared_ptr<folly::AsyncServerSocket> socket_;
623   folly::SocketAddress address_;
624   pthread_t thread_;
625
626   static void *Main(void *ctx) {
627     TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
628     self->evb_.loop();
629     self->acb_->detach();
630     std::cerr << "Server thread exited event loop" << std::endl;
631     return nullptr;
632   }
633
634  public:
635   // Create a TestSSLServer.
636   // This immediately starts listening on the given port.
637   explicit TestSSLServer(
638       SSLServerAcceptCallbackBase* acb,
639       bool enableTFO = false);
640
641   // Kill the thread.
642   ~TestSSLServer() {
643     evb_.runInEventBaseThread([&](){
644       socket_->stopAccepting();
645     });
646     std::cerr << "Waiting for server thread to exit" << std::endl;
647     pthread_join(thread_, nullptr);
648   }
649
650   EventBase &getEventBase() { return evb_; }
651
652   const folly::SocketAddress& getAddress() const {
653     return address_;
654   }
655 };
656
657 class TestSSLAsyncCacheServer : public TestSSLServer {
658  public:
659   explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
660         int lookupDelay = 100) :
661       TestSSLServer(acb) {
662     SSL_CTX *sslCtx = ctx_->getSSLCtx();
663     SSL_CTX_sess_set_get_cb(sslCtx,
664                             TestSSLAsyncCacheServer::getSessionCallback);
665     SSL_CTX_set_session_cache_mode(
666       sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
667     asyncCallbacks_ = 0;
668     asyncLookups_ = 0;
669     lookupDelay_ = lookupDelay;
670   }
671
672   uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
673   uint32_t getAsyncLookups() const { return asyncLookups_; }
674
675  private:
676   static uint32_t asyncCallbacks_;
677   static uint32_t asyncLookups_;
678   static uint32_t lookupDelay_;
679
680   static SSL_SESSION* getSessionCallback(SSL* ssl,
681                                          unsigned char* /* sess_id */,
682                                          int /* id_len */,
683                                          int* copyflag) {
684     *copyflag = 0;
685     asyncCallbacks_++;
686     (void)ssl;
687 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
688     if (!SSL_want_sess_cache_lookup(ssl)) {
689       // libssl.so mismatch
690       std::cerr << "no async support" << std::endl;
691       return nullptr;
692     }
693
694     AsyncSSLSocket *sslSocket =
695         AsyncSSLSocket::getFromSSL(ssl);
696     assert(sslSocket != nullptr);
697     // Going to simulate an async cache by just running delaying the miss 100ms
698     if (asyncCallbacks_ % 2 == 0) {
699       // This socket is already blocked on lookup, return miss
700       std::cerr << "returning miss" << std::endl;
701     } else {
702       // fresh meat - block it
703       std::cerr << "async lookup" << std::endl;
704       sslSocket->getEventBase()->tryRunAfterDelay(
705         std::bind(&AsyncSSLSocket::restartSSLAccept,
706                   sslSocket), lookupDelay_);
707       *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
708       asyncLookups_++;
709     }
710 #endif
711     return nullptr;
712   }
713 };
714
715 void getfds(int fds[2]);
716
717 void getctx(
718   std::shared_ptr<folly::SSLContext> clientCtx,
719   std::shared_ptr<folly::SSLContext> serverCtx);
720
721 void sslsocketpair(
722   EventBase* eventBase,
723   AsyncSSLSocket::UniquePtr* clientSock,
724   AsyncSSLSocket::UniquePtr* serverSock);
725
726 class BlockingWriteClient :
727   private AsyncSSLSocket::HandshakeCB,
728   private AsyncTransportWrapper::WriteCallback {
729  public:
730   explicit BlockingWriteClient(
731     AsyncSSLSocket::UniquePtr socket)
732     : socket_(std::move(socket)),
733       bufLen_(2500),
734       iovCount_(2000) {
735     // Fill buf_
736     buf_.reset(new uint8_t[bufLen_]);
737     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
738       buf_[n] = n % 0xff;
739     }
740
741     // Initialize iov_
742     iov_.reset(new struct iovec[iovCount_]);
743     for (uint32_t n = 0; n < iovCount_; ++n) {
744       iov_[n].iov_base = buf_.get() + n;
745       if (n & 0x1) {
746         iov_[n].iov_len = n % bufLen_;
747       } else {
748         iov_[n].iov_len = bufLen_ - (n % bufLen_);
749       }
750     }
751
752     socket_->sslConn(this, std::chrono::milliseconds(100));
753   }
754
755   struct iovec* getIovec() const {
756     return iov_.get();
757   }
758   uint32_t getIovecCount() const {
759     return iovCount_;
760   }
761
762  private:
763   void handshakeSuc(AsyncSSLSocket*) noexcept override {
764     socket_->writev(this, iov_.get(), iovCount_);
765   }
766   void handshakeErr(
767     AsyncSSLSocket*,
768     const AsyncSocketException& ex) noexcept override {
769     ADD_FAILURE() << "client handshake error: " << ex.what();
770   }
771   void writeSuccess() noexcept override {
772     socket_->close();
773   }
774   void writeErr(
775     size_t bytesWritten,
776     const AsyncSocketException& ex) noexcept override {
777     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
778                   << ex.what();
779   }
780
781   AsyncSSLSocket::UniquePtr socket_;
782   uint32_t bufLen_;
783   uint32_t iovCount_;
784   std::unique_ptr<uint8_t[]> buf_;
785   std::unique_ptr<struct iovec[]> iov_;
786 };
787
788 class BlockingWriteServer :
789     private AsyncSSLSocket::HandshakeCB,
790     private AsyncTransportWrapper::ReadCallback {
791  public:
792   explicit BlockingWriteServer(
793     AsyncSSLSocket::UniquePtr socket)
794     : socket_(std::move(socket)),
795       bufSize_(2500 * 2000),
796       bytesRead_(0) {
797     buf_.reset(new uint8_t[bufSize_]);
798     socket_->sslAccept(this, std::chrono::milliseconds(100));
799   }
800
801   void checkBuffer(struct iovec* iov, uint32_t count) const {
802     uint32_t idx = 0;
803     for (uint32_t n = 0; n < count; ++n) {
804       size_t bytesLeft = bytesRead_ - idx;
805       int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
806                       std::min(iov[n].iov_len, bytesLeft));
807       if (rc != 0) {
808         FAIL() << "buffer mismatch at iovec " << n << "/" << count
809                << ": rc=" << rc;
810
811       }
812       if (iov[n].iov_len > bytesLeft) {
813         FAIL() << "server did not read enough data: "
814                << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
815                << " in iovec " << n << "/" << count;
816       }
817
818       idx += iov[n].iov_len;
819     }
820     if (idx != bytesRead_) {
821       ADD_FAILURE() << "server read extra data: " << bytesRead_
822                     << " bytes read; expected " << idx;
823     }
824   }
825
826  private:
827   void handshakeSuc(AsyncSSLSocket*) noexcept override {
828     // Wait 10ms before reading, so the client's writes will initially block.
829     socket_->getEventBase()->tryRunAfterDelay(
830         [this] { socket_->setReadCB(this); }, 10);
831   }
832   void handshakeErr(
833     AsyncSSLSocket*,
834     const AsyncSocketException& ex) noexcept override {
835     ADD_FAILURE() << "server handshake error: " << ex.what();
836   }
837   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
838     *bufReturn = buf_.get() + bytesRead_;
839     *lenReturn = bufSize_ - bytesRead_;
840   }
841   void readDataAvailable(size_t len) noexcept override {
842     bytesRead_ += len;
843     socket_->setReadCB(nullptr);
844     socket_->getEventBase()->tryRunAfterDelay(
845         [this] { socket_->setReadCB(this); }, 2);
846   }
847   void readEOF() noexcept override {
848     socket_->close();
849   }
850   void readErr(
851     const AsyncSocketException& ex) noexcept override {
852     ADD_FAILURE() << "server read error: " << ex.what();
853   }
854
855   AsyncSSLSocket::UniquePtr socket_;
856   uint32_t bufSize_;
857   uint32_t bytesRead_;
858   std::unique_ptr<uint8_t[]> buf_;
859 };
860
861 class NpnClient :
862   private AsyncSSLSocket::HandshakeCB,
863   private AsyncTransportWrapper::WriteCallback {
864  public:
865   explicit NpnClient(
866     AsyncSSLSocket::UniquePtr socket)
867       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
868     socket_->sslConn(this);
869   }
870
871   const unsigned char* nextProto;
872   unsigned nextProtoLength;
873   SSLContext::NextProtocolType protocolType;
874
875  private:
876   void handshakeSuc(AsyncSSLSocket*) noexcept override {
877     socket_->getSelectedNextProtocol(
878         &nextProto, &nextProtoLength, &protocolType);
879   }
880   void handshakeErr(
881     AsyncSSLSocket*,
882     const AsyncSocketException& ex) noexcept override {
883     ADD_FAILURE() << "client handshake error: " << ex.what();
884   }
885   void writeSuccess() noexcept override {
886     socket_->close();
887   }
888   void writeErr(
889     size_t bytesWritten,
890     const AsyncSocketException& ex) noexcept override {
891     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
892                   << ex.what();
893   }
894
895   AsyncSSLSocket::UniquePtr socket_;
896 };
897
898 class NpnServer :
899     private AsyncSSLSocket::HandshakeCB,
900     private AsyncTransportWrapper::ReadCallback {
901  public:
902   explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
903       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
904     socket_->sslAccept(this);
905   }
906
907   const unsigned char* nextProto;
908   unsigned nextProtoLength;
909   SSLContext::NextProtocolType protocolType;
910
911  private:
912   void handshakeSuc(AsyncSSLSocket*) noexcept override {
913     socket_->getSelectedNextProtocol(
914         &nextProto, &nextProtoLength, &protocolType);
915   }
916   void handshakeErr(
917     AsyncSSLSocket*,
918     const AsyncSocketException& ex) noexcept override {
919     ADD_FAILURE() << "server handshake error: " << ex.what();
920   }
921   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
922     *lenReturn = 0;
923   }
924   void readDataAvailable(size_t /* len */) noexcept override {}
925   void readEOF() noexcept override {
926     socket_->close();
927   }
928   void readErr(
929     const AsyncSocketException& ex) noexcept override {
930     ADD_FAILURE() << "server read error: " << ex.what();
931   }
932
933   AsyncSSLSocket::UniquePtr socket_;
934 };
935
936 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
937                             public AsyncTransportWrapper::ReadCallback {
938  public:
939   explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
940       : socket_(std::move(socket)) {
941     socket_->sslAccept(this);
942   }
943
944   ~RenegotiatingServer() {
945     socket_->setReadCB(nullptr);
946   }
947
948   void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
949     LOG(INFO) << "Renegotiating server handshake success";
950     socket_->setReadCB(this);
951   }
952   void handshakeErr(
953       AsyncSSLSocket*,
954       const AsyncSocketException& ex) noexcept override {
955     ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
956   }
957   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
958     *lenReturn = sizeof(buf);
959     *bufReturn = buf;
960   }
961   void readDataAvailable(size_t /* len */) noexcept override {}
962   void readEOF() noexcept override {}
963   void readErr(const AsyncSocketException& ex) noexcept override {
964     LOG(INFO) << "server got read error " << ex.what();
965     auto exPtr = dynamic_cast<const SSLException*>(&ex);
966     ASSERT_NE(nullptr, exPtr);
967     std::string exStr(ex.what());
968     SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
969     ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
970     renegotiationError_ = true;
971   }
972
973   AsyncSSLSocket::UniquePtr socket_;
974   unsigned char buf[128];
975   bool renegotiationError_{false};
976 };
977
978 #ifndef OPENSSL_NO_TLSEXT
979 class SNIClient :
980   private AsyncSSLSocket::HandshakeCB,
981   private AsyncTransportWrapper::WriteCallback {
982  public:
983   explicit SNIClient(
984     AsyncSSLSocket::UniquePtr socket)
985       : serverNameMatch(false), socket_(std::move(socket)) {
986     socket_->sslConn(this);
987   }
988
989   bool serverNameMatch;
990
991  private:
992   void handshakeSuc(AsyncSSLSocket*) noexcept override {
993     serverNameMatch = socket_->isServerNameMatch();
994   }
995   void handshakeErr(
996     AsyncSSLSocket*,
997     const AsyncSocketException& ex) noexcept override {
998     ADD_FAILURE() << "client handshake error: " << ex.what();
999   }
1000   void writeSuccess() noexcept override {
1001     socket_->close();
1002   }
1003   void writeErr(
1004     size_t bytesWritten,
1005     const AsyncSocketException& ex) noexcept override {
1006     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1007                   << ex.what();
1008   }
1009
1010   AsyncSSLSocket::UniquePtr socket_;
1011 };
1012
1013 class SNIServer :
1014     private AsyncSSLSocket::HandshakeCB,
1015     private AsyncTransportWrapper::ReadCallback {
1016  public:
1017   explicit SNIServer(
1018     AsyncSSLSocket::UniquePtr socket,
1019     const std::shared_ptr<folly::SSLContext>& ctx,
1020     const std::shared_ptr<folly::SSLContext>& sniCtx,
1021     const std::string& expectedServerName)
1022       : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1023         expectedServerName_(expectedServerName) {
1024     ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1025                                          std::placeholders::_1));
1026     socket_->sslAccept(this);
1027   }
1028
1029   bool serverNameMatch;
1030
1031  private:
1032   void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1033   void handshakeErr(
1034     AsyncSSLSocket*,
1035     const AsyncSocketException& ex) noexcept override {
1036     ADD_FAILURE() << "server handshake error: " << ex.what();
1037   }
1038   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1039     *lenReturn = 0;
1040   }
1041   void readDataAvailable(size_t /* len */) noexcept override {}
1042   void readEOF() noexcept override {
1043     socket_->close();
1044   }
1045   void readErr(
1046     const AsyncSocketException& ex) noexcept override {
1047     ADD_FAILURE() << "server read error: " << ex.what();
1048   }
1049
1050   folly::SSLContext::ServerNameCallbackResult
1051     serverNameCallback(SSL *ssl) {
1052     const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1053     if (sniCtx_ &&
1054         sn &&
1055         !strcasecmp(expectedServerName_.c_str(), sn)) {
1056       AsyncSSLSocket *sslSocket =
1057           AsyncSSLSocket::getFromSSL(ssl);
1058       sslSocket->switchServerSSLContext(sniCtx_);
1059       serverNameMatch = true;
1060       return folly::SSLContext::SERVER_NAME_FOUND;
1061     } else {
1062       serverNameMatch = false;
1063       return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1064     }
1065   }
1066
1067   AsyncSSLSocket::UniquePtr socket_;
1068   std::shared_ptr<folly::SSLContext> sniCtx_;
1069   std::string expectedServerName_;
1070 };
1071 #endif
1072
1073 class SSLClient : public AsyncSocket::ConnectCallback,
1074                   public AsyncTransportWrapper::WriteCallback,
1075                   public AsyncTransportWrapper::ReadCallback
1076 {
1077  private:
1078   EventBase *eventBase_;
1079   std::shared_ptr<AsyncSSLSocket> sslSocket_;
1080   SSL_SESSION *session_;
1081   std::shared_ptr<folly::SSLContext> ctx_;
1082   uint32_t requests_;
1083   folly::SocketAddress address_;
1084   uint32_t timeout_;
1085   char buf_[128];
1086   char readbuf_[128];
1087   uint32_t bytesRead_;
1088   uint32_t hit_;
1089   uint32_t miss_;
1090   uint32_t errors_;
1091   uint32_t writeAfterConnectErrors_;
1092
1093   // These settings test that we eventually drain the
1094   // socket, even if the maxReadsPerEvent_ is hit during
1095   // a event loop iteration.
1096   static constexpr size_t kMaxReadsPerEvent = 2;
1097   // 2 event loop iterations
1098   static constexpr size_t kMaxReadBufferSz =
1099     sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1100
1101  public:
1102   SSLClient(EventBase *eventBase,
1103             const folly::SocketAddress& address,
1104             uint32_t requests,
1105             uint32_t timeout = 0)
1106       : eventBase_(eventBase),
1107         session_(nullptr),
1108         requests_(requests),
1109         address_(address),
1110         timeout_(timeout),
1111         bytesRead_(0),
1112         hit_(0),
1113         miss_(0),
1114         errors_(0),
1115         writeAfterConnectErrors_(0) {
1116     ctx_.reset(new folly::SSLContext());
1117     ctx_->setOptions(SSL_OP_NO_TICKET);
1118     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1119     memset(buf_, 'a', sizeof(buf_));
1120   }
1121
1122   ~SSLClient() {
1123     if (session_) {
1124       SSL_SESSION_free(session_);
1125     }
1126     if (errors_ == 0) {
1127       EXPECT_EQ(bytesRead_, sizeof(buf_));
1128     }
1129   }
1130
1131   uint32_t getHit() const { return hit_; }
1132
1133   uint32_t getMiss() const { return miss_; }
1134
1135   uint32_t getErrors() const { return errors_; }
1136
1137   uint32_t getWriteAfterConnectErrors() const {
1138     return writeAfterConnectErrors_;
1139   }
1140
1141   void connect(bool writeNow = false) {
1142     sslSocket_ = AsyncSSLSocket::newSocket(
1143       ctx_, eventBase_);
1144     if (session_ != nullptr) {
1145       sslSocket_->setSSLSession(session_);
1146     }
1147     requests_--;
1148     sslSocket_->connect(this, address_, timeout_);
1149     if (sslSocket_ && writeNow) {
1150       // write some junk, used in an error test
1151       sslSocket_->write(this, buf_, sizeof(buf_));
1152     }
1153   }
1154
1155   void connectSuccess() noexcept override {
1156     std::cerr << "client SSL socket connected" << std::endl;
1157     if (sslSocket_->getSSLSessionReused()) {
1158       hit_++;
1159     } else {
1160       miss_++;
1161       if (session_ != nullptr) {
1162         SSL_SESSION_free(session_);
1163       }
1164       session_ = sslSocket_->getSSLSession();
1165     }
1166
1167     // write()
1168     sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1169     sslSocket_->write(this, buf_, sizeof(buf_));
1170     sslSocket_->setReadCB(this);
1171     memset(readbuf_, 'b', sizeof(readbuf_));
1172     bytesRead_ = 0;
1173   }
1174
1175   void connectErr(
1176     const AsyncSocketException& ex) noexcept override {
1177     std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1178     errors_++;
1179     sslSocket_.reset();
1180   }
1181
1182   void writeSuccess() noexcept override {
1183     std::cerr << "client write success" << std::endl;
1184   }
1185
1186   void writeErr(size_t /* bytesWritten */,
1187                 const AsyncSocketException& ex) noexcept override {
1188     std::cerr << "client writeError: " << ex.what() << std::endl;
1189     if (!sslSocket_) {
1190       writeAfterConnectErrors_++;
1191     }
1192   }
1193
1194   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1195     *bufReturn = readbuf_ + bytesRead_;
1196     *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1197   }
1198
1199   void readEOF() noexcept override {
1200     std::cerr << "client readEOF" << std::endl;
1201   }
1202
1203   void readErr(
1204     const AsyncSocketException& ex) noexcept override {
1205     std::cerr << "client readError: " << ex.what() << std::endl;
1206   }
1207
1208   void readDataAvailable(size_t len) noexcept override {
1209     std::cerr << "client read data: " << len << std::endl;
1210     bytesRead_ += len;
1211     if (bytesRead_ == sizeof(buf_)) {
1212       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1213       sslSocket_->closeNow();
1214       sslSocket_.reset();
1215       if (requests_ != 0) {
1216         connect();
1217       }
1218     }
1219   }
1220
1221 };
1222
1223 class SSLHandshakeBase :
1224   public AsyncSSLSocket::HandshakeCB,
1225   private AsyncTransportWrapper::WriteCallback {
1226  public:
1227   explicit SSLHandshakeBase(
1228    AsyncSSLSocket::UniquePtr socket,
1229    bool preverifyResult,
1230    bool verifyResult) :
1231     handshakeVerify_(false),
1232     handshakeSuccess_(false),
1233     handshakeError_(false),
1234     socket_(std::move(socket)),
1235     preverifyResult_(preverifyResult),
1236     verifyResult_(verifyResult) {
1237   }
1238
1239   AsyncSSLSocket::UniquePtr moveSocket() && {
1240     return std::move(socket_);
1241   }
1242
1243   bool handshakeVerify_;
1244   bool handshakeSuccess_;
1245   bool handshakeError_;
1246   std::chrono::nanoseconds handshakeTime;
1247
1248  protected:
1249   AsyncSSLSocket::UniquePtr socket_;
1250   bool preverifyResult_;
1251   bool verifyResult_;
1252
1253   // HandshakeCallback
1254   bool handshakeVer(AsyncSSLSocket* /* sock */,
1255                     bool preverifyOk,
1256                     X509_STORE_CTX* /* ctx */) noexcept override {
1257     handshakeVerify_ = true;
1258
1259     EXPECT_EQ(preverifyResult_, preverifyOk);
1260     return verifyResult_;
1261   }
1262
1263   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1264     LOG(INFO) << "Handshake success";
1265     handshakeSuccess_ = true;
1266     handshakeTime = socket_->getHandshakeTime();
1267   }
1268
1269   void handshakeErr(
1270       AsyncSSLSocket*,
1271       const AsyncSocketException& ex) noexcept override {
1272     LOG(INFO) << "Handshake error " << ex.what();
1273     handshakeError_ = true;
1274     handshakeTime = socket_->getHandshakeTime();
1275   }
1276
1277   // WriteCallback
1278   void writeSuccess() noexcept override {
1279     socket_->close();
1280   }
1281
1282   void writeErr(
1283    size_t bytesWritten,
1284    const AsyncSocketException& ex) noexcept override {
1285     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1286                   << ex.what();
1287   }
1288 };
1289
1290 class SSLHandshakeClient : public SSLHandshakeBase {
1291  public:
1292   SSLHandshakeClient(
1293    AsyncSSLSocket::UniquePtr socket,
1294    bool preverifyResult,
1295    bool verifyResult) :
1296     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1297     socket_->sslConn(this, std::chrono::milliseconds::zero());
1298   }
1299 };
1300
1301 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1302  public:
1303   SSLHandshakeClientNoVerify(
1304    AsyncSSLSocket::UniquePtr socket,
1305    bool preverifyResult,
1306    bool verifyResult) :
1307     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1308     socket_->sslConn(
1309         this,
1310         std::chrono::milliseconds::zero(),
1311         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1312   }
1313 };
1314
1315 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1316  public:
1317   SSLHandshakeClientDoVerify(
1318    AsyncSSLSocket::UniquePtr socket,
1319    bool preverifyResult,
1320    bool verifyResult) :
1321     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1322     socket_->sslConn(
1323         this,
1324         std::chrono::milliseconds::zero(),
1325         folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1326   }
1327 };
1328
1329 class SSLHandshakeServer : public SSLHandshakeBase {
1330  public:
1331   SSLHandshakeServer(
1332       AsyncSSLSocket::UniquePtr socket,
1333       bool preverifyResult,
1334       bool verifyResult)
1335     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1336     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1337   }
1338 };
1339
1340 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1341  public:
1342   SSLHandshakeServerParseClientHello(
1343       AsyncSSLSocket::UniquePtr socket,
1344       bool preverifyResult,
1345       bool verifyResult)
1346       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1347     socket_->enableClientHelloParsing();
1348     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1349   }
1350
1351   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1352
1353  protected:
1354   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1355     handshakeSuccess_ = true;
1356     sock->getSSLSharedCiphers(sharedCiphers_);
1357     sock->getSSLServerCiphers(serverCiphers_);
1358     sock->getSSLClientCiphers(clientCiphers_);
1359     chosenCipher_ = sock->getNegotiatedCipherName();
1360   }
1361 };
1362
1363
1364 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1365  public:
1366   SSLHandshakeServerNoVerify(
1367       AsyncSSLSocket::UniquePtr socket,
1368       bool preverifyResult,
1369       bool verifyResult)
1370     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1371     socket_->sslAccept(
1372         this,
1373         std::chrono::milliseconds::zero(),
1374         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1375   }
1376 };
1377
1378 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1379  public:
1380   SSLHandshakeServerDoVerify(
1381       AsyncSSLSocket::UniquePtr socket,
1382       bool preverifyResult,
1383       bool verifyResult)
1384     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1385     socket_->sslAccept(
1386         this,
1387         std::chrono::milliseconds::zero(),
1388         folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1389   }
1390 };
1391
1392 class EventBaseAborter : public AsyncTimeout {
1393  public:
1394   EventBaseAborter(EventBase* eventBase,
1395                    uint32_t timeoutMS)
1396     : AsyncTimeout(
1397       eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1398     , eventBase_(eventBase) {
1399     scheduleTimeout(timeoutMS);
1400   }
1401
1402   void timeoutExpired() noexcept override {
1403     FAIL() << "test timed out";
1404     eventBase_->terminateLoopSoon();
1405   }
1406
1407  private:
1408   EventBase* eventBase_;
1409 };
1410
1411 }