Fix copyright lines
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
1 /*
2  * Copyright 2012-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <signal.h>
19
20 #include <folly/ExceptionWrapper.h>
21 #include <folly/SocketAddress.h>
22 #include <folly/experimental/TestUtil.h>
23 #include <folly/io/async/AsyncSSLSocket.h>
24 #include <folly/io/async/AsyncServerSocket.h>
25 #include <folly/io/async/AsyncSocket.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/io/async/AsyncTransport.h>
28 #include <folly/io/async/EventBase.h>
29 #include <folly/io/async/ssl/SSLErrors.h>
30 #include <folly/io/async/test/TestSSLServer.h>
31 #include <folly/portability/GTest.h>
32 #include <folly/portability/PThread.h>
33 #include <folly/portability/Sockets.h>
34 #include <folly/portability/Unistd.h>
35
36 #include <fcntl.h>
37 #include <sys/types.h>
38 #include <condition_variable>
39 #include <iostream>
40 #include <list>
41 #include <memory>
42
43 namespace folly {
44
45 // The destructors of all callback classes assert that the state is
46 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
47 // are responsible for setting the succeeded state properly before the
48 // destructors are called.
49
50 class SendMsgParamsCallbackBase :
51       public folly::AsyncSocket::SendMsgParamsCallback {
52  public:
53   SendMsgParamsCallbackBase() {}
54
55   void setSocket(
56     const std::shared_ptr<AsyncSSLSocket> &socket) {
57     socket_ = socket;
58     oldCallback_ = socket_->getSendMsgParamsCB();
59     socket_->setSendMsgParamCB(this);
60   }
61
62   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
63                                                                      override {
64     return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
65   }
66
67   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
68     oldCallback_->getAncillaryData(flags, data);
69   }
70
71   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
72     return oldCallback_->getAncillaryDataSize(flags);
73   }
74
75   std::shared_ptr<AsyncSSLSocket> socket_;
76   folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
77 };
78
79 class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
80  public:
81   SendMsgFlagsCallback() {}
82
83   void resetFlags(int flags) {
84     flags_ = flags;
85   }
86
87   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
88                                                                   override {
89     if (flags_) {
90       return flags_;
91     } else {
92       return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
93     }
94   }
95
96   int flags_{0};
97 };
98
99 class SendMsgDataCallback : public SendMsgFlagsCallback {
100  public:
101   SendMsgDataCallback() {}
102
103   void resetData(std::vector<char>&& data) {
104     ancillaryData_.swap(data);
105   }
106
107   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
108     if (ancillaryData_.size()) {
109       std::cerr << "getAncillaryData: copying data" << std::endl;
110       memcpy(data, ancillaryData_.data(), ancillaryData_.size());
111     } else {
112       oldCallback_->getAncillaryData(flags, data);
113     }
114   }
115
116   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
117     if (ancillaryData_.size()) {
118       std::cerr << "getAncillaryDataSize: returning size" << std::endl;
119       return ancillaryData_.size();
120     } else {
121       return oldCallback_->getAncillaryDataSize(flags);
122     }
123   }
124
125   std::vector<char> ancillaryData_;
126 };
127
128 class WriteCallbackBase :
129 public AsyncTransportWrapper::WriteCallback {
130  public:
131   explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
132       : state(STATE_WAITING)
133       , bytesWritten(0)
134       , exception(AsyncSocketException::UNKNOWN, "none")
135       , mcb_(mcb) {}
136
137   ~WriteCallbackBase() override {
138     EXPECT_EQ(STATE_SUCCEEDED, state);
139   }
140
141   virtual void setSocket(
142     const std::shared_ptr<AsyncSSLSocket> &socket) {
143     socket_ = socket;
144     if (mcb_) {
145       mcb_->setSocket(socket);
146     }
147   }
148
149   void writeSuccess() noexcept override {
150     std::cerr << "writeSuccess" << std::endl;
151     state = STATE_SUCCEEDED;
152   }
153
154   void writeErr(
155     size_t nBytesWritten,
156     const AsyncSocketException& ex) noexcept override {
157     std::cerr << "writeError: bytesWritten " << nBytesWritten
158          << ", exception " << ex.what() << std::endl;
159
160     state = STATE_FAILED;
161     this->bytesWritten = nBytesWritten;
162     exception = ex;
163     socket_->close();
164   }
165
166   std::shared_ptr<AsyncSSLSocket> socket_;
167   StateEnum state;
168   size_t bytesWritten;
169   AsyncSocketException exception;
170   SendMsgParamsCallbackBase* mcb_;
171 };
172
173 class ExpectWriteErrorCallback :
174 public WriteCallbackBase {
175  public:
176   explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
177       : WriteCallbackBase(mcb) {}
178
179   ~ExpectWriteErrorCallback() override {
180     EXPECT_EQ(STATE_FAILED, state);
181     EXPECT_EQ(exception.type_,
182              AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
183     EXPECT_EQ(exception.errno_, 22);
184     // Suppress the assert in  ~WriteCallbackBase()
185     state = STATE_SUCCEEDED;
186   }
187 };
188
189 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
190 /* copied from include/uapi/linux/net_tstamp.h */
191 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
192 enum SOF_TIMESTAMPING {
193   SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
194   SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
195   SOF_TIMESTAMPING_OPT_ID = (1 << 7),
196   SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
197   SOF_TIMESTAMPING_TX_ACK = (1 << 9),
198   SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
199 };
200
201 class WriteCheckTimestampCallback :
202   public WriteCallbackBase {
203  public:
204   explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
205     : WriteCallbackBase(mcb) {}
206
207   ~WriteCheckTimestampCallback() override {
208     EXPECT_EQ(STATE_SUCCEEDED, state);
209     EXPECT_TRUE(gotTimestamp_);
210     EXPECT_TRUE(gotByteSeq_);
211   }
212
213   void setSocket(
214     const std::shared_ptr<AsyncSSLSocket> &socket) override {
215     WriteCallbackBase::setSocket(socket);
216
217     EXPECT_NE(socket_->getFd(), 0);
218     int flags = SOF_TIMESTAMPING_OPT_ID
219                 | SOF_TIMESTAMPING_OPT_TSONLY
220                 | SOF_TIMESTAMPING_SOFTWARE;
221     AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
222     int ret = tstampingOpt.apply(socket_->getFd(), flags);
223     EXPECT_EQ(ret, 0);
224   }
225
226   void checkForTimestampNotifications() noexcept {
227     int fd = socket_->getFd();
228     std::vector<char> ctrl(1024, 0);
229     unsigned char data;
230     struct msghdr msg;
231     iovec entry;
232
233     memset(&msg, 0, sizeof(msg));
234     entry.iov_base = &data;
235     entry.iov_len = sizeof(data);
236     msg.msg_iov = &entry;
237     msg.msg_iovlen = 1;
238     msg.msg_control = ctrl.data();
239     msg.msg_controllen = ctrl.size();
240
241     int ret;
242     while (true) {
243       ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
244       if (ret < 0) {
245         if (errno != EAGAIN) {
246           auto errnoCopy = errno;
247           std::cerr << "::recvmsg exited with code " << ret
248                     << ", errno: " << errnoCopy << std::endl;
249           AsyncSocketException ex(
250             AsyncSocketException::INTERNAL_ERROR,
251             "recvmsg() failed",
252             errnoCopy);
253           exception = ex;
254         }
255         return;
256       }
257
258       for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
259            cmsg != nullptr && cmsg->cmsg_len != 0;
260            cmsg = CMSG_NXTHDR(&msg, cmsg)) {
261         if (cmsg->cmsg_level == SOL_SOCKET &&
262             cmsg->cmsg_type == SCM_TIMESTAMPING) {
263           gotTimestamp_ = true;
264           continue;
265         }
266
267         if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
268             (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
269           gotByteSeq_ = true;
270           continue;
271         }
272       }
273     }
274   }
275
276   bool gotTimestamp_{false};
277   bool gotByteSeq_{false};
278 };
279 #endif // FOLLY_HAVE_MSG_ERRQUEUE
280
281 class ReadCallbackBase :
282 public AsyncTransportWrapper::ReadCallback {
283  public:
284   explicit ReadCallbackBase(WriteCallbackBase* wcb)
285       : wcb_(wcb), state(STATE_WAITING) {}
286
287   ~ReadCallbackBase() override {
288     EXPECT_EQ(STATE_SUCCEEDED, state);
289   }
290
291   void setSocket(
292     const std::shared_ptr<AsyncSSLSocket> &socket) {
293     socket_ = socket;
294   }
295
296   void setState(StateEnum s) {
297     state = s;
298     if (wcb_) {
299       wcb_->state = s;
300     }
301   }
302
303   void readErr(
304     const AsyncSocketException& ex) noexcept override {
305     std::cerr << "readError " << ex.what() << std::endl;
306     state = STATE_FAILED;
307     socket_->close();
308   }
309
310   void readEOF() noexcept override {
311     std::cerr << "readEOF" << std::endl;
312
313     socket_->close();
314   }
315
316   std::shared_ptr<AsyncSSLSocket> socket_;
317   WriteCallbackBase *wcb_;
318   StateEnum state;
319 };
320
321 class ReadCallback : public ReadCallbackBase {
322  public:
323   explicit ReadCallback(WriteCallbackBase *wcb)
324       : ReadCallbackBase(wcb)
325       , buffers() {}
326
327   ~ReadCallback() override {
328     for (std::vector<Buffer>::iterator it = buffers.begin();
329          it != buffers.end();
330          ++it) {
331       it->free();
332     }
333     currentBuffer.free();
334   }
335
336   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
337     if (!currentBuffer.buffer) {
338       currentBuffer.allocate(4096);
339     }
340     *bufReturn = currentBuffer.buffer;
341     *lenReturn = currentBuffer.length;
342   }
343
344   void readDataAvailable(size_t len) noexcept override {
345     std::cerr << "readDataAvailable, len " << len << std::endl;
346
347     currentBuffer.length = len;
348
349     wcb_->setSocket(socket_);
350
351     // Write back the same data.
352     socket_->write(wcb_, currentBuffer.buffer, len);
353
354     buffers.push_back(currentBuffer);
355     currentBuffer.reset();
356     state = STATE_SUCCEEDED;
357   }
358
359   class Buffer {
360    public:
361     Buffer() : buffer(nullptr), length(0) {}
362     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
363
364     void reset() {
365       buffer = nullptr;
366       length = 0;
367     }
368     void allocate(size_t len) {
369       assert(buffer == nullptr);
370       this->buffer = static_cast<char*>(malloc(len));
371       this->length = len;
372     }
373     void free() {
374       ::free(buffer);
375       reset();
376     }
377
378     char* buffer;
379     size_t length;
380   };
381
382   std::vector<Buffer> buffers;
383   Buffer currentBuffer;
384 };
385
386 class ReadErrorCallback : public ReadCallbackBase {
387  public:
388   explicit ReadErrorCallback(WriteCallbackBase *wcb)
389       : ReadCallbackBase(wcb) {}
390
391   // Return nullptr buffer to trigger readError()
392   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
393     *bufReturn = nullptr;
394     *lenReturn = 0;
395   }
396
397   void readDataAvailable(size_t /* len */) noexcept override {
398     // This should never to called.
399     FAIL();
400   }
401
402   void readErr(
403     const AsyncSocketException& ex) noexcept override {
404     ReadCallbackBase::readErr(ex);
405     std::cerr << "ReadErrorCallback::readError" << std::endl;
406     setState(STATE_SUCCEEDED);
407   }
408 };
409
410 class ReadEOFCallback : public ReadCallbackBase {
411  public:
412   explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
413
414   // Return nullptr buffer to trigger readError()
415   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
416     *bufReturn = nullptr;
417     *lenReturn = 0;
418   }
419
420   void readDataAvailable(size_t /* len */) noexcept override {
421     // This should never to called.
422     FAIL();
423   }
424
425   void readEOF() noexcept override {
426     ReadCallbackBase::readEOF();
427     setState(STATE_SUCCEEDED);
428   }
429 };
430
431 class WriteErrorCallback : public ReadCallback {
432  public:
433   explicit WriteErrorCallback(WriteCallbackBase *wcb)
434       : ReadCallback(wcb) {}
435
436   void readDataAvailable(size_t len) noexcept override {
437     std::cerr << "readDataAvailable, len " << len << std::endl;
438
439     currentBuffer.length = len;
440
441     // close the socket before writing to trigger writeError().
442     ::close(socket_->getFd());
443
444     wcb_->setSocket(socket_);
445
446     // Write back the same data.
447     folly::test::msvcSuppressAbortOnInvalidParams([&] {
448       socket_->write(wcb_, currentBuffer.buffer, len);
449     });
450
451     if (wcb_->state == STATE_FAILED) {
452       setState(STATE_SUCCEEDED);
453     } else {
454       state = STATE_FAILED;
455     }
456
457     buffers.push_back(currentBuffer);
458     currentBuffer.reset();
459   }
460
461   void readErr(const AsyncSocketException& ex) noexcept override {
462     std::cerr << "readError " << ex.what() << std::endl;
463     // do nothing since this is expected
464   }
465 };
466
467 class EmptyReadCallback : public ReadCallback {
468  public:
469   explicit EmptyReadCallback()
470       : ReadCallback(nullptr) {}
471
472   void readErr(const AsyncSocketException& ex) noexcept override {
473     std::cerr << "readError " << ex.what() << std::endl;
474     state = STATE_FAILED;
475     if (tcpSocket_) {
476       tcpSocket_->close();
477     }
478   }
479
480   void readEOF() noexcept override {
481     std::cerr << "readEOF" << std::endl;
482     if (tcpSocket_) {
483       tcpSocket_->close();
484     }
485     state = STATE_SUCCEEDED;
486   }
487
488   std::shared_ptr<AsyncSocket> tcpSocket_;
489 };
490
491 class HandshakeCallback :
492 public AsyncSSLSocket::HandshakeCB {
493  public:
494   enum ExpectType {
495     EXPECT_SUCCESS,
496     EXPECT_ERROR
497   };
498
499   explicit HandshakeCallback(ReadCallbackBase *rcb,
500                              ExpectType expect = EXPECT_SUCCESS):
501       state(STATE_WAITING),
502       rcb_(rcb),
503       expect_(expect) {}
504
505   void setSocket(
506     const std::shared_ptr<AsyncSSLSocket> &socket) {
507     socket_ = socket;
508   }
509
510   void setState(StateEnum s) {
511     state = s;
512     rcb_->setState(s);
513   }
514
515   // Functions inherited from AsyncSSLSocketHandshakeCallback
516   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
517     std::lock_guard<std::mutex> g(mutex_);
518     cv_.notify_all();
519     EXPECT_EQ(sock, socket_.get());
520     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
521     rcb_->setSocket(socket_);
522     sock->setReadCB(rcb_);
523     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
524   }
525   void handshakeErr(AsyncSSLSocket* /* sock */,
526                     const AsyncSocketException& ex) noexcept override {
527     std::lock_guard<std::mutex> g(mutex_);
528     cv_.notify_all();
529     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
530     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
531     if (expect_ == EXPECT_ERROR) {
532       // rcb will never be invoked
533       rcb_->setState(STATE_SUCCEEDED);
534     }
535     errorString_ = ex.what();
536   }
537
538   void waitForHandshake() {
539     std::unique_lock<std::mutex> lock(mutex_);
540     cv_.wait(lock, [this] { return state != STATE_WAITING; });
541   }
542
543   ~HandshakeCallback() override {
544     EXPECT_EQ(STATE_SUCCEEDED, state);
545   }
546
547   void closeSocket() {
548     socket_->close();
549     state = STATE_SUCCEEDED;
550   }
551
552   std::shared_ptr<AsyncSSLSocket> getSocket() {
553     return socket_;
554   }
555
556   StateEnum state;
557   std::shared_ptr<AsyncSSLSocket> socket_;
558   ReadCallbackBase *rcb_;
559   ExpectType expect_;
560   std::mutex mutex_;
561   std::condition_variable cv_;
562   std::string errorString_;
563 };
564
565 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
566  public:
567   uint32_t timeout_;
568
569   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
570                                    uint32_t timeout = 0):
571       SSLServerAcceptCallbackBase(hcb),
572       timeout_(timeout) {}
573
574   ~SSLServerAcceptCallback() override {
575     if (timeout_ > 0) {
576       // if we set a timeout, we expect failure
577       EXPECT_EQ(hcb_->state, STATE_FAILED);
578       hcb_->setState(STATE_SUCCEEDED);
579     }
580   }
581
582   void connAccepted(
583     const std::shared_ptr<folly::AsyncSSLSocket> &s)
584     noexcept override {
585     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
586     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
587
588     hcb_->setSocket(sock);
589     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
590     EXPECT_EQ(sock->getSSLState(),
591                       AsyncSSLSocket::STATE_ACCEPTING);
592
593     state = STATE_SUCCEEDED;
594   }
595 };
596
597 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
598  public:
599   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
600       SSLServerAcceptCallback(hcb) {}
601
602   void connAccepted(
603     const std::shared_ptr<folly::AsyncSSLSocket> &s)
604     noexcept override {
605
606     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
607
608     std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
609               << std::endl;
610     int fd = sock->getFd();
611
612 #ifndef TCP_NOPUSH
613     {
614     // The accepted connection should already have TCP_NODELAY set
615     int value;
616     socklen_t valueLength = sizeof(value);
617     int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
618     EXPECT_EQ(rc, 0);
619     EXPECT_EQ(value, 1);
620     }
621 #endif
622
623     // Unset the TCP_NODELAY option.
624     int value = 0;
625     socklen_t valueLength = sizeof(value);
626     int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
627     EXPECT_EQ(rc, 0);
628
629     rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
630     EXPECT_EQ(rc, 0);
631     EXPECT_EQ(value, 0);
632
633     SSLServerAcceptCallback::connAccepted(sock);
634   }
635 };
636
637 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
638  public:
639   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
640                                              uint32_t timeout = 0):
641     SSLServerAcceptCallback(hcb, timeout) {}
642
643   void connAccepted(
644     const std::shared_ptr<folly::AsyncSSLSocket> &s)
645     noexcept override {
646     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
647
648     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
649
650     hcb_->setSocket(sock);
651     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
652     ASSERT_TRUE((sock->getSSLState() ==
653                  AsyncSSLSocket::STATE_ACCEPTING) ||
654                 (sock->getSSLState() ==
655                  AsyncSSLSocket::STATE_CACHE_LOOKUP));
656
657     state = STATE_SUCCEEDED;
658   }
659 };
660
661
662 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
663  public:
664   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
665   SSLServerAcceptCallbackBase(hcb)  {}
666
667   void connAccepted(
668     const std::shared_ptr<folly::AsyncSSLSocket> &s)
669     noexcept override {
670     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
671
672     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
673
674     // The first call to sslAccept() should succeed.
675     hcb_->setSocket(sock);
676     sock->sslAccept(hcb_);
677     EXPECT_EQ(sock->getSSLState(),
678                       AsyncSSLSocket::STATE_ACCEPTING);
679
680     // The second call to sslAccept() should fail.
681     HandshakeCallback callback2(hcb_->rcb_);
682     callback2.setSocket(sock);
683     sock->sslAccept(&callback2);
684     EXPECT_EQ(sock->getSSLState(),
685                       AsyncSSLSocket::STATE_ERROR);
686
687     // Both callbacks should be in the error state.
688     EXPECT_EQ(hcb_->state, STATE_FAILED);
689     EXPECT_EQ(callback2.state, STATE_FAILED);
690
691     state = STATE_SUCCEEDED;
692     hcb_->setState(STATE_SUCCEEDED);
693     callback2.setState(STATE_SUCCEEDED);
694   }
695 };
696
697 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
698  public:
699   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
700   SSLServerAcceptCallbackBase(hcb)  {}
701
702   void connAccepted(
703     const std::shared_ptr<folly::AsyncSSLSocket> &s)
704     noexcept override {
705     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
706
707     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
708
709     hcb_->setSocket(sock);
710     sock->getEventBase()->tryRunAfterDelay([=] {
711         std::cerr << "Delayed SSL accept, client will have close by now"
712                   << std::endl;
713         // SSL accept will fail
714         EXPECT_EQ(
715           sock->getSSLState(),
716           AsyncSSLSocket::STATE_UNINIT);
717         hcb_->socket_->sslAccept(hcb_);
718         // This registers for an event
719         EXPECT_EQ(
720           sock->getSSLState(),
721           AsyncSSLSocket::STATE_ACCEPTING);
722
723         state = STATE_SUCCEEDED;
724       }, 100);
725   }
726 };
727
728 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
729  public:
730   ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
731     // We don't care if we get invoked or not.
732     // The client may time out and give up before connAccepted() is even
733     // called.
734     state = STATE_SUCCEEDED;
735   }
736
737   void connAccepted(
738       const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
739     std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
740
741     // Just wait a while before closing the socket, so the client
742     // will time out waiting for the handshake to complete.
743     s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
744   }
745 };
746
747 class TestSSLAsyncCacheServer : public TestSSLServer {
748  public:
749   explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
750         int lookupDelay = 100) :
751       TestSSLServer(acb) {
752     SSL_CTX *sslCtx = ctx_->getSSLCtx();
753 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
754     SSL_CTX_sess_set_get_cb(sslCtx,
755                             TestSSLAsyncCacheServer::getSessionCallback);
756 #endif
757     SSL_CTX_set_session_cache_mode(
758       sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
759     asyncCallbacks_ = 0;
760     asyncLookups_ = 0;
761     lookupDelay_ = lookupDelay;
762   }
763
764   uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
765   uint32_t getAsyncLookups() const { return asyncLookups_; }
766
767  private:
768   static uint32_t asyncCallbacks_;
769   static uint32_t asyncLookups_;
770   static uint32_t lookupDelay_;
771
772   static SSL_SESSION* getSessionCallback(SSL* ssl,
773                                          unsigned char* /* sess_id */,
774                                          int /* id_len */,
775                                          int* copyflag) {
776     *copyflag = 0;
777     asyncCallbacks_++;
778     (void)ssl;
779 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
780     if (!SSL_want_sess_cache_lookup(ssl)) {
781       // libssl.so mismatch
782       std::cerr << "no async support" << std::endl;
783       return nullptr;
784     }
785
786     AsyncSSLSocket *sslSocket =
787         AsyncSSLSocket::getFromSSL(ssl);
788     assert(sslSocket != nullptr);
789     // Going to simulate an async cache by just running delaying the miss 100ms
790     if (asyncCallbacks_ % 2 == 0) {
791       // This socket is already blocked on lookup, return miss
792       std::cerr << "returning miss" << std::endl;
793     } else {
794       // fresh meat - block it
795       std::cerr << "async lookup" << std::endl;
796       sslSocket->getEventBase()->tryRunAfterDelay(
797         std::bind(&AsyncSSLSocket::restartSSLAccept,
798                   sslSocket), lookupDelay_);
799       *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
800       asyncLookups_++;
801     }
802 #endif
803     return nullptr;
804   }
805 };
806
807 void getfds(int fds[2]);
808
809 void getctx(
810   std::shared_ptr<folly::SSLContext> clientCtx,
811   std::shared_ptr<folly::SSLContext> serverCtx);
812
813 void sslsocketpair(
814   EventBase* eventBase,
815   AsyncSSLSocket::UniquePtr* clientSock,
816   AsyncSSLSocket::UniquePtr* serverSock);
817
818 class BlockingWriteClient :
819   private AsyncSSLSocket::HandshakeCB,
820   private AsyncTransportWrapper::WriteCallback {
821  public:
822   explicit BlockingWriteClient(
823     AsyncSSLSocket::UniquePtr socket)
824     : socket_(std::move(socket)),
825       bufLen_(2500),
826       iovCount_(2000) {
827     // Fill buf_
828     buf_ = std::make_unique<uint8_t[]>(bufLen_);
829     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
830       buf_[n] = n % 0xff;
831     }
832
833     // Initialize iov_
834     iov_ = std::make_unique<struct iovec[]>(iovCount_);
835     for (uint32_t n = 0; n < iovCount_; ++n) {
836       iov_[n].iov_base = buf_.get() + n;
837       if (n & 0x1) {
838         iov_[n].iov_len = n % bufLen_;
839       } else {
840         iov_[n].iov_len = bufLen_ - (n % bufLen_);
841       }
842     }
843
844     socket_->sslConn(this, std::chrono::milliseconds(100));
845   }
846
847   struct iovec* getIovec() const {
848     return iov_.get();
849   }
850   uint32_t getIovecCount() const {
851     return iovCount_;
852   }
853
854  private:
855   void handshakeSuc(AsyncSSLSocket*) noexcept override {
856     socket_->writev(this, iov_.get(), iovCount_);
857   }
858   void handshakeErr(
859     AsyncSSLSocket*,
860     const AsyncSocketException& ex) noexcept override {
861     ADD_FAILURE() << "client handshake error: " << ex.what();
862   }
863   void writeSuccess() noexcept override {
864     socket_->close();
865   }
866   void writeErr(
867     size_t bytesWritten,
868     const AsyncSocketException& ex) noexcept override {
869     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
870                   << ex.what();
871   }
872
873   AsyncSSLSocket::UniquePtr socket_;
874   uint32_t bufLen_;
875   uint32_t iovCount_;
876   std::unique_ptr<uint8_t[]> buf_;
877   std::unique_ptr<struct iovec[]> iov_;
878 };
879
880 class BlockingWriteServer :
881     private AsyncSSLSocket::HandshakeCB,
882     private AsyncTransportWrapper::ReadCallback {
883  public:
884   explicit BlockingWriteServer(
885     AsyncSSLSocket::UniquePtr socket)
886     : socket_(std::move(socket)),
887       bufSize_(2500 * 2000),
888       bytesRead_(0) {
889     buf_ = std::make_unique<uint8_t[]>(bufSize_);
890     socket_->sslAccept(this, std::chrono::milliseconds(100));
891   }
892
893   void checkBuffer(struct iovec* iov, uint32_t count) const {
894     uint32_t idx = 0;
895     for (uint32_t n = 0; n < count; ++n) {
896       size_t bytesLeft = bytesRead_ - idx;
897       int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
898                       std::min(iov[n].iov_len, bytesLeft));
899       if (rc != 0) {
900         FAIL() << "buffer mismatch at iovec " << n << "/" << count
901                << ": rc=" << rc;
902
903       }
904       if (iov[n].iov_len > bytesLeft) {
905         FAIL() << "server did not read enough data: "
906                << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
907                << " in iovec " << n << "/" << count;
908       }
909
910       idx += iov[n].iov_len;
911     }
912     if (idx != bytesRead_) {
913       ADD_FAILURE() << "server read extra data: " << bytesRead_
914                     << " bytes read; expected " << idx;
915     }
916   }
917
918  private:
919   void handshakeSuc(AsyncSSLSocket*) noexcept override {
920     // Wait 10ms before reading, so the client's writes will initially block.
921     socket_->getEventBase()->tryRunAfterDelay(
922         [this] { socket_->setReadCB(this); }, 10);
923   }
924   void handshakeErr(
925     AsyncSSLSocket*,
926     const AsyncSocketException& ex) noexcept override {
927     ADD_FAILURE() << "server handshake error: " << ex.what();
928   }
929   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
930     *bufReturn = buf_.get() + bytesRead_;
931     *lenReturn = bufSize_ - bytesRead_;
932   }
933   void readDataAvailable(size_t len) noexcept override {
934     bytesRead_ += len;
935     socket_->setReadCB(nullptr);
936     socket_->getEventBase()->tryRunAfterDelay(
937         [this] { socket_->setReadCB(this); }, 2);
938   }
939   void readEOF() noexcept override {
940     socket_->close();
941   }
942   void readErr(
943     const AsyncSocketException& ex) noexcept override {
944     ADD_FAILURE() << "server read error: " << ex.what();
945   }
946
947   AsyncSSLSocket::UniquePtr socket_;
948   uint32_t bufSize_;
949   uint32_t bytesRead_;
950   std::unique_ptr<uint8_t[]> buf_;
951 };
952
953 class NpnClient :
954   private AsyncSSLSocket::HandshakeCB,
955   private AsyncTransportWrapper::WriteCallback {
956  public:
957   explicit NpnClient(
958     AsyncSSLSocket::UniquePtr socket)
959       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
960     socket_->sslConn(this);
961   }
962
963   const unsigned char* nextProto;
964   unsigned nextProtoLength;
965   SSLContext::NextProtocolType protocolType;
966   folly::Optional<AsyncSocketException> except;
967
968  private:
969   void handshakeSuc(AsyncSSLSocket*) noexcept override {
970     socket_->getSelectedNextProtocol(
971         &nextProto, &nextProtoLength, &protocolType);
972   }
973   void handshakeErr(
974     AsyncSSLSocket*,
975     const AsyncSocketException& ex) noexcept override {
976     except = ex;
977   }
978   void writeSuccess() noexcept override {
979     socket_->close();
980   }
981   void writeErr(
982     size_t bytesWritten,
983     const AsyncSocketException& ex) noexcept override {
984     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
985                   << ex.what();
986   }
987
988   AsyncSSLSocket::UniquePtr socket_;
989 };
990
991 class NpnServer :
992     private AsyncSSLSocket::HandshakeCB,
993     private AsyncTransportWrapper::ReadCallback {
994  public:
995   explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
996       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
997     socket_->sslAccept(this);
998   }
999
1000   const unsigned char* nextProto;
1001   unsigned nextProtoLength;
1002   SSLContext::NextProtocolType protocolType;
1003   folly::Optional<AsyncSocketException> except;
1004
1005  private:
1006   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1007     socket_->getSelectedNextProtocol(
1008         &nextProto, &nextProtoLength, &protocolType);
1009   }
1010   void handshakeErr(
1011     AsyncSSLSocket*,
1012     const AsyncSocketException& ex) noexcept override {
1013     except = ex;
1014   }
1015   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1016     *lenReturn = 0;
1017   }
1018   void readDataAvailable(size_t /* len */) noexcept override {}
1019   void readEOF() noexcept override {
1020     socket_->close();
1021   }
1022   void readErr(
1023     const AsyncSocketException& ex) noexcept override {
1024     ADD_FAILURE() << "server read error: " << ex.what();
1025   }
1026
1027   AsyncSSLSocket::UniquePtr socket_;
1028 };
1029
1030 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
1031                             public AsyncTransportWrapper::ReadCallback {
1032  public:
1033   explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
1034       : socket_(std::move(socket)) {
1035     socket_->sslAccept(this);
1036   }
1037
1038   ~RenegotiatingServer() override {
1039     socket_->setReadCB(nullptr);
1040   }
1041
1042   void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
1043     LOG(INFO) << "Renegotiating server handshake success";
1044     socket_->setReadCB(this);
1045   }
1046   void handshakeErr(
1047       AsyncSSLSocket*,
1048       const AsyncSocketException& ex) noexcept override {
1049     ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
1050   }
1051   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1052     *lenReturn = sizeof(buf);
1053     *bufReturn = buf;
1054   }
1055   void readDataAvailable(size_t /* len */) noexcept override {}
1056   void readEOF() noexcept override {}
1057   void readErr(const AsyncSocketException& ex) noexcept override {
1058     LOG(INFO) << "server got read error " << ex.what();
1059     auto exPtr = dynamic_cast<const SSLException*>(&ex);
1060     ASSERT_NE(nullptr, exPtr);
1061     std::string exStr(ex.what());
1062     SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
1063     ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
1064     renegotiationError_ = true;
1065   }
1066
1067   AsyncSSLSocket::UniquePtr socket_;
1068   unsigned char buf[128];
1069   bool renegotiationError_{false};
1070 };
1071
1072 #ifndef OPENSSL_NO_TLSEXT
1073 class SNIClient :
1074   private AsyncSSLSocket::HandshakeCB,
1075   private AsyncTransportWrapper::WriteCallback {
1076  public:
1077   explicit SNIClient(
1078     AsyncSSLSocket::UniquePtr socket)
1079       : serverNameMatch(false), socket_(std::move(socket)) {
1080     socket_->sslConn(this);
1081   }
1082
1083   bool serverNameMatch;
1084
1085  private:
1086   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1087     serverNameMatch = socket_->isServerNameMatch();
1088   }
1089   void handshakeErr(
1090     AsyncSSLSocket*,
1091     const AsyncSocketException& ex) noexcept override {
1092     ADD_FAILURE() << "client handshake error: " << ex.what();
1093   }
1094   void writeSuccess() noexcept override {
1095     socket_->close();
1096   }
1097   void writeErr(
1098     size_t bytesWritten,
1099     const AsyncSocketException& ex) noexcept override {
1100     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1101                   << ex.what();
1102   }
1103
1104   AsyncSSLSocket::UniquePtr socket_;
1105 };
1106
1107 class SNIServer :
1108     private AsyncSSLSocket::HandshakeCB,
1109     private AsyncTransportWrapper::ReadCallback {
1110  public:
1111   explicit SNIServer(
1112     AsyncSSLSocket::UniquePtr socket,
1113     const std::shared_ptr<folly::SSLContext>& ctx,
1114     const std::shared_ptr<folly::SSLContext>& sniCtx,
1115     const std::string& expectedServerName)
1116       : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1117         expectedServerName_(expectedServerName) {
1118     ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1119                                          std::placeholders::_1));
1120     socket_->sslAccept(this);
1121   }
1122
1123   bool serverNameMatch;
1124
1125  private:
1126   void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1127   void handshakeErr(
1128     AsyncSSLSocket*,
1129     const AsyncSocketException& ex) noexcept override {
1130     ADD_FAILURE() << "server handshake error: " << ex.what();
1131   }
1132   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1133     *lenReturn = 0;
1134   }
1135   void readDataAvailable(size_t /* len */) noexcept override {}
1136   void readEOF() noexcept override {
1137     socket_->close();
1138   }
1139   void readErr(
1140     const AsyncSocketException& ex) noexcept override {
1141     ADD_FAILURE() << "server read error: " << ex.what();
1142   }
1143
1144   folly::SSLContext::ServerNameCallbackResult
1145     serverNameCallback(SSL *ssl) {
1146     const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1147     if (sniCtx_ &&
1148         sn &&
1149         !strcasecmp(expectedServerName_.c_str(), sn)) {
1150       AsyncSSLSocket *sslSocket =
1151           AsyncSSLSocket::getFromSSL(ssl);
1152       sslSocket->switchServerSSLContext(sniCtx_);
1153       serverNameMatch = true;
1154       return folly::SSLContext::SERVER_NAME_FOUND;
1155     } else {
1156       serverNameMatch = false;
1157       return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1158     }
1159   }
1160
1161   AsyncSSLSocket::UniquePtr socket_;
1162   std::shared_ptr<folly::SSLContext> sniCtx_;
1163   std::string expectedServerName_;
1164 };
1165 #endif
1166
1167 class SSLClient : public AsyncSocket::ConnectCallback,
1168                   public AsyncTransportWrapper::WriteCallback,
1169                   public AsyncTransportWrapper::ReadCallback
1170 {
1171  private:
1172   EventBase *eventBase_;
1173   std::shared_ptr<AsyncSSLSocket> sslSocket_;
1174   SSL_SESSION *session_;
1175   std::shared_ptr<folly::SSLContext> ctx_;
1176   uint32_t requests_;
1177   folly::SocketAddress address_;
1178   uint32_t timeout_;
1179   char buf_[128];
1180   char readbuf_[128];
1181   uint32_t bytesRead_;
1182   uint32_t hit_;
1183   uint32_t miss_;
1184   uint32_t errors_;
1185   uint32_t writeAfterConnectErrors_;
1186
1187   // These settings test that we eventually drain the
1188   // socket, even if the maxReadsPerEvent_ is hit during
1189   // a event loop iteration.
1190   static constexpr size_t kMaxReadsPerEvent = 2;
1191   // 2 event loop iterations
1192   static constexpr size_t kMaxReadBufferSz =
1193     sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1194
1195  public:
1196   SSLClient(EventBase *eventBase,
1197             const folly::SocketAddress& address,
1198             uint32_t requests,
1199             uint32_t timeout = 0)
1200       : eventBase_(eventBase),
1201         session_(nullptr),
1202         requests_(requests),
1203         address_(address),
1204         timeout_(timeout),
1205         bytesRead_(0),
1206         hit_(0),
1207         miss_(0),
1208         errors_(0),
1209         writeAfterConnectErrors_(0) {
1210     ctx_.reset(new folly::SSLContext());
1211     ctx_->setOptions(SSL_OP_NO_TICKET);
1212     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1213     memset(buf_, 'a', sizeof(buf_));
1214   }
1215
1216   ~SSLClient() override {
1217     if (session_) {
1218       SSL_SESSION_free(session_);
1219     }
1220     if (errors_ == 0) {
1221       EXPECT_EQ(bytesRead_, sizeof(buf_));
1222     }
1223   }
1224
1225   uint32_t getHit() const { return hit_; }
1226
1227   uint32_t getMiss() const { return miss_; }
1228
1229   uint32_t getErrors() const { return errors_; }
1230
1231   uint32_t getWriteAfterConnectErrors() const {
1232     return writeAfterConnectErrors_;
1233   }
1234
1235   void connect(bool writeNow = false) {
1236     sslSocket_ = AsyncSSLSocket::newSocket(
1237       ctx_, eventBase_);
1238     if (session_ != nullptr) {
1239       sslSocket_->setSSLSession(session_);
1240     }
1241     requests_--;
1242     sslSocket_->connect(this, address_, timeout_);
1243     if (sslSocket_ && writeNow) {
1244       // write some junk, used in an error test
1245       sslSocket_->write(this, buf_, sizeof(buf_));
1246     }
1247   }
1248
1249   void connectSuccess() noexcept override {
1250     std::cerr << "client SSL socket connected" << std::endl;
1251     if (sslSocket_->getSSLSessionReused()) {
1252       hit_++;
1253     } else {
1254       miss_++;
1255       if (session_ != nullptr) {
1256         SSL_SESSION_free(session_);
1257       }
1258       session_ = sslSocket_->getSSLSession();
1259     }
1260
1261     // write()
1262     sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1263     sslSocket_->write(this, buf_, sizeof(buf_));
1264     sslSocket_->setReadCB(this);
1265     memset(readbuf_, 'b', sizeof(readbuf_));
1266     bytesRead_ = 0;
1267   }
1268
1269   void connectErr(
1270     const AsyncSocketException& ex) noexcept override {
1271     std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1272     errors_++;
1273     sslSocket_.reset();
1274   }
1275
1276   void writeSuccess() noexcept override {
1277     std::cerr << "client write success" << std::endl;
1278   }
1279
1280   void writeErr(size_t /* bytesWritten */,
1281                 const AsyncSocketException& ex) noexcept override {
1282     std::cerr << "client writeError: " << ex.what() << std::endl;
1283     if (!sslSocket_) {
1284       writeAfterConnectErrors_++;
1285     }
1286   }
1287
1288   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1289     *bufReturn = readbuf_ + bytesRead_;
1290     *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1291   }
1292
1293   void readEOF() noexcept override {
1294     std::cerr << "client readEOF" << std::endl;
1295   }
1296
1297   void readErr(
1298     const AsyncSocketException& ex) noexcept override {
1299     std::cerr << "client readError: " << ex.what() << std::endl;
1300   }
1301
1302   void readDataAvailable(size_t len) noexcept override {
1303     std::cerr << "client read data: " << len << std::endl;
1304     bytesRead_ += len;
1305     if (bytesRead_ == sizeof(buf_)) {
1306       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1307       sslSocket_->closeNow();
1308       sslSocket_.reset();
1309       if (requests_ != 0) {
1310         connect();
1311       }
1312     }
1313   }
1314
1315 };
1316
1317 class SSLHandshakeBase :
1318   public AsyncSSLSocket::HandshakeCB,
1319   private AsyncTransportWrapper::WriteCallback {
1320  public:
1321   explicit SSLHandshakeBase(
1322    AsyncSSLSocket::UniquePtr socket,
1323    bool preverifyResult,
1324    bool verifyResult) :
1325     handshakeVerify_(false),
1326     handshakeSuccess_(false),
1327     handshakeError_(false),
1328     socket_(std::move(socket)),
1329     preverifyResult_(preverifyResult),
1330     verifyResult_(verifyResult) {
1331   }
1332
1333   AsyncSSLSocket::UniquePtr moveSocket() && {
1334     return std::move(socket_);
1335   }
1336
1337   bool handshakeVerify_;
1338   bool handshakeSuccess_;
1339   bool handshakeError_;
1340   std::chrono::nanoseconds handshakeTime;
1341
1342  protected:
1343   AsyncSSLSocket::UniquePtr socket_;
1344   bool preverifyResult_;
1345   bool verifyResult_;
1346
1347   // HandshakeCallback
1348   bool handshakeVer(AsyncSSLSocket* /* sock */,
1349                     bool preverifyOk,
1350                     X509_STORE_CTX* /* ctx */) noexcept override {
1351     handshakeVerify_ = true;
1352
1353     EXPECT_EQ(preverifyResult_, preverifyOk);
1354     return verifyResult_;
1355   }
1356
1357   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1358     LOG(INFO) << "Handshake success";
1359     handshakeSuccess_ = true;
1360     handshakeTime = socket_->getHandshakeTime();
1361   }
1362
1363   void handshakeErr(
1364       AsyncSSLSocket*,
1365       const AsyncSocketException& ex) noexcept override {
1366     LOG(INFO) << "Handshake error " << ex.what();
1367     handshakeError_ = true;
1368     handshakeTime = socket_->getHandshakeTime();
1369   }
1370
1371   // WriteCallback
1372   void writeSuccess() noexcept override {
1373     socket_->close();
1374   }
1375
1376   void writeErr(
1377    size_t bytesWritten,
1378    const AsyncSocketException& ex) noexcept override {
1379     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1380                   << ex.what();
1381   }
1382 };
1383
1384 class SSLHandshakeClient : public SSLHandshakeBase {
1385  public:
1386   SSLHandshakeClient(
1387    AsyncSSLSocket::UniquePtr socket,
1388    bool preverifyResult,
1389    bool verifyResult) :
1390     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1391     socket_->sslConn(this, std::chrono::milliseconds::zero());
1392   }
1393 };
1394
1395 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1396  public:
1397   SSLHandshakeClientNoVerify(
1398    AsyncSSLSocket::UniquePtr socket,
1399    bool preverifyResult,
1400    bool verifyResult) :
1401     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1402     socket_->sslConn(
1403         this,
1404         std::chrono::milliseconds::zero(),
1405         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1406   }
1407 };
1408
1409 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1410  public:
1411   SSLHandshakeClientDoVerify(
1412    AsyncSSLSocket::UniquePtr socket,
1413    bool preverifyResult,
1414    bool verifyResult) :
1415     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1416     socket_->sslConn(
1417         this,
1418         std::chrono::milliseconds::zero(),
1419         folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1420   }
1421 };
1422
1423 class SSLHandshakeServer : public SSLHandshakeBase {
1424  public:
1425   SSLHandshakeServer(
1426       AsyncSSLSocket::UniquePtr socket,
1427       bool preverifyResult,
1428       bool verifyResult)
1429     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1430     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1431   }
1432 };
1433
1434 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1435  public:
1436   SSLHandshakeServerParseClientHello(
1437       AsyncSSLSocket::UniquePtr socket,
1438       bool preverifyResult,
1439       bool verifyResult)
1440       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1441     socket_->enableClientHelloParsing();
1442     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1443   }
1444
1445   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1446
1447  protected:
1448   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1449     handshakeSuccess_ = true;
1450     sock->getSSLSharedCiphers(sharedCiphers_);
1451     sock->getSSLServerCiphers(serverCiphers_);
1452     sock->getSSLClientCiphers(clientCiphers_);
1453     chosenCipher_ = sock->getNegotiatedCipherName();
1454   }
1455 };
1456
1457
1458 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1459  public:
1460   SSLHandshakeServerNoVerify(
1461       AsyncSSLSocket::UniquePtr socket,
1462       bool preverifyResult,
1463       bool verifyResult)
1464     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1465     socket_->sslAccept(
1466         this,
1467         std::chrono::milliseconds::zero(),
1468         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1469   }
1470 };
1471
1472 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1473  public:
1474   SSLHandshakeServerDoVerify(
1475       AsyncSSLSocket::UniquePtr socket,
1476       bool preverifyResult,
1477       bool verifyResult)
1478     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1479     socket_->sslAccept(
1480         this,
1481         std::chrono::milliseconds::zero(),
1482         folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1483   }
1484 };
1485
1486 class EventBaseAborter : public AsyncTimeout {
1487  public:
1488   EventBaseAborter(EventBase* eventBase,
1489                    uint32_t timeoutMS)
1490     : AsyncTimeout(
1491       eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1492     , eventBase_(eventBase) {
1493     scheduleTimeout(timeoutMS);
1494   }
1495
1496   void timeoutExpired() noexcept override {
1497     FAIL() << "test timed out";
1498     eventBase_->terminateLoopSoon();
1499   }
1500
1501  private:
1502   EventBase* eventBase_;
1503 };
1504
1505 } // namespace folly