apply clang-tidy modernize-use-override
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/experimental/TestUtil.h>
24 #include <folly/io/async/AsyncSSLSocket.h>
25 #include <folly/io/async/AsyncServerSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/AsyncTimeout.h>
28 #include <folly/io/async/AsyncTransport.h>
29 #include <folly/io/async/EventBase.h>
30 #include <folly/io/async/ssl/SSLErrors.h>
31 #include <folly/io/async/test/TestSSLServer.h>
32 #include <folly/portability/GTest.h>
33 #include <folly/portability/Sockets.h>
34 #include <folly/portability/Unistd.h>
35
36 #include <fcntl.h>
37 #include <sys/types.h>
38 #include <condition_variable>
39 #include <iostream>
40 #include <list>
41
42 namespace folly {
43
44 // The destructors of all callback classes assert that the state is
45 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
46 // are responsible for setting the succeeded state properly before the
47 // destructors are called.
48
49 class SendMsgParamsCallbackBase :
50       public folly::AsyncSocket::SendMsgParamsCallback {
51  public:
52   SendMsgParamsCallbackBase() {}
53
54   void setSocket(
55     const std::shared_ptr<AsyncSSLSocket> &socket) {
56     socket_ = socket;
57     oldCallback_ = socket_->getSendMsgParamsCB();
58     socket_->setSendMsgParamCB(this);
59   }
60
61   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
62                                                                      override {
63     return oldCallback_->getFlags(flags);
64   }
65
66   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
67     oldCallback_->getAncillaryData(flags, data);
68   }
69
70   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
71     return oldCallback_->getAncillaryDataSize(flags);
72   }
73
74   std::shared_ptr<AsyncSSLSocket> socket_;
75   folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
76 };
77
78 class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
79  public:
80   SendMsgFlagsCallback() {}
81
82   void resetFlags(int flags) {
83     flags_ = flags;
84   }
85
86   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
87                                                                   override {
88     if (flags_) {
89       return flags_;
90     } else {
91       return oldCallback_->getFlags(flags);
92     }
93   }
94
95   int flags_{0};
96 };
97
98 class SendMsgDataCallback : public SendMsgFlagsCallback {
99  public:
100   SendMsgDataCallback() {}
101
102   void resetData(std::vector<char>&& data) {
103     ancillaryData_.swap(data);
104   }
105
106   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
107     if (ancillaryData_.size()) {
108       std::cerr << "getAncillaryData: copying data" << std::endl;
109       memcpy(data, ancillaryData_.data(), ancillaryData_.size());
110     } else {
111       oldCallback_->getAncillaryData(flags, data);
112     }
113   }
114
115   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
116     if (ancillaryData_.size()) {
117       std::cerr << "getAncillaryDataSize: returning size" << std::endl;
118       return ancillaryData_.size();
119     } else {
120       return oldCallback_->getAncillaryDataSize(flags);
121     }
122   }
123
124   std::vector<char> ancillaryData_;
125 };
126
127 class WriteCallbackBase :
128 public AsyncTransportWrapper::WriteCallback {
129 public:
130   explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
131       : state(STATE_WAITING)
132       , bytesWritten(0)
133       , exception(AsyncSocketException::UNKNOWN, "none")
134       , mcb_(mcb) {}
135
136   ~WriteCallbackBase() override {
137     EXPECT_EQ(STATE_SUCCEEDED, state);
138   }
139
140   virtual void setSocket(
141     const std::shared_ptr<AsyncSSLSocket> &socket) {
142     socket_ = socket;
143     if (mcb_) {
144       mcb_->setSocket(socket);
145     }
146   }
147
148   void writeSuccess() noexcept override {
149     std::cerr << "writeSuccess" << std::endl;
150     state = STATE_SUCCEEDED;
151   }
152
153   void writeErr(
154     size_t nBytesWritten,
155     const AsyncSocketException& ex) noexcept override {
156     std::cerr << "writeError: bytesWritten " << nBytesWritten
157          << ", exception " << ex.what() << std::endl;
158
159     state = STATE_FAILED;
160     this->bytesWritten = nBytesWritten;
161     exception = ex;
162     socket_->close();
163   }
164
165   std::shared_ptr<AsyncSSLSocket> socket_;
166   StateEnum state;
167   size_t bytesWritten;
168   AsyncSocketException exception;
169   SendMsgParamsCallbackBase* mcb_;
170 };
171
172 class ExpectWriteErrorCallback :
173 public WriteCallbackBase {
174 public:
175   explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
176       : WriteCallbackBase(mcb) {}
177
178   ~ExpectWriteErrorCallback() override {
179     EXPECT_EQ(STATE_FAILED, state);
180     EXPECT_EQ(exception.type_,
181              AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
182     EXPECT_EQ(exception.errno_, 22);
183     // Suppress the assert in  ~WriteCallbackBase()
184     state = STATE_SUCCEEDED;
185   }
186 };
187
188 #ifdef MSG_ERRQUEUE
189 /* copied from include/uapi/linux/net_tstamp.h */
190 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
191 enum SOF_TIMESTAMPING {
192   SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
193   SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
194   SOF_TIMESTAMPING_OPT_ID = (1 << 7),
195   SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
196   SOF_TIMESTAMPING_TX_ACK = (1 << 9),
197   SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
198 };
199
200 class WriteCheckTimestampCallback :
201  public WriteCallbackBase {
202 public:
203   explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
204     : WriteCallbackBase(mcb) {}
205
206   ~WriteCheckTimestampCallback() override {
207     EXPECT_EQ(STATE_SUCCEEDED, state);
208     EXPECT_TRUE(gotTimestamp_);
209     EXPECT_TRUE(gotByteSeq_);
210   }
211
212   void setSocket(
213     const std::shared_ptr<AsyncSSLSocket> &socket) override {
214     WriteCallbackBase::setSocket(socket);
215
216     EXPECT_NE(socket_->getFd(), 0);
217     int flags = SOF_TIMESTAMPING_OPT_ID
218                 | SOF_TIMESTAMPING_OPT_TSONLY
219                 | SOF_TIMESTAMPING_SOFTWARE;
220     AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
221     int ret = tstampingOpt.apply(socket_->getFd(), flags);
222     EXPECT_EQ(ret, 0);
223   }
224
225   void checkForTimestampNotifications() noexcept {
226     int fd = socket_->getFd();
227     std::vector<char> ctrl(1024, 0);
228     unsigned char data;
229     struct msghdr msg;
230     iovec entry;
231
232     memset(&msg, 0, sizeof(msg));
233     entry.iov_base = &data;
234     entry.iov_len = sizeof(data);
235     msg.msg_iov = &entry;
236     msg.msg_iovlen = 1;
237     msg.msg_control = ctrl.data();
238     msg.msg_controllen = ctrl.size();
239
240     int ret;
241     while (true) {
242       ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
243       if (ret < 0) {
244         if (errno != EAGAIN) {
245           auto errnoCopy = errno;
246           std::cerr << "::recvmsg exited with code " << ret
247                     << ", errno: " << errnoCopy << std::endl;
248           AsyncSocketException ex(
249             AsyncSocketException::INTERNAL_ERROR,
250             "recvmsg() failed",
251             errnoCopy);
252           exception = ex;
253         }
254         return;
255       }
256
257       for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
258            cmsg != nullptr && cmsg->cmsg_len != 0;
259            cmsg = CMSG_NXTHDR(&msg, cmsg)) {
260         if (cmsg->cmsg_level == SOL_SOCKET &&
261             cmsg->cmsg_type == SCM_TIMESTAMPING) {
262           gotTimestamp_ = true;
263           continue;
264         }
265
266         if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
267             (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
268           gotByteSeq_ = true;
269           continue;
270         }
271       }
272     }
273   }
274
275   bool gotTimestamp_{false};
276   bool gotByteSeq_{false};
277 };
278 #endif // MSG_ERRQUEUE
279
280 class ReadCallbackBase :
281 public AsyncTransportWrapper::ReadCallback {
282  public:
283   explicit ReadCallbackBase(WriteCallbackBase* wcb)
284       : wcb_(wcb), state(STATE_WAITING) {}
285
286   ~ReadCallbackBase() override {
287     EXPECT_EQ(STATE_SUCCEEDED, state);
288   }
289
290   void setSocket(
291     const std::shared_ptr<AsyncSSLSocket> &socket) {
292     socket_ = socket;
293   }
294
295   void setState(StateEnum s) {
296     state = s;
297     if (wcb_) {
298       wcb_->state = s;
299     }
300   }
301
302   void readErr(
303     const AsyncSocketException& ex) noexcept override {
304     std::cerr << "readError " << ex.what() << std::endl;
305     state = STATE_FAILED;
306     socket_->close();
307   }
308
309   void readEOF() noexcept override {
310     std::cerr << "readEOF" << std::endl;
311
312     socket_->close();
313   }
314
315   std::shared_ptr<AsyncSSLSocket> socket_;
316   WriteCallbackBase *wcb_;
317   StateEnum state;
318 };
319
320 class ReadCallback : public ReadCallbackBase {
321 public:
322   explicit ReadCallback(WriteCallbackBase *wcb)
323       : ReadCallbackBase(wcb)
324       , buffers() {}
325
326   ~ReadCallback() override {
327     for (std::vector<Buffer>::iterator it = buffers.begin();
328          it != buffers.end();
329          ++it) {
330       it->free();
331     }
332     currentBuffer.free();
333   }
334
335   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
336     if (!currentBuffer.buffer) {
337       currentBuffer.allocate(4096);
338     }
339     *bufReturn = currentBuffer.buffer;
340     *lenReturn = currentBuffer.length;
341   }
342
343   void readDataAvailable(size_t len) noexcept override {
344     std::cerr << "readDataAvailable, len " << len << std::endl;
345
346     currentBuffer.length = len;
347
348     wcb_->setSocket(socket_);
349
350     // Write back the same data.
351     socket_->write(wcb_, currentBuffer.buffer, len);
352
353     buffers.push_back(currentBuffer);
354     currentBuffer.reset();
355     state = STATE_SUCCEEDED;
356   }
357
358   class Buffer {
359   public:
360     Buffer() : buffer(nullptr), length(0) {}
361     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
362
363     void reset() {
364       buffer = nullptr;
365       length = 0;
366     }
367     void allocate(size_t len) {
368       assert(buffer == nullptr);
369       this->buffer = static_cast<char*>(malloc(len));
370       this->length = len;
371     }
372     void free() {
373       ::free(buffer);
374       reset();
375     }
376
377     char* buffer;
378     size_t length;
379   };
380
381   std::vector<Buffer> buffers;
382   Buffer currentBuffer;
383 };
384
385 class ReadErrorCallback : public ReadCallbackBase {
386 public:
387   explicit ReadErrorCallback(WriteCallbackBase *wcb)
388       : ReadCallbackBase(wcb) {}
389
390   // Return nullptr buffer to trigger readError()
391   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
392     *bufReturn = nullptr;
393     *lenReturn = 0;
394   }
395
396   void readDataAvailable(size_t /* len */) noexcept override {
397     // This should never to called.
398     FAIL();
399   }
400
401   void readErr(
402     const AsyncSocketException& ex) noexcept override {
403     ReadCallbackBase::readErr(ex);
404     std::cerr << "ReadErrorCallback::readError" << std::endl;
405     setState(STATE_SUCCEEDED);
406   }
407 };
408
409 class ReadEOFCallback : public ReadCallbackBase {
410  public:
411   explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
412
413   // Return nullptr buffer to trigger readError()
414   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
415     *bufReturn = nullptr;
416     *lenReturn = 0;
417   }
418
419   void readDataAvailable(size_t /* len */) noexcept override {
420     // This should never to called.
421     FAIL();
422   }
423
424   void readEOF() noexcept override {
425     ReadCallbackBase::readEOF();
426     setState(STATE_SUCCEEDED);
427   }
428 };
429
430 class WriteErrorCallback : public ReadCallback {
431 public:
432   explicit WriteErrorCallback(WriteCallbackBase *wcb)
433       : ReadCallback(wcb) {}
434
435   void readDataAvailable(size_t len) noexcept override {
436     std::cerr << "readDataAvailable, len " << len << std::endl;
437
438     currentBuffer.length = len;
439
440     // close the socket before writing to trigger writeError().
441     ::close(socket_->getFd());
442
443     wcb_->setSocket(socket_);
444
445     // Write back the same data.
446     folly::test::msvcSuppressAbortOnInvalidParams([&] {
447       socket_->write(wcb_, currentBuffer.buffer, len);
448     });
449
450     if (wcb_->state == STATE_FAILED) {
451       setState(STATE_SUCCEEDED);
452     } else {
453       state = STATE_FAILED;
454     }
455
456     buffers.push_back(currentBuffer);
457     currentBuffer.reset();
458   }
459
460   void readErr(const AsyncSocketException& ex) noexcept override {
461     std::cerr << "readError " << ex.what() << std::endl;
462     // do nothing since this is expected
463   }
464 };
465
466 class EmptyReadCallback : public ReadCallback {
467 public:
468   explicit EmptyReadCallback()
469       : ReadCallback(nullptr) {}
470
471   void readErr(const AsyncSocketException& ex) noexcept override {
472     std::cerr << "readError " << ex.what() << std::endl;
473     state = STATE_FAILED;
474     if (tcpSocket_) {
475       tcpSocket_->close();
476     }
477   }
478
479   void readEOF() noexcept override {
480     std::cerr << "readEOF" << std::endl;
481     if (tcpSocket_) {
482       tcpSocket_->close();
483     }
484     state = STATE_SUCCEEDED;
485   }
486
487   std::shared_ptr<AsyncSocket> tcpSocket_;
488 };
489
490 class HandshakeCallback :
491 public AsyncSSLSocket::HandshakeCB {
492 public:
493   enum ExpectType {
494     EXPECT_SUCCESS,
495     EXPECT_ERROR
496   };
497
498   explicit HandshakeCallback(ReadCallbackBase *rcb,
499                              ExpectType expect = EXPECT_SUCCESS):
500       state(STATE_WAITING),
501       rcb_(rcb),
502       expect_(expect) {}
503
504   void setSocket(
505     const std::shared_ptr<AsyncSSLSocket> &socket) {
506     socket_ = socket;
507   }
508
509   void setState(StateEnum s) {
510     state = s;
511     rcb_->setState(s);
512   }
513
514   // Functions inherited from AsyncSSLSocketHandshakeCallback
515   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
516     std::lock_guard<std::mutex> g(mutex_);
517     cv_.notify_all();
518     EXPECT_EQ(sock, socket_.get());
519     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
520     rcb_->setSocket(socket_);
521     sock->setReadCB(rcb_);
522     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
523   }
524   void handshakeErr(AsyncSSLSocket* /* sock */,
525                     const AsyncSocketException& ex) noexcept override {
526     std::lock_guard<std::mutex> g(mutex_);
527     cv_.notify_all();
528     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
529     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
530     if (expect_ == EXPECT_ERROR) {
531       // rcb will never be invoked
532       rcb_->setState(STATE_SUCCEEDED);
533     }
534     errorString_ = ex.what();
535   }
536
537   void waitForHandshake() {
538     std::unique_lock<std::mutex> lock(mutex_);
539     cv_.wait(lock, [this] { return state != STATE_WAITING; });
540   }
541
542   ~HandshakeCallback() override {
543     EXPECT_EQ(STATE_SUCCEEDED, state);
544   }
545
546   void closeSocket() {
547     socket_->close();
548     state = STATE_SUCCEEDED;
549   }
550
551   std::shared_ptr<AsyncSSLSocket> getSocket() {
552     return socket_;
553   }
554
555   StateEnum state;
556   std::shared_ptr<AsyncSSLSocket> socket_;
557   ReadCallbackBase *rcb_;
558   ExpectType expect_;
559   std::mutex mutex_;
560   std::condition_variable cv_;
561   std::string errorString_;
562 };
563
564 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
565 public:
566   uint32_t timeout_;
567
568   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
569                                    uint32_t timeout = 0):
570       SSLServerAcceptCallbackBase(hcb),
571       timeout_(timeout) {}
572
573   ~SSLServerAcceptCallback() override {
574     if (timeout_ > 0) {
575       // if we set a timeout, we expect failure
576       EXPECT_EQ(hcb_->state, STATE_FAILED);
577       hcb_->setState(STATE_SUCCEEDED);
578     }
579   }
580
581   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
582   void connAccepted(
583     const std::shared_ptr<folly::AsyncSSLSocket> &s)
584     noexcept override {
585     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
586     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
587
588     hcb_->setSocket(sock);
589     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
590     EXPECT_EQ(sock->getSSLState(),
591                       AsyncSSLSocket::STATE_ACCEPTING);
592
593     state = STATE_SUCCEEDED;
594   }
595 };
596
597 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
598 public:
599   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
600       SSLServerAcceptCallback(hcb) {}
601
602   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
603   void connAccepted(
604     const std::shared_ptr<folly::AsyncSSLSocket> &s)
605     noexcept override {
606
607     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
608
609     std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
610               << std::endl;
611     int fd = sock->getFd();
612
613 #ifndef TCP_NOPUSH
614     {
615     // The accepted connection should already have TCP_NODELAY set
616     int value;
617     socklen_t valueLength = sizeof(value);
618     int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
619     EXPECT_EQ(rc, 0);
620     EXPECT_EQ(value, 1);
621     }
622 #endif
623
624     // Unset the TCP_NODELAY option.
625     int value = 0;
626     socklen_t valueLength = sizeof(value);
627     int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
628     EXPECT_EQ(rc, 0);
629
630     rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
631     EXPECT_EQ(rc, 0);
632     EXPECT_EQ(value, 0);
633
634     SSLServerAcceptCallback::connAccepted(sock);
635   }
636 };
637
638 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
639 public:
640   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
641                                              uint32_t timeout = 0):
642     SSLServerAcceptCallback(hcb, timeout) {}
643
644   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
645   void connAccepted(
646     const std::shared_ptr<folly::AsyncSSLSocket> &s)
647     noexcept override {
648     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
649
650     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
651
652     hcb_->setSocket(sock);
653     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
654     ASSERT_TRUE((sock->getSSLState() ==
655                  AsyncSSLSocket::STATE_ACCEPTING) ||
656                 (sock->getSSLState() ==
657                  AsyncSSLSocket::STATE_CACHE_LOOKUP));
658
659     state = STATE_SUCCEEDED;
660   }
661 };
662
663
664 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
665 public:
666   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
667   SSLServerAcceptCallbackBase(hcb)  {}
668
669   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
670   void connAccepted(
671     const std::shared_ptr<folly::AsyncSSLSocket> &s)
672     noexcept override {
673     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
674
675     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
676
677     // The first call to sslAccept() should succeed.
678     hcb_->setSocket(sock);
679     sock->sslAccept(hcb_);
680     EXPECT_EQ(sock->getSSLState(),
681                       AsyncSSLSocket::STATE_ACCEPTING);
682
683     // The second call to sslAccept() should fail.
684     HandshakeCallback callback2(hcb_->rcb_);
685     callback2.setSocket(sock);
686     sock->sslAccept(&callback2);
687     EXPECT_EQ(sock->getSSLState(),
688                       AsyncSSLSocket::STATE_ERROR);
689
690     // Both callbacks should be in the error state.
691     EXPECT_EQ(hcb_->state, STATE_FAILED);
692     EXPECT_EQ(callback2.state, STATE_FAILED);
693
694     state = STATE_SUCCEEDED;
695     hcb_->setState(STATE_SUCCEEDED);
696     callback2.setState(STATE_SUCCEEDED);
697   }
698 };
699
700 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
701 public:
702   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
703   SSLServerAcceptCallbackBase(hcb)  {}
704
705   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
706   void connAccepted(
707     const std::shared_ptr<folly::AsyncSSLSocket> &s)
708     noexcept override {
709     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
710
711     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
712
713     hcb_->setSocket(sock);
714     sock->getEventBase()->tryRunAfterDelay([=] {
715         std::cerr << "Delayed SSL accept, client will have close by now"
716                   << std::endl;
717         // SSL accept will fail
718         EXPECT_EQ(
719           sock->getSSLState(),
720           AsyncSSLSocket::STATE_UNINIT);
721         hcb_->socket_->sslAccept(hcb_);
722         // This registers for an event
723         EXPECT_EQ(
724           sock->getSSLState(),
725           AsyncSSLSocket::STATE_ACCEPTING);
726
727         state = STATE_SUCCEEDED;
728       }, 100);
729   }
730 };
731
732 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
733  public:
734   ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
735     // We don't care if we get invoked or not.
736     // The client may time out and give up before connAccepted() is even
737     // called.
738     state = STATE_SUCCEEDED;
739   }
740
741   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
742   void connAccepted(
743       const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
744     std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
745
746     // Just wait a while before closing the socket, so the client
747     // will time out waiting for the handshake to complete.
748     s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
749   }
750 };
751
752 class TestSSLAsyncCacheServer : public TestSSLServer {
753  public:
754   explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
755         int lookupDelay = 100) :
756       TestSSLServer(acb) {
757     SSL_CTX *sslCtx = ctx_->getSSLCtx();
758 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
759     SSL_CTX_sess_set_get_cb(sslCtx,
760                             TestSSLAsyncCacheServer::getSessionCallback);
761 #endif
762     SSL_CTX_set_session_cache_mode(
763       sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
764     asyncCallbacks_ = 0;
765     asyncLookups_ = 0;
766     lookupDelay_ = lookupDelay;
767   }
768
769   uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
770   uint32_t getAsyncLookups() const { return asyncLookups_; }
771
772  private:
773   static uint32_t asyncCallbacks_;
774   static uint32_t asyncLookups_;
775   static uint32_t lookupDelay_;
776
777   static SSL_SESSION* getSessionCallback(SSL* ssl,
778                                          unsigned char* /* sess_id */,
779                                          int /* id_len */,
780                                          int* copyflag) {
781     *copyflag = 0;
782     asyncCallbacks_++;
783     (void)ssl;
784 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
785     if (!SSL_want_sess_cache_lookup(ssl)) {
786       // libssl.so mismatch
787       std::cerr << "no async support" << std::endl;
788       return nullptr;
789     }
790
791     AsyncSSLSocket *sslSocket =
792         AsyncSSLSocket::getFromSSL(ssl);
793     assert(sslSocket != nullptr);
794     // Going to simulate an async cache by just running delaying the miss 100ms
795     if (asyncCallbacks_ % 2 == 0) {
796       // This socket is already blocked on lookup, return miss
797       std::cerr << "returning miss" << std::endl;
798     } else {
799       // fresh meat - block it
800       std::cerr << "async lookup" << std::endl;
801       sslSocket->getEventBase()->tryRunAfterDelay(
802         std::bind(&AsyncSSLSocket::restartSSLAccept,
803                   sslSocket), lookupDelay_);
804       *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
805       asyncLookups_++;
806     }
807 #endif
808     return nullptr;
809   }
810 };
811
812 void getfds(int fds[2]);
813
814 void getctx(
815   std::shared_ptr<folly::SSLContext> clientCtx,
816   std::shared_ptr<folly::SSLContext> serverCtx);
817
818 void sslsocketpair(
819   EventBase* eventBase,
820   AsyncSSLSocket::UniquePtr* clientSock,
821   AsyncSSLSocket::UniquePtr* serverSock);
822
823 class BlockingWriteClient :
824   private AsyncSSLSocket::HandshakeCB,
825   private AsyncTransportWrapper::WriteCallback {
826  public:
827   explicit BlockingWriteClient(
828     AsyncSSLSocket::UniquePtr socket)
829     : socket_(std::move(socket)),
830       bufLen_(2500),
831       iovCount_(2000) {
832     // Fill buf_
833     buf_.reset(new uint8_t[bufLen_]);
834     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
835       buf_[n] = n % 0xff;
836     }
837
838     // Initialize iov_
839     iov_.reset(new struct iovec[iovCount_]);
840     for (uint32_t n = 0; n < iovCount_; ++n) {
841       iov_[n].iov_base = buf_.get() + n;
842       if (n & 0x1) {
843         iov_[n].iov_len = n % bufLen_;
844       } else {
845         iov_[n].iov_len = bufLen_ - (n % bufLen_);
846       }
847     }
848
849     socket_->sslConn(this, std::chrono::milliseconds(100));
850   }
851
852   struct iovec* getIovec() const {
853     return iov_.get();
854   }
855   uint32_t getIovecCount() const {
856     return iovCount_;
857   }
858
859  private:
860   void handshakeSuc(AsyncSSLSocket*) noexcept override {
861     socket_->writev(this, iov_.get(), iovCount_);
862   }
863   void handshakeErr(
864     AsyncSSLSocket*,
865     const AsyncSocketException& ex) noexcept override {
866     ADD_FAILURE() << "client handshake error: " << ex.what();
867   }
868   void writeSuccess() noexcept override {
869     socket_->close();
870   }
871   void writeErr(
872     size_t bytesWritten,
873     const AsyncSocketException& ex) noexcept override {
874     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
875                   << ex.what();
876   }
877
878   AsyncSSLSocket::UniquePtr socket_;
879   uint32_t bufLen_;
880   uint32_t iovCount_;
881   std::unique_ptr<uint8_t[]> buf_;
882   std::unique_ptr<struct iovec[]> iov_;
883 };
884
885 class BlockingWriteServer :
886     private AsyncSSLSocket::HandshakeCB,
887     private AsyncTransportWrapper::ReadCallback {
888  public:
889   explicit BlockingWriteServer(
890     AsyncSSLSocket::UniquePtr socket)
891     : socket_(std::move(socket)),
892       bufSize_(2500 * 2000),
893       bytesRead_(0) {
894     buf_.reset(new uint8_t[bufSize_]);
895     socket_->sslAccept(this, std::chrono::milliseconds(100));
896   }
897
898   void checkBuffer(struct iovec* iov, uint32_t count) const {
899     uint32_t idx = 0;
900     for (uint32_t n = 0; n < count; ++n) {
901       size_t bytesLeft = bytesRead_ - idx;
902       int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
903                       std::min(iov[n].iov_len, bytesLeft));
904       if (rc != 0) {
905         FAIL() << "buffer mismatch at iovec " << n << "/" << count
906                << ": rc=" << rc;
907
908       }
909       if (iov[n].iov_len > bytesLeft) {
910         FAIL() << "server did not read enough data: "
911                << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
912                << " in iovec " << n << "/" << count;
913       }
914
915       idx += iov[n].iov_len;
916     }
917     if (idx != bytesRead_) {
918       ADD_FAILURE() << "server read extra data: " << bytesRead_
919                     << " bytes read; expected " << idx;
920     }
921   }
922
923  private:
924   void handshakeSuc(AsyncSSLSocket*) noexcept override {
925     // Wait 10ms before reading, so the client's writes will initially block.
926     socket_->getEventBase()->tryRunAfterDelay(
927         [this] { socket_->setReadCB(this); }, 10);
928   }
929   void handshakeErr(
930     AsyncSSLSocket*,
931     const AsyncSocketException& ex) noexcept override {
932     ADD_FAILURE() << "server handshake error: " << ex.what();
933   }
934   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
935     *bufReturn = buf_.get() + bytesRead_;
936     *lenReturn = bufSize_ - bytesRead_;
937   }
938   void readDataAvailable(size_t len) noexcept override {
939     bytesRead_ += len;
940     socket_->setReadCB(nullptr);
941     socket_->getEventBase()->tryRunAfterDelay(
942         [this] { socket_->setReadCB(this); }, 2);
943   }
944   void readEOF() noexcept override {
945     socket_->close();
946   }
947   void readErr(
948     const AsyncSocketException& ex) noexcept override {
949     ADD_FAILURE() << "server read error: " << ex.what();
950   }
951
952   AsyncSSLSocket::UniquePtr socket_;
953   uint32_t bufSize_;
954   uint32_t bytesRead_;
955   std::unique_ptr<uint8_t[]> buf_;
956 };
957
958 class NpnClient :
959   private AsyncSSLSocket::HandshakeCB,
960   private AsyncTransportWrapper::WriteCallback {
961  public:
962   explicit NpnClient(
963     AsyncSSLSocket::UniquePtr socket)
964       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
965     socket_->sslConn(this);
966   }
967
968   const unsigned char* nextProto;
969   unsigned nextProtoLength;
970   SSLContext::NextProtocolType protocolType;
971   folly::Optional<AsyncSocketException> except;
972
973  private:
974   void handshakeSuc(AsyncSSLSocket*) noexcept override {
975     socket_->getSelectedNextProtocol(
976         &nextProto, &nextProtoLength, &protocolType);
977   }
978   void handshakeErr(
979     AsyncSSLSocket*,
980     const AsyncSocketException& ex) noexcept override {
981     except = ex;
982   }
983   void writeSuccess() noexcept override {
984     socket_->close();
985   }
986   void writeErr(
987     size_t bytesWritten,
988     const AsyncSocketException& ex) noexcept override {
989     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
990                   << ex.what();
991   }
992
993   AsyncSSLSocket::UniquePtr socket_;
994 };
995
996 class NpnServer :
997     private AsyncSSLSocket::HandshakeCB,
998     private AsyncTransportWrapper::ReadCallback {
999  public:
1000   explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
1001       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
1002     socket_->sslAccept(this);
1003   }
1004
1005   const unsigned char* nextProto;
1006   unsigned nextProtoLength;
1007   SSLContext::NextProtocolType protocolType;
1008   folly::Optional<AsyncSocketException> except;
1009
1010  private:
1011   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1012     socket_->getSelectedNextProtocol(
1013         &nextProto, &nextProtoLength, &protocolType);
1014   }
1015   void handshakeErr(
1016     AsyncSSLSocket*,
1017     const AsyncSocketException& ex) noexcept override {
1018     except = ex;
1019   }
1020   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1021     *lenReturn = 0;
1022   }
1023   void readDataAvailable(size_t /* len */) noexcept override {}
1024   void readEOF() noexcept override {
1025     socket_->close();
1026   }
1027   void readErr(
1028     const AsyncSocketException& ex) noexcept override {
1029     ADD_FAILURE() << "server read error: " << ex.what();
1030   }
1031
1032   AsyncSSLSocket::UniquePtr socket_;
1033 };
1034
1035 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
1036                             public AsyncTransportWrapper::ReadCallback {
1037  public:
1038   explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
1039       : socket_(std::move(socket)) {
1040     socket_->sslAccept(this);
1041   }
1042
1043   ~RenegotiatingServer() override {
1044     socket_->setReadCB(nullptr);
1045   }
1046
1047   void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
1048     LOG(INFO) << "Renegotiating server handshake success";
1049     socket_->setReadCB(this);
1050   }
1051   void handshakeErr(
1052       AsyncSSLSocket*,
1053       const AsyncSocketException& ex) noexcept override {
1054     ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
1055   }
1056   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1057     *lenReturn = sizeof(buf);
1058     *bufReturn = buf;
1059   }
1060   void readDataAvailable(size_t /* len */) noexcept override {}
1061   void readEOF() noexcept override {}
1062   void readErr(const AsyncSocketException& ex) noexcept override {
1063     LOG(INFO) << "server got read error " << ex.what();
1064     auto exPtr = dynamic_cast<const SSLException*>(&ex);
1065     ASSERT_NE(nullptr, exPtr);
1066     std::string exStr(ex.what());
1067     SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
1068     ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
1069     renegotiationError_ = true;
1070   }
1071
1072   AsyncSSLSocket::UniquePtr socket_;
1073   unsigned char buf[128];
1074   bool renegotiationError_{false};
1075 };
1076
1077 #ifndef OPENSSL_NO_TLSEXT
1078 class SNIClient :
1079   private AsyncSSLSocket::HandshakeCB,
1080   private AsyncTransportWrapper::WriteCallback {
1081  public:
1082   explicit SNIClient(
1083     AsyncSSLSocket::UniquePtr socket)
1084       : serverNameMatch(false), socket_(std::move(socket)) {
1085     socket_->sslConn(this);
1086   }
1087
1088   bool serverNameMatch;
1089
1090  private:
1091   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1092     serverNameMatch = socket_->isServerNameMatch();
1093   }
1094   void handshakeErr(
1095     AsyncSSLSocket*,
1096     const AsyncSocketException& ex) noexcept override {
1097     ADD_FAILURE() << "client handshake error: " << ex.what();
1098   }
1099   void writeSuccess() noexcept override {
1100     socket_->close();
1101   }
1102   void writeErr(
1103     size_t bytesWritten,
1104     const AsyncSocketException& ex) noexcept override {
1105     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1106                   << ex.what();
1107   }
1108
1109   AsyncSSLSocket::UniquePtr socket_;
1110 };
1111
1112 class SNIServer :
1113     private AsyncSSLSocket::HandshakeCB,
1114     private AsyncTransportWrapper::ReadCallback {
1115  public:
1116   explicit SNIServer(
1117     AsyncSSLSocket::UniquePtr socket,
1118     const std::shared_ptr<folly::SSLContext>& ctx,
1119     const std::shared_ptr<folly::SSLContext>& sniCtx,
1120     const std::string& expectedServerName)
1121       : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1122         expectedServerName_(expectedServerName) {
1123     ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1124                                          std::placeholders::_1));
1125     socket_->sslAccept(this);
1126   }
1127
1128   bool serverNameMatch;
1129
1130  private:
1131   void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1132   void handshakeErr(
1133     AsyncSSLSocket*,
1134     const AsyncSocketException& ex) noexcept override {
1135     ADD_FAILURE() << "server handshake error: " << ex.what();
1136   }
1137   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1138     *lenReturn = 0;
1139   }
1140   void readDataAvailable(size_t /* len */) noexcept override {}
1141   void readEOF() noexcept override {
1142     socket_->close();
1143   }
1144   void readErr(
1145     const AsyncSocketException& ex) noexcept override {
1146     ADD_FAILURE() << "server read error: " << ex.what();
1147   }
1148
1149   folly::SSLContext::ServerNameCallbackResult
1150     serverNameCallback(SSL *ssl) {
1151     const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1152     if (sniCtx_ &&
1153         sn &&
1154         !strcasecmp(expectedServerName_.c_str(), sn)) {
1155       AsyncSSLSocket *sslSocket =
1156           AsyncSSLSocket::getFromSSL(ssl);
1157       sslSocket->switchServerSSLContext(sniCtx_);
1158       serverNameMatch = true;
1159       return folly::SSLContext::SERVER_NAME_FOUND;
1160     } else {
1161       serverNameMatch = false;
1162       return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1163     }
1164   }
1165
1166   AsyncSSLSocket::UniquePtr socket_;
1167   std::shared_ptr<folly::SSLContext> sniCtx_;
1168   std::string expectedServerName_;
1169 };
1170 #endif
1171
1172 class SSLClient : public AsyncSocket::ConnectCallback,
1173                   public AsyncTransportWrapper::WriteCallback,
1174                   public AsyncTransportWrapper::ReadCallback
1175 {
1176  private:
1177   EventBase *eventBase_;
1178   std::shared_ptr<AsyncSSLSocket> sslSocket_;
1179   SSL_SESSION *session_;
1180   std::shared_ptr<folly::SSLContext> ctx_;
1181   uint32_t requests_;
1182   folly::SocketAddress address_;
1183   uint32_t timeout_;
1184   char buf_[128];
1185   char readbuf_[128];
1186   uint32_t bytesRead_;
1187   uint32_t hit_;
1188   uint32_t miss_;
1189   uint32_t errors_;
1190   uint32_t writeAfterConnectErrors_;
1191
1192   // These settings test that we eventually drain the
1193   // socket, even if the maxReadsPerEvent_ is hit during
1194   // a event loop iteration.
1195   static constexpr size_t kMaxReadsPerEvent = 2;
1196   // 2 event loop iterations
1197   static constexpr size_t kMaxReadBufferSz =
1198     sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1199
1200  public:
1201   SSLClient(EventBase *eventBase,
1202             const folly::SocketAddress& address,
1203             uint32_t requests,
1204             uint32_t timeout = 0)
1205       : eventBase_(eventBase),
1206         session_(nullptr),
1207         requests_(requests),
1208         address_(address),
1209         timeout_(timeout),
1210         bytesRead_(0),
1211         hit_(0),
1212         miss_(0),
1213         errors_(0),
1214         writeAfterConnectErrors_(0) {
1215     ctx_.reset(new folly::SSLContext());
1216     ctx_->setOptions(SSL_OP_NO_TICKET);
1217     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1218     memset(buf_, 'a', sizeof(buf_));
1219   }
1220
1221   ~SSLClient() override {
1222     if (session_) {
1223       SSL_SESSION_free(session_);
1224     }
1225     if (errors_ == 0) {
1226       EXPECT_EQ(bytesRead_, sizeof(buf_));
1227     }
1228   }
1229
1230   uint32_t getHit() const { return hit_; }
1231
1232   uint32_t getMiss() const { return miss_; }
1233
1234   uint32_t getErrors() const { return errors_; }
1235
1236   uint32_t getWriteAfterConnectErrors() const {
1237     return writeAfterConnectErrors_;
1238   }
1239
1240   void connect(bool writeNow = false) {
1241     sslSocket_ = AsyncSSLSocket::newSocket(
1242       ctx_, eventBase_);
1243     if (session_ != nullptr) {
1244       sslSocket_->setSSLSession(session_);
1245     }
1246     requests_--;
1247     sslSocket_->connect(this, address_, timeout_);
1248     if (sslSocket_ && writeNow) {
1249       // write some junk, used in an error test
1250       sslSocket_->write(this, buf_, sizeof(buf_));
1251     }
1252   }
1253
1254   void connectSuccess() noexcept override {
1255     std::cerr << "client SSL socket connected" << std::endl;
1256     if (sslSocket_->getSSLSessionReused()) {
1257       hit_++;
1258     } else {
1259       miss_++;
1260       if (session_ != nullptr) {
1261         SSL_SESSION_free(session_);
1262       }
1263       session_ = sslSocket_->getSSLSession();
1264     }
1265
1266     // write()
1267     sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1268     sslSocket_->write(this, buf_, sizeof(buf_));
1269     sslSocket_->setReadCB(this);
1270     memset(readbuf_, 'b', sizeof(readbuf_));
1271     bytesRead_ = 0;
1272   }
1273
1274   void connectErr(
1275     const AsyncSocketException& ex) noexcept override {
1276     std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1277     errors_++;
1278     sslSocket_.reset();
1279   }
1280
1281   void writeSuccess() noexcept override {
1282     std::cerr << "client write success" << std::endl;
1283   }
1284
1285   void writeErr(size_t /* bytesWritten */,
1286                 const AsyncSocketException& ex) noexcept override {
1287     std::cerr << "client writeError: " << ex.what() << std::endl;
1288     if (!sslSocket_) {
1289       writeAfterConnectErrors_++;
1290     }
1291   }
1292
1293   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1294     *bufReturn = readbuf_ + bytesRead_;
1295     *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1296   }
1297
1298   void readEOF() noexcept override {
1299     std::cerr << "client readEOF" << std::endl;
1300   }
1301
1302   void readErr(
1303     const AsyncSocketException& ex) noexcept override {
1304     std::cerr << "client readError: " << ex.what() << std::endl;
1305   }
1306
1307   void readDataAvailable(size_t len) noexcept override {
1308     std::cerr << "client read data: " << len << std::endl;
1309     bytesRead_ += len;
1310     if (bytesRead_ == sizeof(buf_)) {
1311       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1312       sslSocket_->closeNow();
1313       sslSocket_.reset();
1314       if (requests_ != 0) {
1315         connect();
1316       }
1317     }
1318   }
1319
1320 };
1321
1322 class SSLHandshakeBase :
1323   public AsyncSSLSocket::HandshakeCB,
1324   private AsyncTransportWrapper::WriteCallback {
1325  public:
1326   explicit SSLHandshakeBase(
1327    AsyncSSLSocket::UniquePtr socket,
1328    bool preverifyResult,
1329    bool verifyResult) :
1330     handshakeVerify_(false),
1331     handshakeSuccess_(false),
1332     handshakeError_(false),
1333     socket_(std::move(socket)),
1334     preverifyResult_(preverifyResult),
1335     verifyResult_(verifyResult) {
1336   }
1337
1338   AsyncSSLSocket::UniquePtr moveSocket() && {
1339     return std::move(socket_);
1340   }
1341
1342   bool handshakeVerify_;
1343   bool handshakeSuccess_;
1344   bool handshakeError_;
1345   std::chrono::nanoseconds handshakeTime;
1346
1347  protected:
1348   AsyncSSLSocket::UniquePtr socket_;
1349   bool preverifyResult_;
1350   bool verifyResult_;
1351
1352   // HandshakeCallback
1353   bool handshakeVer(AsyncSSLSocket* /* sock */,
1354                     bool preverifyOk,
1355                     X509_STORE_CTX* /* ctx */) noexcept override {
1356     handshakeVerify_ = true;
1357
1358     EXPECT_EQ(preverifyResult_, preverifyOk);
1359     return verifyResult_;
1360   }
1361
1362   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1363     LOG(INFO) << "Handshake success";
1364     handshakeSuccess_ = true;
1365     handshakeTime = socket_->getHandshakeTime();
1366   }
1367
1368   void handshakeErr(
1369       AsyncSSLSocket*,
1370       const AsyncSocketException& ex) noexcept override {
1371     LOG(INFO) << "Handshake error " << ex.what();
1372     handshakeError_ = true;
1373     handshakeTime = socket_->getHandshakeTime();
1374   }
1375
1376   // WriteCallback
1377   void writeSuccess() noexcept override {
1378     socket_->close();
1379   }
1380
1381   void writeErr(
1382    size_t bytesWritten,
1383    const AsyncSocketException& ex) noexcept override {
1384     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1385                   << ex.what();
1386   }
1387 };
1388
1389 class SSLHandshakeClient : public SSLHandshakeBase {
1390  public:
1391   SSLHandshakeClient(
1392    AsyncSSLSocket::UniquePtr socket,
1393    bool preverifyResult,
1394    bool verifyResult) :
1395     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1396     socket_->sslConn(this, std::chrono::milliseconds::zero());
1397   }
1398 };
1399
1400 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1401  public:
1402   SSLHandshakeClientNoVerify(
1403    AsyncSSLSocket::UniquePtr socket,
1404    bool preverifyResult,
1405    bool verifyResult) :
1406     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1407     socket_->sslConn(
1408         this,
1409         std::chrono::milliseconds::zero(),
1410         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1411   }
1412 };
1413
1414 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1415  public:
1416   SSLHandshakeClientDoVerify(
1417    AsyncSSLSocket::UniquePtr socket,
1418    bool preverifyResult,
1419    bool verifyResult) :
1420     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1421     socket_->sslConn(
1422         this,
1423         std::chrono::milliseconds::zero(),
1424         folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1425   }
1426 };
1427
1428 class SSLHandshakeServer : public SSLHandshakeBase {
1429  public:
1430   SSLHandshakeServer(
1431       AsyncSSLSocket::UniquePtr socket,
1432       bool preverifyResult,
1433       bool verifyResult)
1434     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1435     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1436   }
1437 };
1438
1439 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1440  public:
1441   SSLHandshakeServerParseClientHello(
1442       AsyncSSLSocket::UniquePtr socket,
1443       bool preverifyResult,
1444       bool verifyResult)
1445       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1446     socket_->enableClientHelloParsing();
1447     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1448   }
1449
1450   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1451
1452  protected:
1453   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1454     handshakeSuccess_ = true;
1455     sock->getSSLSharedCiphers(sharedCiphers_);
1456     sock->getSSLServerCiphers(serverCiphers_);
1457     sock->getSSLClientCiphers(clientCiphers_);
1458     chosenCipher_ = sock->getNegotiatedCipherName();
1459   }
1460 };
1461
1462
1463 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1464  public:
1465   SSLHandshakeServerNoVerify(
1466       AsyncSSLSocket::UniquePtr socket,
1467       bool preverifyResult,
1468       bool verifyResult)
1469     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1470     socket_->sslAccept(
1471         this,
1472         std::chrono::milliseconds::zero(),
1473         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1474   }
1475 };
1476
1477 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1478  public:
1479   SSLHandshakeServerDoVerify(
1480       AsyncSSLSocket::UniquePtr socket,
1481       bool preverifyResult,
1482       bool verifyResult)
1483     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1484     socket_->sslAccept(
1485         this,
1486         std::chrono::milliseconds::zero(),
1487         folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1488   }
1489 };
1490
1491 class EventBaseAborter : public AsyncTimeout {
1492  public:
1493   EventBaseAborter(EventBase* eventBase,
1494                    uint32_t timeoutMS)
1495     : AsyncTimeout(
1496       eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1497     , eventBase_(eventBase) {
1498     scheduleTimeout(timeoutMS);
1499   }
1500
1501   void timeoutExpired() noexcept override {
1502     FAIL() << "test timed out";
1503     eventBase_->terminateLoopSoon();
1504   }
1505
1506  private:
1507   EventBase* eventBase_;
1508 };
1509
1510 }