OpenSSL 1.1.0 compatibility
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/experimental/TestUtil.h>
24 #include <folly/io/async/AsyncSSLSocket.h>
25 #include <folly/io/async/AsyncServerSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/AsyncTimeout.h>
28 #include <folly/io/async/AsyncTransport.h>
29 #include <folly/io/async/EventBase.h>
30 #include <folly/io/async/ssl/SSLErrors.h>
31 #include <folly/io/async/test/TestSSLServer.h>
32 #include <folly/portability/GTest.h>
33 #include <folly/portability/Sockets.h>
34 #include <folly/portability/Unistd.h>
35
36 #include <fcntl.h>
37 #include <sys/types.h>
38 #include <condition_variable>
39 #include <iostream>
40 #include <list>
41
42 namespace folly {
43
44 // The destructors of all callback classes assert that the state is
45 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
46 // are responsible for setting the succeeded state properly before the
47 // destructors are called.
48
49 class SendMsgParamsCallbackBase :
50       public folly::AsyncSocket::SendMsgParamsCallback {
51  public:
52   SendMsgParamsCallbackBase() {}
53
54   void setSocket(
55     const std::shared_ptr<AsyncSSLSocket> &socket) {
56     socket_ = socket;
57     oldCallback_ = socket_->getSendMsgParamsCB();
58     socket_->setSendMsgParamCB(this);
59   }
60
61   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
62                                                                      override {
63     return oldCallback_->getFlags(flags);
64   }
65
66   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
67     oldCallback_->getAncillaryData(flags, data);
68   }
69
70   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
71     return oldCallback_->getAncillaryDataSize(flags);
72   }
73
74   std::shared_ptr<AsyncSSLSocket> socket_;
75   folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
76 };
77
78 class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
79  public:
80   SendMsgFlagsCallback() {}
81
82   void resetFlags(int flags) {
83     flags_ = flags;
84   }
85
86   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
87                                                                   override {
88     if (flags_) {
89       return flags_;
90     } else {
91       return oldCallback_->getFlags(flags);
92     }
93   }
94
95   int flags_{0};
96 };
97
98 class SendMsgDataCallback : public SendMsgFlagsCallback {
99  public:
100   SendMsgDataCallback() {}
101
102   void resetData(std::vector<char>&& data) {
103     ancillaryData_.swap(data);
104   }
105
106   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
107     if (ancillaryData_.size()) {
108       std::cerr << "getAncillaryData: copying data" << std::endl;
109       memcpy(data, ancillaryData_.data(), ancillaryData_.size());
110     } else {
111       oldCallback_->getAncillaryData(flags, data);
112     }
113   }
114
115   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
116     if (ancillaryData_.size()) {
117       std::cerr << "getAncillaryDataSize: returning size" << std::endl;
118       return ancillaryData_.size();
119     } else {
120       return oldCallback_->getAncillaryDataSize(flags);
121     }
122   }
123
124   std::vector<char> ancillaryData_;
125 };
126
127 class WriteCallbackBase :
128 public AsyncTransportWrapper::WriteCallback {
129 public:
130   explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
131       : state(STATE_WAITING)
132       , bytesWritten(0)
133       , exception(AsyncSocketException::UNKNOWN, "none")
134       , mcb_(mcb) {}
135
136   ~WriteCallbackBase() {
137     EXPECT_EQ(STATE_SUCCEEDED, state);
138   }
139
140   virtual void setSocket(
141     const std::shared_ptr<AsyncSSLSocket> &socket) {
142     socket_ = socket;
143     if (mcb_) {
144       mcb_->setSocket(socket);
145     }
146   }
147
148   virtual void writeSuccess() noexcept override {
149     std::cerr << "writeSuccess" << std::endl;
150     state = STATE_SUCCEEDED;
151   }
152
153   void writeErr(
154     size_t nBytesWritten,
155     const AsyncSocketException& ex) noexcept override {
156     std::cerr << "writeError: bytesWritten " << nBytesWritten
157          << ", exception " << ex.what() << std::endl;
158
159     state = STATE_FAILED;
160     this->bytesWritten = nBytesWritten;
161     exception = ex;
162     socket_->close();
163   }
164
165   std::shared_ptr<AsyncSSLSocket> socket_;
166   StateEnum state;
167   size_t bytesWritten;
168   AsyncSocketException exception;
169   SendMsgParamsCallbackBase* mcb_;
170 };
171
172 class ExpectWriteErrorCallback :
173 public WriteCallbackBase {
174 public:
175   explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
176       : WriteCallbackBase(mcb) {}
177
178   ~ExpectWriteErrorCallback() {
179     EXPECT_EQ(STATE_FAILED, state);
180     EXPECT_EQ(exception.type_,
181              AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
182     EXPECT_EQ(exception.errno_, 22);
183     // Suppress the assert in  ~WriteCallbackBase()
184     state = STATE_SUCCEEDED;
185   }
186 };
187
188 #ifdef MSG_ERRQUEUE
189 /* copied from include/uapi/linux/net_tstamp.h */
190 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
191 enum SOF_TIMESTAMPING {
192   SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
193   SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
194   SOF_TIMESTAMPING_OPT_ID = (1 << 7),
195   SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
196   SOF_TIMESTAMPING_TX_ACK = (1 << 9),
197   SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
198 };
199
200 class WriteCheckTimestampCallback :
201  public WriteCallbackBase {
202 public:
203   explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
204     : WriteCallbackBase(mcb) {}
205
206   ~WriteCheckTimestampCallback() {
207     EXPECT_EQ(STATE_SUCCEEDED, state);
208     EXPECT_TRUE(gotTimestamp_);
209     EXPECT_TRUE(gotByteSeq_);
210   }
211
212   void setSocket(
213     const std::shared_ptr<AsyncSSLSocket> &socket) override {
214     WriteCallbackBase::setSocket(socket);
215
216     EXPECT_NE(socket_->getFd(), 0);
217     int flags = SOF_TIMESTAMPING_OPT_ID
218                 | SOF_TIMESTAMPING_OPT_TSONLY
219                 | SOF_TIMESTAMPING_SOFTWARE;
220     AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
221     int ret = tstampingOpt.apply(socket_->getFd(), flags);
222     EXPECT_EQ(ret, 0);
223   }
224
225   void checkForTimestampNotifications() noexcept {
226     int fd = socket_->getFd();
227     std::vector<char> ctrl(1024, 0);
228     unsigned char data;
229     struct msghdr msg;
230     iovec entry;
231
232     memset(&msg, 0, sizeof(msg));
233     entry.iov_base = &data;
234     entry.iov_len = sizeof(data);
235     msg.msg_iov = &entry;
236     msg.msg_iovlen = 1;
237     msg.msg_control = ctrl.data();
238     msg.msg_controllen = ctrl.size();
239
240     int ret;
241     while (true) {
242       ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
243       if (ret < 0) {
244         if (errno != EAGAIN) {
245           auto errnoCopy = errno;
246           std::cerr << "::recvmsg exited with code " << ret
247                     << ", errno: " << errnoCopy << std::endl;
248           AsyncSocketException ex(
249             AsyncSocketException::INTERNAL_ERROR,
250             "recvmsg() failed",
251             errnoCopy);
252           exception = ex;
253         }
254         return;
255       }
256
257       for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
258            cmsg != nullptr && cmsg->cmsg_len != 0;
259            cmsg = CMSG_NXTHDR(&msg, cmsg)) {
260         if (cmsg->cmsg_level == SOL_SOCKET &&
261             cmsg->cmsg_type == SCM_TIMESTAMPING) {
262           gotTimestamp_ = true;
263           continue;
264         }
265
266         if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
267             (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
268           gotByteSeq_ = true;
269           continue;
270         }
271       }
272     }
273   }
274
275   bool gotTimestamp_{false};
276   bool gotByteSeq_{false};
277 };
278 #endif // MSG_ERRQUEUE
279
280 class ReadCallbackBase :
281 public AsyncTransportWrapper::ReadCallback {
282  public:
283   explicit ReadCallbackBase(WriteCallbackBase* wcb)
284       : wcb_(wcb), state(STATE_WAITING) {}
285
286   ~ReadCallbackBase() {
287     EXPECT_EQ(STATE_SUCCEEDED, state);
288   }
289
290   void setSocket(
291     const std::shared_ptr<AsyncSSLSocket> &socket) {
292     socket_ = socket;
293   }
294
295   void setState(StateEnum s) {
296     state = s;
297     if (wcb_) {
298       wcb_->state = s;
299     }
300   }
301
302   void readErr(
303     const AsyncSocketException& ex) noexcept override {
304     std::cerr << "readError " << ex.what() << std::endl;
305     state = STATE_FAILED;
306     socket_->close();
307   }
308
309   void readEOF() noexcept override {
310     std::cerr << "readEOF" << std::endl;
311
312     socket_->close();
313   }
314
315   std::shared_ptr<AsyncSSLSocket> socket_;
316   WriteCallbackBase *wcb_;
317   StateEnum state;
318 };
319
320 class ReadCallback : public ReadCallbackBase {
321 public:
322   explicit ReadCallback(WriteCallbackBase *wcb)
323       : ReadCallbackBase(wcb)
324       , buffers() {}
325
326   ~ReadCallback() {
327     for (std::vector<Buffer>::iterator it = buffers.begin();
328          it != buffers.end();
329          ++it) {
330       it->free();
331     }
332     currentBuffer.free();
333   }
334
335   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
336     if (!currentBuffer.buffer) {
337       currentBuffer.allocate(4096);
338     }
339     *bufReturn = currentBuffer.buffer;
340     *lenReturn = currentBuffer.length;
341   }
342
343   void readDataAvailable(size_t len) noexcept override {
344     std::cerr << "readDataAvailable, len " << len << std::endl;
345
346     currentBuffer.length = len;
347
348     wcb_->setSocket(socket_);
349
350     // Write back the same data.
351     socket_->write(wcb_, currentBuffer.buffer, len);
352
353     buffers.push_back(currentBuffer);
354     currentBuffer.reset();
355     state = STATE_SUCCEEDED;
356   }
357
358   class Buffer {
359   public:
360     Buffer() : buffer(nullptr), length(0) {}
361     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
362
363     void reset() {
364       buffer = nullptr;
365       length = 0;
366     }
367     void allocate(size_t len) {
368       assert(buffer == nullptr);
369       this->buffer = static_cast<char*>(malloc(len));
370       this->length = len;
371     }
372     void free() {
373       ::free(buffer);
374       reset();
375     }
376
377     char* buffer;
378     size_t length;
379   };
380
381   std::vector<Buffer> buffers;
382   Buffer currentBuffer;
383 };
384
385 class ReadErrorCallback : public ReadCallbackBase {
386 public:
387   explicit ReadErrorCallback(WriteCallbackBase *wcb)
388       : ReadCallbackBase(wcb) {}
389
390   // Return nullptr buffer to trigger readError()
391   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
392     *bufReturn = nullptr;
393     *lenReturn = 0;
394   }
395
396   void readDataAvailable(size_t /* len */) noexcept override {
397     // This should never to called.
398     FAIL();
399   }
400
401   void readErr(
402     const AsyncSocketException& ex) noexcept override {
403     ReadCallbackBase::readErr(ex);
404     std::cerr << "ReadErrorCallback::readError" << std::endl;
405     setState(STATE_SUCCEEDED);
406   }
407 };
408
409 class ReadEOFCallback : public ReadCallbackBase {
410  public:
411   explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
412
413   // Return nullptr buffer to trigger readError()
414   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
415     *bufReturn = nullptr;
416     *lenReturn = 0;
417   }
418
419   void readDataAvailable(size_t /* len */) noexcept override {
420     // This should never to called.
421     FAIL();
422   }
423
424   void readEOF() noexcept override {
425     ReadCallbackBase::readEOF();
426     setState(STATE_SUCCEEDED);
427   }
428 };
429
430 class WriteErrorCallback : public ReadCallback {
431 public:
432   explicit WriteErrorCallback(WriteCallbackBase *wcb)
433       : ReadCallback(wcb) {}
434
435   void readDataAvailable(size_t len) noexcept override {
436     std::cerr << "readDataAvailable, len " << len << std::endl;
437
438     currentBuffer.length = len;
439
440     // close the socket before writing to trigger writeError().
441     ::close(socket_->getFd());
442
443     wcb_->setSocket(socket_);
444
445     // Write back the same data.
446     folly::test::msvcSuppressAbortOnInvalidParams([&] {
447       socket_->write(wcb_, currentBuffer.buffer, len);
448     });
449
450     if (wcb_->state == STATE_FAILED) {
451       setState(STATE_SUCCEEDED);
452     } else {
453       state = STATE_FAILED;
454     }
455
456     buffers.push_back(currentBuffer);
457     currentBuffer.reset();
458   }
459
460   void readErr(const AsyncSocketException& ex) noexcept override {
461     std::cerr << "readError " << ex.what() << std::endl;
462     // do nothing since this is expected
463   }
464 };
465
466 class EmptyReadCallback : public ReadCallback {
467 public:
468   explicit EmptyReadCallback()
469       : ReadCallback(nullptr) {}
470
471   void readErr(const AsyncSocketException& ex) noexcept override {
472     std::cerr << "readError " << ex.what() << std::endl;
473     state = STATE_FAILED;
474     if (tcpSocket_) {
475       tcpSocket_->close();
476     }
477   }
478
479   void readEOF() noexcept override {
480     std::cerr << "readEOF" << std::endl;
481     if (tcpSocket_) {
482       tcpSocket_->close();
483     }
484     state = STATE_SUCCEEDED;
485   }
486
487   std::shared_ptr<AsyncSocket> tcpSocket_;
488 };
489
490 class HandshakeCallback :
491 public AsyncSSLSocket::HandshakeCB {
492 public:
493   enum ExpectType {
494     EXPECT_SUCCESS,
495     EXPECT_ERROR
496   };
497
498   explicit HandshakeCallback(ReadCallbackBase *rcb,
499                              ExpectType expect = EXPECT_SUCCESS):
500       state(STATE_WAITING),
501       rcb_(rcb),
502       expect_(expect) {}
503
504   void setSocket(
505     const std::shared_ptr<AsyncSSLSocket> &socket) {
506     socket_ = socket;
507   }
508
509   void setState(StateEnum s) {
510     state = s;
511     rcb_->setState(s);
512   }
513
514   // Functions inherited from AsyncSSLSocketHandshakeCallback
515   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
516     std::lock_guard<std::mutex> g(mutex_);
517     cv_.notify_all();
518     EXPECT_EQ(sock, socket_.get());
519     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
520     rcb_->setSocket(socket_);
521     sock->setReadCB(rcb_);
522     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
523   }
524   void handshakeErr(AsyncSSLSocket* /* sock */,
525                     const AsyncSocketException& ex) noexcept override {
526     std::lock_guard<std::mutex> g(mutex_);
527     cv_.notify_all();
528     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
529     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
530     if (expect_ == EXPECT_ERROR) {
531       // rcb will never be invoked
532       rcb_->setState(STATE_SUCCEEDED);
533     }
534     errorString_ = ex.what();
535   }
536
537   void waitForHandshake() {
538     std::unique_lock<std::mutex> lock(mutex_);
539     cv_.wait(lock, [this] { return state != STATE_WAITING; });
540   }
541
542   ~HandshakeCallback() {
543     EXPECT_EQ(STATE_SUCCEEDED, state);
544   }
545
546   void closeSocket() {
547     socket_->close();
548     state = STATE_SUCCEEDED;
549   }
550
551   std::shared_ptr<AsyncSSLSocket> getSocket() {
552     return socket_;
553   }
554
555   StateEnum state;
556   std::shared_ptr<AsyncSSLSocket> socket_;
557   ReadCallbackBase *rcb_;
558   ExpectType expect_;
559   std::mutex mutex_;
560   std::condition_variable cv_;
561   std::string errorString_;
562 };
563
564 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
565 public:
566   uint32_t timeout_;
567
568   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
569                                    uint32_t timeout = 0):
570       SSLServerAcceptCallbackBase(hcb),
571       timeout_(timeout) {}
572
573   virtual ~SSLServerAcceptCallback() {
574     if (timeout_ > 0) {
575       // if we set a timeout, we expect failure
576       EXPECT_EQ(hcb_->state, STATE_FAILED);
577       hcb_->setState(STATE_SUCCEEDED);
578     }
579   }
580
581   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
582   void connAccepted(
583     const std::shared_ptr<folly::AsyncSSLSocket> &s)
584     noexcept override {
585     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
586     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
587
588     hcb_->setSocket(sock);
589     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
590     EXPECT_EQ(sock->getSSLState(),
591                       AsyncSSLSocket::STATE_ACCEPTING);
592
593     state = STATE_SUCCEEDED;
594   }
595 };
596
597 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
598 public:
599   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
600       SSLServerAcceptCallback(hcb) {}
601
602   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
603   void connAccepted(
604     const std::shared_ptr<folly::AsyncSSLSocket> &s)
605     noexcept override {
606
607     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
608
609     std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
610               << std::endl;
611     int fd = sock->getFd();
612
613 #ifndef TCP_NOPUSH
614     {
615     // The accepted connection should already have TCP_NODELAY set
616     int value;
617     socklen_t valueLength = sizeof(value);
618     int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
619     EXPECT_EQ(rc, 0);
620     EXPECT_EQ(value, 1);
621     }
622 #endif
623
624     // Unset the TCP_NODELAY option.
625     int value = 0;
626     socklen_t valueLength = sizeof(value);
627     int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
628     EXPECT_EQ(rc, 0);
629
630     rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
631     EXPECT_EQ(rc, 0);
632     EXPECT_EQ(value, 0);
633
634     SSLServerAcceptCallback::connAccepted(sock);
635   }
636 };
637
638 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
639 public:
640   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
641                                              uint32_t timeout = 0):
642     SSLServerAcceptCallback(hcb, timeout) {}
643
644   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
645   void connAccepted(
646     const std::shared_ptr<folly::AsyncSSLSocket> &s)
647     noexcept override {
648     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
649
650     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
651
652     hcb_->setSocket(sock);
653     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
654     ASSERT_TRUE((sock->getSSLState() ==
655                  AsyncSSLSocket::STATE_ACCEPTING) ||
656                 (sock->getSSLState() ==
657                  AsyncSSLSocket::STATE_CACHE_LOOKUP));
658
659     state = STATE_SUCCEEDED;
660   }
661 };
662
663
664 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
665 public:
666   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
667   SSLServerAcceptCallbackBase(hcb)  {}
668
669   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
670   void connAccepted(
671     const std::shared_ptr<folly::AsyncSSLSocket> &s)
672     noexcept override {
673     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
674
675     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
676
677     // The first call to sslAccept() should succeed.
678     hcb_->setSocket(sock);
679     sock->sslAccept(hcb_);
680     EXPECT_EQ(sock->getSSLState(),
681                       AsyncSSLSocket::STATE_ACCEPTING);
682
683     // The second call to sslAccept() should fail.
684     HandshakeCallback callback2(hcb_->rcb_);
685     callback2.setSocket(sock);
686     sock->sslAccept(&callback2);
687     EXPECT_EQ(sock->getSSLState(),
688                       AsyncSSLSocket::STATE_ERROR);
689
690     // Both callbacks should be in the error state.
691     EXPECT_EQ(hcb_->state, STATE_FAILED);
692     EXPECT_EQ(callback2.state, STATE_FAILED);
693
694     state = STATE_SUCCEEDED;
695     hcb_->setState(STATE_SUCCEEDED);
696     callback2.setState(STATE_SUCCEEDED);
697   }
698 };
699
700 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
701 public:
702   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
703   SSLServerAcceptCallbackBase(hcb)  {}
704
705   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
706   void connAccepted(
707     const std::shared_ptr<folly::AsyncSSLSocket> &s)
708     noexcept override {
709     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
710
711     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
712
713     hcb_->setSocket(sock);
714     sock->getEventBase()->tryRunAfterDelay([=] {
715         std::cerr << "Delayed SSL accept, client will have close by now"
716                   << std::endl;
717         // SSL accept will fail
718         EXPECT_EQ(
719           sock->getSSLState(),
720           AsyncSSLSocket::STATE_UNINIT);
721         hcb_->socket_->sslAccept(hcb_);
722         // This registers for an event
723         EXPECT_EQ(
724           sock->getSSLState(),
725           AsyncSSLSocket::STATE_ACCEPTING);
726
727         state = STATE_SUCCEEDED;
728       }, 100);
729   }
730 };
731
732 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
733  public:
734   ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
735     // We don't care if we get invoked or not.
736     // The client may time out and give up before connAccepted() is even
737     // called.
738     state = STATE_SUCCEEDED;
739   }
740
741   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
742   void connAccepted(
743       const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
744     std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
745
746     // Just wait a while before closing the socket, so the client
747     // will time out waiting for the handshake to complete.
748     s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
749   }
750 };
751
752 class TestSSLAsyncCacheServer : public TestSSLServer {
753  public:
754   explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
755         int lookupDelay = 100) :
756       TestSSLServer(acb) {
757     SSL_CTX *sslCtx = ctx_->getSSLCtx();
758 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
759     SSL_CTX_sess_set_get_cb(sslCtx,
760                             TestSSLAsyncCacheServer::getSessionCallback);
761 #endif
762     SSL_CTX_set_session_cache_mode(
763       sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
764     asyncCallbacks_ = 0;
765     asyncLookups_ = 0;
766     lookupDelay_ = lookupDelay;
767   }
768
769   uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
770   uint32_t getAsyncLookups() const { return asyncLookups_; }
771
772  private:
773   static uint32_t asyncCallbacks_;
774   static uint32_t asyncLookups_;
775   static uint32_t lookupDelay_;
776
777   static SSL_SESSION* getSessionCallback(SSL* ssl,
778                                          unsigned char* /* sess_id */,
779                                          int /* id_len */,
780                                          int* copyflag) {
781     *copyflag = 0;
782     asyncCallbacks_++;
783     (void)ssl;
784 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
785     if (!SSL_want_sess_cache_lookup(ssl)) {
786       // libssl.so mismatch
787       std::cerr << "no async support" << std::endl;
788       return nullptr;
789     }
790
791     AsyncSSLSocket *sslSocket =
792         AsyncSSLSocket::getFromSSL(ssl);
793     assert(sslSocket != nullptr);
794     // Going to simulate an async cache by just running delaying the miss 100ms
795     if (asyncCallbacks_ % 2 == 0) {
796       // This socket is already blocked on lookup, return miss
797       std::cerr << "returning miss" << std::endl;
798     } else {
799       // fresh meat - block it
800       std::cerr << "async lookup" << std::endl;
801       sslSocket->getEventBase()->tryRunAfterDelay(
802         std::bind(&AsyncSSLSocket::restartSSLAccept,
803                   sslSocket), lookupDelay_);
804       *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
805       asyncLookups_++;
806     }
807 #endif
808     return nullptr;
809   }
810 };
811
812 void getfds(int fds[2]);
813
814 void getctx(
815   std::shared_ptr<folly::SSLContext> clientCtx,
816   std::shared_ptr<folly::SSLContext> serverCtx);
817
818 void sslsocketpair(
819   EventBase* eventBase,
820   AsyncSSLSocket::UniquePtr* clientSock,
821   AsyncSSLSocket::UniquePtr* serverSock);
822
823 class BlockingWriteClient :
824   private AsyncSSLSocket::HandshakeCB,
825   private AsyncTransportWrapper::WriteCallback {
826  public:
827   explicit BlockingWriteClient(
828     AsyncSSLSocket::UniquePtr socket)
829     : socket_(std::move(socket)),
830       bufLen_(2500),
831       iovCount_(2000) {
832     // Fill buf_
833     buf_.reset(new uint8_t[bufLen_]);
834     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
835       buf_[n] = n % 0xff;
836     }
837
838     // Initialize iov_
839     iov_.reset(new struct iovec[iovCount_]);
840     for (uint32_t n = 0; n < iovCount_; ++n) {
841       iov_[n].iov_base = buf_.get() + n;
842       if (n & 0x1) {
843         iov_[n].iov_len = n % bufLen_;
844       } else {
845         iov_[n].iov_len = bufLen_ - (n % bufLen_);
846       }
847     }
848
849     socket_->sslConn(this, std::chrono::milliseconds(100));
850   }
851
852   struct iovec* getIovec() const {
853     return iov_.get();
854   }
855   uint32_t getIovecCount() const {
856     return iovCount_;
857   }
858
859  private:
860   void handshakeSuc(AsyncSSLSocket*) noexcept override {
861     socket_->writev(this, iov_.get(), iovCount_);
862   }
863   void handshakeErr(
864     AsyncSSLSocket*,
865     const AsyncSocketException& ex) noexcept override {
866     ADD_FAILURE() << "client handshake error: " << ex.what();
867   }
868   void writeSuccess() noexcept override {
869     socket_->close();
870   }
871   void writeErr(
872     size_t bytesWritten,
873     const AsyncSocketException& ex) noexcept override {
874     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
875                   << ex.what();
876   }
877
878   AsyncSSLSocket::UniquePtr socket_;
879   uint32_t bufLen_;
880   uint32_t iovCount_;
881   std::unique_ptr<uint8_t[]> buf_;
882   std::unique_ptr<struct iovec[]> iov_;
883 };
884
885 class BlockingWriteServer :
886     private AsyncSSLSocket::HandshakeCB,
887     private AsyncTransportWrapper::ReadCallback {
888  public:
889   explicit BlockingWriteServer(
890     AsyncSSLSocket::UniquePtr socket)
891     : socket_(std::move(socket)),
892       bufSize_(2500 * 2000),
893       bytesRead_(0) {
894     buf_.reset(new uint8_t[bufSize_]);
895     socket_->sslAccept(this, std::chrono::milliseconds(100));
896   }
897
898   void checkBuffer(struct iovec* iov, uint32_t count) const {
899     uint32_t idx = 0;
900     for (uint32_t n = 0; n < count; ++n) {
901       size_t bytesLeft = bytesRead_ - idx;
902       int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
903                       std::min(iov[n].iov_len, bytesLeft));
904       if (rc != 0) {
905         FAIL() << "buffer mismatch at iovec " << n << "/" << count
906                << ": rc=" << rc;
907
908       }
909       if (iov[n].iov_len > bytesLeft) {
910         FAIL() << "server did not read enough data: "
911                << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
912                << " in iovec " << n << "/" << count;
913       }
914
915       idx += iov[n].iov_len;
916     }
917     if (idx != bytesRead_) {
918       ADD_FAILURE() << "server read extra data: " << bytesRead_
919                     << " bytes read; expected " << idx;
920     }
921   }
922
923  private:
924   void handshakeSuc(AsyncSSLSocket*) noexcept override {
925     // Wait 10ms before reading, so the client's writes will initially block.
926     socket_->getEventBase()->tryRunAfterDelay(
927         [this] { socket_->setReadCB(this); }, 10);
928   }
929   void handshakeErr(
930     AsyncSSLSocket*,
931     const AsyncSocketException& ex) noexcept override {
932     ADD_FAILURE() << "server handshake error: " << ex.what();
933   }
934   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
935     *bufReturn = buf_.get() + bytesRead_;
936     *lenReturn = bufSize_ - bytesRead_;
937   }
938   void readDataAvailable(size_t len) noexcept override {
939     bytesRead_ += len;
940     socket_->setReadCB(nullptr);
941     socket_->getEventBase()->tryRunAfterDelay(
942         [this] { socket_->setReadCB(this); }, 2);
943   }
944   void readEOF() noexcept override {
945     socket_->close();
946   }
947   void readErr(
948     const AsyncSocketException& ex) noexcept override {
949     ADD_FAILURE() << "server read error: " << ex.what();
950   }
951
952   AsyncSSLSocket::UniquePtr socket_;
953   uint32_t bufSize_;
954   uint32_t bytesRead_;
955   std::unique_ptr<uint8_t[]> buf_;
956 };
957
958 class NpnClient :
959   private AsyncSSLSocket::HandshakeCB,
960   private AsyncTransportWrapper::WriteCallback {
961  public:
962   explicit NpnClient(
963     AsyncSSLSocket::UniquePtr socket)
964       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
965     socket_->sslConn(this);
966   }
967
968   const unsigned char* nextProto;
969   unsigned nextProtoLength;
970   SSLContext::NextProtocolType protocolType;
971
972  private:
973   void handshakeSuc(AsyncSSLSocket*) noexcept override {
974     socket_->getSelectedNextProtocol(
975         &nextProto, &nextProtoLength, &protocolType);
976   }
977   void handshakeErr(
978     AsyncSSLSocket*,
979     const AsyncSocketException& ex) noexcept override {
980     ADD_FAILURE() << "client handshake error: " << ex.what();
981   }
982   void writeSuccess() noexcept override {
983     socket_->close();
984   }
985   void writeErr(
986     size_t bytesWritten,
987     const AsyncSocketException& ex) noexcept override {
988     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
989                   << ex.what();
990   }
991
992   AsyncSSLSocket::UniquePtr socket_;
993 };
994
995 class NpnServer :
996     private AsyncSSLSocket::HandshakeCB,
997     private AsyncTransportWrapper::ReadCallback {
998  public:
999   explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
1000       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
1001     socket_->sslAccept(this);
1002   }
1003
1004   const unsigned char* nextProto;
1005   unsigned nextProtoLength;
1006   SSLContext::NextProtocolType protocolType;
1007
1008  private:
1009   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1010     socket_->getSelectedNextProtocol(
1011         &nextProto, &nextProtoLength, &protocolType);
1012   }
1013   void handshakeErr(
1014     AsyncSSLSocket*,
1015     const AsyncSocketException& ex) noexcept override {
1016     ADD_FAILURE() << "server handshake error: " << ex.what();
1017   }
1018   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1019     *lenReturn = 0;
1020   }
1021   void readDataAvailable(size_t /* len */) noexcept override {}
1022   void readEOF() noexcept override {
1023     socket_->close();
1024   }
1025   void readErr(
1026     const AsyncSocketException& ex) noexcept override {
1027     ADD_FAILURE() << "server read error: " << ex.what();
1028   }
1029
1030   AsyncSSLSocket::UniquePtr socket_;
1031 };
1032
1033 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
1034                             public AsyncTransportWrapper::ReadCallback {
1035  public:
1036   explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
1037       : socket_(std::move(socket)) {
1038     socket_->sslAccept(this);
1039   }
1040
1041   ~RenegotiatingServer() {
1042     socket_->setReadCB(nullptr);
1043   }
1044
1045   void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
1046     LOG(INFO) << "Renegotiating server handshake success";
1047     socket_->setReadCB(this);
1048   }
1049   void handshakeErr(
1050       AsyncSSLSocket*,
1051       const AsyncSocketException& ex) noexcept override {
1052     ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
1053   }
1054   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1055     *lenReturn = sizeof(buf);
1056     *bufReturn = buf;
1057   }
1058   void readDataAvailable(size_t /* len */) noexcept override {}
1059   void readEOF() noexcept override {}
1060   void readErr(const AsyncSocketException& ex) noexcept override {
1061     LOG(INFO) << "server got read error " << ex.what();
1062     auto exPtr = dynamic_cast<const SSLException*>(&ex);
1063     ASSERT_NE(nullptr, exPtr);
1064     std::string exStr(ex.what());
1065     SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
1066     ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
1067     renegotiationError_ = true;
1068   }
1069
1070   AsyncSSLSocket::UniquePtr socket_;
1071   unsigned char buf[128];
1072   bool renegotiationError_{false};
1073 };
1074
1075 #ifndef OPENSSL_NO_TLSEXT
1076 class SNIClient :
1077   private AsyncSSLSocket::HandshakeCB,
1078   private AsyncTransportWrapper::WriteCallback {
1079  public:
1080   explicit SNIClient(
1081     AsyncSSLSocket::UniquePtr socket)
1082       : serverNameMatch(false), socket_(std::move(socket)) {
1083     socket_->sslConn(this);
1084   }
1085
1086   bool serverNameMatch;
1087
1088  private:
1089   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1090     serverNameMatch = socket_->isServerNameMatch();
1091   }
1092   void handshakeErr(
1093     AsyncSSLSocket*,
1094     const AsyncSocketException& ex) noexcept override {
1095     ADD_FAILURE() << "client handshake error: " << ex.what();
1096   }
1097   void writeSuccess() noexcept override {
1098     socket_->close();
1099   }
1100   void writeErr(
1101     size_t bytesWritten,
1102     const AsyncSocketException& ex) noexcept override {
1103     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1104                   << ex.what();
1105   }
1106
1107   AsyncSSLSocket::UniquePtr socket_;
1108 };
1109
1110 class SNIServer :
1111     private AsyncSSLSocket::HandshakeCB,
1112     private AsyncTransportWrapper::ReadCallback {
1113  public:
1114   explicit SNIServer(
1115     AsyncSSLSocket::UniquePtr socket,
1116     const std::shared_ptr<folly::SSLContext>& ctx,
1117     const std::shared_ptr<folly::SSLContext>& sniCtx,
1118     const std::string& expectedServerName)
1119       : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1120         expectedServerName_(expectedServerName) {
1121     ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1122                                          std::placeholders::_1));
1123     socket_->sslAccept(this);
1124   }
1125
1126   bool serverNameMatch;
1127
1128  private:
1129   void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1130   void handshakeErr(
1131     AsyncSSLSocket*,
1132     const AsyncSocketException& ex) noexcept override {
1133     ADD_FAILURE() << "server handshake error: " << ex.what();
1134   }
1135   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1136     *lenReturn = 0;
1137   }
1138   void readDataAvailable(size_t /* len */) noexcept override {}
1139   void readEOF() noexcept override {
1140     socket_->close();
1141   }
1142   void readErr(
1143     const AsyncSocketException& ex) noexcept override {
1144     ADD_FAILURE() << "server read error: " << ex.what();
1145   }
1146
1147   folly::SSLContext::ServerNameCallbackResult
1148     serverNameCallback(SSL *ssl) {
1149     const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1150     if (sniCtx_ &&
1151         sn &&
1152         !strcasecmp(expectedServerName_.c_str(), sn)) {
1153       AsyncSSLSocket *sslSocket =
1154           AsyncSSLSocket::getFromSSL(ssl);
1155       sslSocket->switchServerSSLContext(sniCtx_);
1156       serverNameMatch = true;
1157       return folly::SSLContext::SERVER_NAME_FOUND;
1158     } else {
1159       serverNameMatch = false;
1160       return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1161     }
1162   }
1163
1164   AsyncSSLSocket::UniquePtr socket_;
1165   std::shared_ptr<folly::SSLContext> sniCtx_;
1166   std::string expectedServerName_;
1167 };
1168 #endif
1169
1170 class SSLClient : public AsyncSocket::ConnectCallback,
1171                   public AsyncTransportWrapper::WriteCallback,
1172                   public AsyncTransportWrapper::ReadCallback
1173 {
1174  private:
1175   EventBase *eventBase_;
1176   std::shared_ptr<AsyncSSLSocket> sslSocket_;
1177   SSL_SESSION *session_;
1178   std::shared_ptr<folly::SSLContext> ctx_;
1179   uint32_t requests_;
1180   folly::SocketAddress address_;
1181   uint32_t timeout_;
1182   char buf_[128];
1183   char readbuf_[128];
1184   uint32_t bytesRead_;
1185   uint32_t hit_;
1186   uint32_t miss_;
1187   uint32_t errors_;
1188   uint32_t writeAfterConnectErrors_;
1189
1190   // These settings test that we eventually drain the
1191   // socket, even if the maxReadsPerEvent_ is hit during
1192   // a event loop iteration.
1193   static constexpr size_t kMaxReadsPerEvent = 2;
1194   // 2 event loop iterations
1195   static constexpr size_t kMaxReadBufferSz =
1196     sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1197
1198  public:
1199   SSLClient(EventBase *eventBase,
1200             const folly::SocketAddress& address,
1201             uint32_t requests,
1202             uint32_t timeout = 0)
1203       : eventBase_(eventBase),
1204         session_(nullptr),
1205         requests_(requests),
1206         address_(address),
1207         timeout_(timeout),
1208         bytesRead_(0),
1209         hit_(0),
1210         miss_(0),
1211         errors_(0),
1212         writeAfterConnectErrors_(0) {
1213     ctx_.reset(new folly::SSLContext());
1214     ctx_->setOptions(SSL_OP_NO_TICKET);
1215     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1216     memset(buf_, 'a', sizeof(buf_));
1217   }
1218
1219   ~SSLClient() {
1220     if (session_) {
1221       SSL_SESSION_free(session_);
1222     }
1223     if (errors_ == 0) {
1224       EXPECT_EQ(bytesRead_, sizeof(buf_));
1225     }
1226   }
1227
1228   uint32_t getHit() const { return hit_; }
1229
1230   uint32_t getMiss() const { return miss_; }
1231
1232   uint32_t getErrors() const { return errors_; }
1233
1234   uint32_t getWriteAfterConnectErrors() const {
1235     return writeAfterConnectErrors_;
1236   }
1237
1238   void connect(bool writeNow = false) {
1239     sslSocket_ = AsyncSSLSocket::newSocket(
1240       ctx_, eventBase_);
1241     if (session_ != nullptr) {
1242       sslSocket_->setSSLSession(session_);
1243     }
1244     requests_--;
1245     sslSocket_->connect(this, address_, timeout_);
1246     if (sslSocket_ && writeNow) {
1247       // write some junk, used in an error test
1248       sslSocket_->write(this, buf_, sizeof(buf_));
1249     }
1250   }
1251
1252   void connectSuccess() noexcept override {
1253     std::cerr << "client SSL socket connected" << std::endl;
1254     if (sslSocket_->getSSLSessionReused()) {
1255       hit_++;
1256     } else {
1257       miss_++;
1258       if (session_ != nullptr) {
1259         SSL_SESSION_free(session_);
1260       }
1261       session_ = sslSocket_->getSSLSession();
1262     }
1263
1264     // write()
1265     sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1266     sslSocket_->write(this, buf_, sizeof(buf_));
1267     sslSocket_->setReadCB(this);
1268     memset(readbuf_, 'b', sizeof(readbuf_));
1269     bytesRead_ = 0;
1270   }
1271
1272   void connectErr(
1273     const AsyncSocketException& ex) noexcept override {
1274     std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1275     errors_++;
1276     sslSocket_.reset();
1277   }
1278
1279   void writeSuccess() noexcept override {
1280     std::cerr << "client write success" << std::endl;
1281   }
1282
1283   void writeErr(size_t /* bytesWritten */,
1284                 const AsyncSocketException& ex) noexcept override {
1285     std::cerr << "client writeError: " << ex.what() << std::endl;
1286     if (!sslSocket_) {
1287       writeAfterConnectErrors_++;
1288     }
1289   }
1290
1291   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1292     *bufReturn = readbuf_ + bytesRead_;
1293     *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1294   }
1295
1296   void readEOF() noexcept override {
1297     std::cerr << "client readEOF" << std::endl;
1298   }
1299
1300   void readErr(
1301     const AsyncSocketException& ex) noexcept override {
1302     std::cerr << "client readError: " << ex.what() << std::endl;
1303   }
1304
1305   void readDataAvailable(size_t len) noexcept override {
1306     std::cerr << "client read data: " << len << std::endl;
1307     bytesRead_ += len;
1308     if (bytesRead_ == sizeof(buf_)) {
1309       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1310       sslSocket_->closeNow();
1311       sslSocket_.reset();
1312       if (requests_ != 0) {
1313         connect();
1314       }
1315     }
1316   }
1317
1318 };
1319
1320 class SSLHandshakeBase :
1321   public AsyncSSLSocket::HandshakeCB,
1322   private AsyncTransportWrapper::WriteCallback {
1323  public:
1324   explicit SSLHandshakeBase(
1325    AsyncSSLSocket::UniquePtr socket,
1326    bool preverifyResult,
1327    bool verifyResult) :
1328     handshakeVerify_(false),
1329     handshakeSuccess_(false),
1330     handshakeError_(false),
1331     socket_(std::move(socket)),
1332     preverifyResult_(preverifyResult),
1333     verifyResult_(verifyResult) {
1334   }
1335
1336   AsyncSSLSocket::UniquePtr moveSocket() && {
1337     return std::move(socket_);
1338   }
1339
1340   bool handshakeVerify_;
1341   bool handshakeSuccess_;
1342   bool handshakeError_;
1343   std::chrono::nanoseconds handshakeTime;
1344
1345  protected:
1346   AsyncSSLSocket::UniquePtr socket_;
1347   bool preverifyResult_;
1348   bool verifyResult_;
1349
1350   // HandshakeCallback
1351   bool handshakeVer(AsyncSSLSocket* /* sock */,
1352                     bool preverifyOk,
1353                     X509_STORE_CTX* /* ctx */) noexcept override {
1354     handshakeVerify_ = true;
1355
1356     EXPECT_EQ(preverifyResult_, preverifyOk);
1357     return verifyResult_;
1358   }
1359
1360   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1361     LOG(INFO) << "Handshake success";
1362     handshakeSuccess_ = true;
1363     handshakeTime = socket_->getHandshakeTime();
1364   }
1365
1366   void handshakeErr(
1367       AsyncSSLSocket*,
1368       const AsyncSocketException& ex) noexcept override {
1369     LOG(INFO) << "Handshake error " << ex.what();
1370     handshakeError_ = true;
1371     handshakeTime = socket_->getHandshakeTime();
1372   }
1373
1374   // WriteCallback
1375   void writeSuccess() noexcept override {
1376     socket_->close();
1377   }
1378
1379   void writeErr(
1380    size_t bytesWritten,
1381    const AsyncSocketException& ex) noexcept override {
1382     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1383                   << ex.what();
1384   }
1385 };
1386
1387 class SSLHandshakeClient : public SSLHandshakeBase {
1388  public:
1389   SSLHandshakeClient(
1390    AsyncSSLSocket::UniquePtr socket,
1391    bool preverifyResult,
1392    bool verifyResult) :
1393     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1394     socket_->sslConn(this, std::chrono::milliseconds::zero());
1395   }
1396 };
1397
1398 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1399  public:
1400   SSLHandshakeClientNoVerify(
1401    AsyncSSLSocket::UniquePtr socket,
1402    bool preverifyResult,
1403    bool verifyResult) :
1404     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1405     socket_->sslConn(
1406         this,
1407         std::chrono::milliseconds::zero(),
1408         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1409   }
1410 };
1411
1412 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1413  public:
1414   SSLHandshakeClientDoVerify(
1415    AsyncSSLSocket::UniquePtr socket,
1416    bool preverifyResult,
1417    bool verifyResult) :
1418     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1419     socket_->sslConn(
1420         this,
1421         std::chrono::milliseconds::zero(),
1422         folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1423   }
1424 };
1425
1426 class SSLHandshakeServer : public SSLHandshakeBase {
1427  public:
1428   SSLHandshakeServer(
1429       AsyncSSLSocket::UniquePtr socket,
1430       bool preverifyResult,
1431       bool verifyResult)
1432     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1433     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1434   }
1435 };
1436
1437 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1438  public:
1439   SSLHandshakeServerParseClientHello(
1440       AsyncSSLSocket::UniquePtr socket,
1441       bool preverifyResult,
1442       bool verifyResult)
1443       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1444     socket_->enableClientHelloParsing();
1445     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1446   }
1447
1448   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1449
1450  protected:
1451   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1452     handshakeSuccess_ = true;
1453     sock->getSSLSharedCiphers(sharedCiphers_);
1454     sock->getSSLServerCiphers(serverCiphers_);
1455     sock->getSSLClientCiphers(clientCiphers_);
1456     chosenCipher_ = sock->getNegotiatedCipherName();
1457   }
1458 };
1459
1460
1461 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1462  public:
1463   SSLHandshakeServerNoVerify(
1464       AsyncSSLSocket::UniquePtr socket,
1465       bool preverifyResult,
1466       bool verifyResult)
1467     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1468     socket_->sslAccept(
1469         this,
1470         std::chrono::milliseconds::zero(),
1471         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1472   }
1473 };
1474
1475 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1476  public:
1477   SSLHandshakeServerDoVerify(
1478       AsyncSSLSocket::UniquePtr socket,
1479       bool preverifyResult,
1480       bool verifyResult)
1481     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1482     socket_->sslAccept(
1483         this,
1484         std::chrono::milliseconds::zero(),
1485         folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1486   }
1487 };
1488
1489 class EventBaseAborter : public AsyncTimeout {
1490  public:
1491   EventBaseAborter(EventBase* eventBase,
1492                    uint32_t timeoutMS)
1493     : AsyncTimeout(
1494       eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1495     , eventBase_(eventBase) {
1496     scheduleTimeout(timeoutMS);
1497   }
1498
1499   void timeoutExpired() noexcept override {
1500     FAIL() << "test timed out";
1501     eventBase_->terminateLoopSoon();
1502   }
1503
1504  private:
1505   EventBase* eventBase_;
1506 };
1507
1508 }