Add SO_ZEROCOPY support
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <signal.h>
19
20 #include <folly/ExceptionWrapper.h>
21 #include <folly/SocketAddress.h>
22 #include <folly/experimental/TestUtil.h>
23 #include <folly/io/async/AsyncSSLSocket.h>
24 #include <folly/io/async/AsyncServerSocket.h>
25 #include <folly/io/async/AsyncSocket.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/io/async/AsyncTransport.h>
28 #include <folly/io/async/EventBase.h>
29 #include <folly/io/async/ssl/SSLErrors.h>
30 #include <folly/io/async/test/TestSSLServer.h>
31 #include <folly/portability/GTest.h>
32 #include <folly/portability/PThread.h>
33 #include <folly/portability/Sockets.h>
34 #include <folly/portability/Unistd.h>
35
36 #include <fcntl.h>
37 #include <sys/types.h>
38 #include <condition_variable>
39 #include <iostream>
40 #include <list>
41
42 namespace folly {
43
44 // The destructors of all callback classes assert that the state is
45 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
46 // are responsible for setting the succeeded state properly before the
47 // destructors are called.
48
49 class SendMsgParamsCallbackBase :
50       public folly::AsyncSocket::SendMsgParamsCallback {
51  public:
52   SendMsgParamsCallbackBase() {}
53
54   void setSocket(
55     const std::shared_ptr<AsyncSSLSocket> &socket) {
56     socket_ = socket;
57     oldCallback_ = socket_->getSendMsgParamsCB();
58     socket_->setSendMsgParamCB(this);
59   }
60
61   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
62                                                                      override {
63     return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
64   }
65
66   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
67     oldCallback_->getAncillaryData(flags, data);
68   }
69
70   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
71     return oldCallback_->getAncillaryDataSize(flags);
72   }
73
74   std::shared_ptr<AsyncSSLSocket> socket_;
75   folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
76 };
77
78 class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
79  public:
80   SendMsgFlagsCallback() {}
81
82   void resetFlags(int flags) {
83     flags_ = flags;
84   }
85
86   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
87                                                                   override {
88     if (flags_) {
89       return flags_;
90     } else {
91       return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
92     }
93   }
94
95   int flags_{0};
96 };
97
98 class SendMsgDataCallback : public SendMsgFlagsCallback {
99  public:
100   SendMsgDataCallback() {}
101
102   void resetData(std::vector<char>&& data) {
103     ancillaryData_.swap(data);
104   }
105
106   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
107     if (ancillaryData_.size()) {
108       std::cerr << "getAncillaryData: copying data" << std::endl;
109       memcpy(data, ancillaryData_.data(), ancillaryData_.size());
110     } else {
111       oldCallback_->getAncillaryData(flags, data);
112     }
113   }
114
115   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
116     if (ancillaryData_.size()) {
117       std::cerr << "getAncillaryDataSize: returning size" << std::endl;
118       return ancillaryData_.size();
119     } else {
120       return oldCallback_->getAncillaryDataSize(flags);
121     }
122   }
123
124   std::vector<char> ancillaryData_;
125 };
126
127 class WriteCallbackBase :
128 public AsyncTransportWrapper::WriteCallback {
129  public:
130   explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
131       : state(STATE_WAITING)
132       , bytesWritten(0)
133       , exception(AsyncSocketException::UNKNOWN, "none")
134       , mcb_(mcb) {}
135
136   ~WriteCallbackBase() override {
137     EXPECT_EQ(STATE_SUCCEEDED, state);
138   }
139
140   virtual void setSocket(
141     const std::shared_ptr<AsyncSSLSocket> &socket) {
142     socket_ = socket;
143     if (mcb_) {
144       mcb_->setSocket(socket);
145     }
146   }
147
148   void writeSuccess() noexcept override {
149     std::cerr << "writeSuccess" << std::endl;
150     state = STATE_SUCCEEDED;
151   }
152
153   void writeErr(
154     size_t nBytesWritten,
155     const AsyncSocketException& ex) noexcept override {
156     std::cerr << "writeError: bytesWritten " << nBytesWritten
157          << ", exception " << ex.what() << std::endl;
158
159     state = STATE_FAILED;
160     this->bytesWritten = nBytesWritten;
161     exception = ex;
162     socket_->close();
163   }
164
165   std::shared_ptr<AsyncSSLSocket> socket_;
166   StateEnum state;
167   size_t bytesWritten;
168   AsyncSocketException exception;
169   SendMsgParamsCallbackBase* mcb_;
170 };
171
172 class ExpectWriteErrorCallback :
173 public WriteCallbackBase {
174  public:
175   explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
176       : WriteCallbackBase(mcb) {}
177
178   ~ExpectWriteErrorCallback() override {
179     EXPECT_EQ(STATE_FAILED, state);
180     EXPECT_EQ(exception.type_,
181              AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
182     EXPECT_EQ(exception.errno_, 22);
183     // Suppress the assert in  ~WriteCallbackBase()
184     state = STATE_SUCCEEDED;
185   }
186 };
187
188 #ifdef MSG_ERRQUEUE
189 /* copied from include/uapi/linux/net_tstamp.h */
190 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
191 enum SOF_TIMESTAMPING {
192   SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
193   SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
194   SOF_TIMESTAMPING_OPT_ID = (1 << 7),
195   SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
196   SOF_TIMESTAMPING_TX_ACK = (1 << 9),
197   SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
198 };
199
200 class WriteCheckTimestampCallback :
201   public WriteCallbackBase {
202  public:
203   explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
204     : WriteCallbackBase(mcb) {}
205
206   ~WriteCheckTimestampCallback() override {
207     EXPECT_EQ(STATE_SUCCEEDED, state);
208     EXPECT_TRUE(gotTimestamp_);
209     EXPECT_TRUE(gotByteSeq_);
210   }
211
212   void setSocket(
213     const std::shared_ptr<AsyncSSLSocket> &socket) override {
214     WriteCallbackBase::setSocket(socket);
215
216     EXPECT_NE(socket_->getFd(), 0);
217     int flags = SOF_TIMESTAMPING_OPT_ID
218                 | SOF_TIMESTAMPING_OPT_TSONLY
219                 | SOF_TIMESTAMPING_SOFTWARE;
220     AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
221     int ret = tstampingOpt.apply(socket_->getFd(), flags);
222     EXPECT_EQ(ret, 0);
223   }
224
225   void checkForTimestampNotifications() noexcept {
226     int fd = socket_->getFd();
227     std::vector<char> ctrl(1024, 0);
228     unsigned char data;
229     struct msghdr msg;
230     iovec entry;
231
232     memset(&msg, 0, sizeof(msg));
233     entry.iov_base = &data;
234     entry.iov_len = sizeof(data);
235     msg.msg_iov = &entry;
236     msg.msg_iovlen = 1;
237     msg.msg_control = ctrl.data();
238     msg.msg_controllen = ctrl.size();
239
240     int ret;
241     while (true) {
242       ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
243       if (ret < 0) {
244         if (errno != EAGAIN) {
245           auto errnoCopy = errno;
246           std::cerr << "::recvmsg exited with code " << ret
247                     << ", errno: " << errnoCopy << std::endl;
248           AsyncSocketException ex(
249             AsyncSocketException::INTERNAL_ERROR,
250             "recvmsg() failed",
251             errnoCopy);
252           exception = ex;
253         }
254         return;
255       }
256
257       for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
258            cmsg != nullptr && cmsg->cmsg_len != 0;
259            cmsg = CMSG_NXTHDR(&msg, cmsg)) {
260         if (cmsg->cmsg_level == SOL_SOCKET &&
261             cmsg->cmsg_type == SCM_TIMESTAMPING) {
262           gotTimestamp_ = true;
263           continue;
264         }
265
266         if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
267             (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
268           gotByteSeq_ = true;
269           continue;
270         }
271       }
272     }
273   }
274
275   bool gotTimestamp_{false};
276   bool gotByteSeq_{false};
277 };
278 #endif // MSG_ERRQUEUE
279
280 class ReadCallbackBase :
281 public AsyncTransportWrapper::ReadCallback {
282  public:
283   explicit ReadCallbackBase(WriteCallbackBase* wcb)
284       : wcb_(wcb), state(STATE_WAITING) {}
285
286   ~ReadCallbackBase() override {
287     EXPECT_EQ(STATE_SUCCEEDED, state);
288   }
289
290   void setSocket(
291     const std::shared_ptr<AsyncSSLSocket> &socket) {
292     socket_ = socket;
293   }
294
295   void setState(StateEnum s) {
296     state = s;
297     if (wcb_) {
298       wcb_->state = s;
299     }
300   }
301
302   void readErr(
303     const AsyncSocketException& ex) noexcept override {
304     std::cerr << "readError " << ex.what() << std::endl;
305     state = STATE_FAILED;
306     socket_->close();
307   }
308
309   void readEOF() noexcept override {
310     std::cerr << "readEOF" << std::endl;
311
312     socket_->close();
313   }
314
315   std::shared_ptr<AsyncSSLSocket> socket_;
316   WriteCallbackBase *wcb_;
317   StateEnum state;
318 };
319
320 class ReadCallback : public ReadCallbackBase {
321  public:
322   explicit ReadCallback(WriteCallbackBase *wcb)
323       : ReadCallbackBase(wcb)
324       , buffers() {}
325
326   ~ReadCallback() override {
327     for (std::vector<Buffer>::iterator it = buffers.begin();
328          it != buffers.end();
329          ++it) {
330       it->free();
331     }
332     currentBuffer.free();
333   }
334
335   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
336     if (!currentBuffer.buffer) {
337       currentBuffer.allocate(4096);
338     }
339     *bufReturn = currentBuffer.buffer;
340     *lenReturn = currentBuffer.length;
341   }
342
343   void readDataAvailable(size_t len) noexcept override {
344     std::cerr << "readDataAvailable, len " << len << std::endl;
345
346     currentBuffer.length = len;
347
348     wcb_->setSocket(socket_);
349
350     // Write back the same data.
351     socket_->write(wcb_, currentBuffer.buffer, len);
352
353     buffers.push_back(currentBuffer);
354     currentBuffer.reset();
355     state = STATE_SUCCEEDED;
356   }
357
358   class Buffer {
359    public:
360     Buffer() : buffer(nullptr), length(0) {}
361     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
362
363     void reset() {
364       buffer = nullptr;
365       length = 0;
366     }
367     void allocate(size_t len) {
368       assert(buffer == nullptr);
369       this->buffer = static_cast<char*>(malloc(len));
370       this->length = len;
371     }
372     void free() {
373       ::free(buffer);
374       reset();
375     }
376
377     char* buffer;
378     size_t length;
379   };
380
381   std::vector<Buffer> buffers;
382   Buffer currentBuffer;
383 };
384
385 class ReadErrorCallback : public ReadCallbackBase {
386  public:
387   explicit ReadErrorCallback(WriteCallbackBase *wcb)
388       : ReadCallbackBase(wcb) {}
389
390   // Return nullptr buffer to trigger readError()
391   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
392     *bufReturn = nullptr;
393     *lenReturn = 0;
394   }
395
396   void readDataAvailable(size_t /* len */) noexcept override {
397     // This should never to called.
398     FAIL();
399   }
400
401   void readErr(
402     const AsyncSocketException& ex) noexcept override {
403     ReadCallbackBase::readErr(ex);
404     std::cerr << "ReadErrorCallback::readError" << std::endl;
405     setState(STATE_SUCCEEDED);
406   }
407 };
408
409 class ReadEOFCallback : public ReadCallbackBase {
410  public:
411   explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
412
413   // Return nullptr buffer to trigger readError()
414   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
415     *bufReturn = nullptr;
416     *lenReturn = 0;
417   }
418
419   void readDataAvailable(size_t /* len */) noexcept override {
420     // This should never to called.
421     FAIL();
422   }
423
424   void readEOF() noexcept override {
425     ReadCallbackBase::readEOF();
426     setState(STATE_SUCCEEDED);
427   }
428 };
429
430 class WriteErrorCallback : public ReadCallback {
431  public:
432   explicit WriteErrorCallback(WriteCallbackBase *wcb)
433       : ReadCallback(wcb) {}
434
435   void readDataAvailable(size_t len) noexcept override {
436     std::cerr << "readDataAvailable, len " << len << std::endl;
437
438     currentBuffer.length = len;
439
440     // close the socket before writing to trigger writeError().
441     ::close(socket_->getFd());
442
443     wcb_->setSocket(socket_);
444
445     // Write back the same data.
446     folly::test::msvcSuppressAbortOnInvalidParams([&] {
447       socket_->write(wcb_, currentBuffer.buffer, len);
448     });
449
450     if (wcb_->state == STATE_FAILED) {
451       setState(STATE_SUCCEEDED);
452     } else {
453       state = STATE_FAILED;
454     }
455
456     buffers.push_back(currentBuffer);
457     currentBuffer.reset();
458   }
459
460   void readErr(const AsyncSocketException& ex) noexcept override {
461     std::cerr << "readError " << ex.what() << std::endl;
462     // do nothing since this is expected
463   }
464 };
465
466 class EmptyReadCallback : public ReadCallback {
467  public:
468   explicit EmptyReadCallback()
469       : ReadCallback(nullptr) {}
470
471   void readErr(const AsyncSocketException& ex) noexcept override {
472     std::cerr << "readError " << ex.what() << std::endl;
473     state = STATE_FAILED;
474     if (tcpSocket_) {
475       tcpSocket_->close();
476     }
477   }
478
479   void readEOF() noexcept override {
480     std::cerr << "readEOF" << std::endl;
481     if (tcpSocket_) {
482       tcpSocket_->close();
483     }
484     state = STATE_SUCCEEDED;
485   }
486
487   std::shared_ptr<AsyncSocket> tcpSocket_;
488 };
489
490 class HandshakeCallback :
491 public AsyncSSLSocket::HandshakeCB {
492  public:
493   enum ExpectType {
494     EXPECT_SUCCESS,
495     EXPECT_ERROR
496   };
497
498   explicit HandshakeCallback(ReadCallbackBase *rcb,
499                              ExpectType expect = EXPECT_SUCCESS):
500       state(STATE_WAITING),
501       rcb_(rcb),
502       expect_(expect) {}
503
504   void setSocket(
505     const std::shared_ptr<AsyncSSLSocket> &socket) {
506     socket_ = socket;
507   }
508
509   void setState(StateEnum s) {
510     state = s;
511     rcb_->setState(s);
512   }
513
514   // Functions inherited from AsyncSSLSocketHandshakeCallback
515   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
516     std::lock_guard<std::mutex> g(mutex_);
517     cv_.notify_all();
518     EXPECT_EQ(sock, socket_.get());
519     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
520     rcb_->setSocket(socket_);
521     sock->setReadCB(rcb_);
522     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
523   }
524   void handshakeErr(AsyncSSLSocket* /* sock */,
525                     const AsyncSocketException& ex) noexcept override {
526     std::lock_guard<std::mutex> g(mutex_);
527     cv_.notify_all();
528     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
529     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
530     if (expect_ == EXPECT_ERROR) {
531       // rcb will never be invoked
532       rcb_->setState(STATE_SUCCEEDED);
533     }
534     errorString_ = ex.what();
535   }
536
537   void waitForHandshake() {
538     std::unique_lock<std::mutex> lock(mutex_);
539     cv_.wait(lock, [this] { return state != STATE_WAITING; });
540   }
541
542   ~HandshakeCallback() override {
543     EXPECT_EQ(STATE_SUCCEEDED, state);
544   }
545
546   void closeSocket() {
547     socket_->close();
548     state = STATE_SUCCEEDED;
549   }
550
551   std::shared_ptr<AsyncSSLSocket> getSocket() {
552     return socket_;
553   }
554
555   StateEnum state;
556   std::shared_ptr<AsyncSSLSocket> socket_;
557   ReadCallbackBase *rcb_;
558   ExpectType expect_;
559   std::mutex mutex_;
560   std::condition_variable cv_;
561   std::string errorString_;
562 };
563
564 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
565  public:
566   uint32_t timeout_;
567
568   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
569                                    uint32_t timeout = 0):
570       SSLServerAcceptCallbackBase(hcb),
571       timeout_(timeout) {}
572
573   ~SSLServerAcceptCallback() override {
574     if (timeout_ > 0) {
575       // if we set a timeout, we expect failure
576       EXPECT_EQ(hcb_->state, STATE_FAILED);
577       hcb_->setState(STATE_SUCCEEDED);
578     }
579   }
580
581   void connAccepted(
582     const std::shared_ptr<folly::AsyncSSLSocket> &s)
583     noexcept override {
584     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
585     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
586
587     hcb_->setSocket(sock);
588     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
589     EXPECT_EQ(sock->getSSLState(),
590                       AsyncSSLSocket::STATE_ACCEPTING);
591
592     state = STATE_SUCCEEDED;
593   }
594 };
595
596 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
597  public:
598   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
599       SSLServerAcceptCallback(hcb) {}
600
601   void connAccepted(
602     const std::shared_ptr<folly::AsyncSSLSocket> &s)
603     noexcept override {
604
605     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
606
607     std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
608               << std::endl;
609     int fd = sock->getFd();
610
611 #ifndef TCP_NOPUSH
612     {
613     // The accepted connection should already have TCP_NODELAY set
614     int value;
615     socklen_t valueLength = sizeof(value);
616     int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
617     EXPECT_EQ(rc, 0);
618     EXPECT_EQ(value, 1);
619     }
620 #endif
621
622     // Unset the TCP_NODELAY option.
623     int value = 0;
624     socklen_t valueLength = sizeof(value);
625     int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
626     EXPECT_EQ(rc, 0);
627
628     rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
629     EXPECT_EQ(rc, 0);
630     EXPECT_EQ(value, 0);
631
632     SSLServerAcceptCallback::connAccepted(sock);
633   }
634 };
635
636 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
637  public:
638   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
639                                              uint32_t timeout = 0):
640     SSLServerAcceptCallback(hcb, timeout) {}
641
642   void connAccepted(
643     const std::shared_ptr<folly::AsyncSSLSocket> &s)
644     noexcept override {
645     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
646
647     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
648
649     hcb_->setSocket(sock);
650     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
651     ASSERT_TRUE((sock->getSSLState() ==
652                  AsyncSSLSocket::STATE_ACCEPTING) ||
653                 (sock->getSSLState() ==
654                  AsyncSSLSocket::STATE_CACHE_LOOKUP));
655
656     state = STATE_SUCCEEDED;
657   }
658 };
659
660
661 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
662  public:
663   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
664   SSLServerAcceptCallbackBase(hcb)  {}
665
666   void connAccepted(
667     const std::shared_ptr<folly::AsyncSSLSocket> &s)
668     noexcept override {
669     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
670
671     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
672
673     // The first call to sslAccept() should succeed.
674     hcb_->setSocket(sock);
675     sock->sslAccept(hcb_);
676     EXPECT_EQ(sock->getSSLState(),
677                       AsyncSSLSocket::STATE_ACCEPTING);
678
679     // The second call to sslAccept() should fail.
680     HandshakeCallback callback2(hcb_->rcb_);
681     callback2.setSocket(sock);
682     sock->sslAccept(&callback2);
683     EXPECT_EQ(sock->getSSLState(),
684                       AsyncSSLSocket::STATE_ERROR);
685
686     // Both callbacks should be in the error state.
687     EXPECT_EQ(hcb_->state, STATE_FAILED);
688     EXPECT_EQ(callback2.state, STATE_FAILED);
689
690     state = STATE_SUCCEEDED;
691     hcb_->setState(STATE_SUCCEEDED);
692     callback2.setState(STATE_SUCCEEDED);
693   }
694 };
695
696 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
697  public:
698   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
699   SSLServerAcceptCallbackBase(hcb)  {}
700
701   void connAccepted(
702     const std::shared_ptr<folly::AsyncSSLSocket> &s)
703     noexcept override {
704     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
705
706     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
707
708     hcb_->setSocket(sock);
709     sock->getEventBase()->tryRunAfterDelay([=] {
710         std::cerr << "Delayed SSL accept, client will have close by now"
711                   << std::endl;
712         // SSL accept will fail
713         EXPECT_EQ(
714           sock->getSSLState(),
715           AsyncSSLSocket::STATE_UNINIT);
716         hcb_->socket_->sslAccept(hcb_);
717         // This registers for an event
718         EXPECT_EQ(
719           sock->getSSLState(),
720           AsyncSSLSocket::STATE_ACCEPTING);
721
722         state = STATE_SUCCEEDED;
723       }, 100);
724   }
725 };
726
727 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
728  public:
729   ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
730     // We don't care if we get invoked or not.
731     // The client may time out and give up before connAccepted() is even
732     // called.
733     state = STATE_SUCCEEDED;
734   }
735
736   void connAccepted(
737       const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
738     std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
739
740     // Just wait a while before closing the socket, so the client
741     // will time out waiting for the handshake to complete.
742     s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
743   }
744 };
745
746 class TestSSLAsyncCacheServer : public TestSSLServer {
747  public:
748   explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
749         int lookupDelay = 100) :
750       TestSSLServer(acb) {
751     SSL_CTX *sslCtx = ctx_->getSSLCtx();
752 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
753     SSL_CTX_sess_set_get_cb(sslCtx,
754                             TestSSLAsyncCacheServer::getSessionCallback);
755 #endif
756     SSL_CTX_set_session_cache_mode(
757       sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
758     asyncCallbacks_ = 0;
759     asyncLookups_ = 0;
760     lookupDelay_ = lookupDelay;
761   }
762
763   uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
764   uint32_t getAsyncLookups() const { return asyncLookups_; }
765
766  private:
767   static uint32_t asyncCallbacks_;
768   static uint32_t asyncLookups_;
769   static uint32_t lookupDelay_;
770
771   static SSL_SESSION* getSessionCallback(SSL* ssl,
772                                          unsigned char* /* sess_id */,
773                                          int /* id_len */,
774                                          int* copyflag) {
775     *copyflag = 0;
776     asyncCallbacks_++;
777     (void)ssl;
778 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
779     if (!SSL_want_sess_cache_lookup(ssl)) {
780       // libssl.so mismatch
781       std::cerr << "no async support" << std::endl;
782       return nullptr;
783     }
784
785     AsyncSSLSocket *sslSocket =
786         AsyncSSLSocket::getFromSSL(ssl);
787     assert(sslSocket != nullptr);
788     // Going to simulate an async cache by just running delaying the miss 100ms
789     if (asyncCallbacks_ % 2 == 0) {
790       // This socket is already blocked on lookup, return miss
791       std::cerr << "returning miss" << std::endl;
792     } else {
793       // fresh meat - block it
794       std::cerr << "async lookup" << std::endl;
795       sslSocket->getEventBase()->tryRunAfterDelay(
796         std::bind(&AsyncSSLSocket::restartSSLAccept,
797                   sslSocket), lookupDelay_);
798       *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
799       asyncLookups_++;
800     }
801 #endif
802     return nullptr;
803   }
804 };
805
806 void getfds(int fds[2]);
807
808 void getctx(
809   std::shared_ptr<folly::SSLContext> clientCtx,
810   std::shared_ptr<folly::SSLContext> serverCtx);
811
812 void sslsocketpair(
813   EventBase* eventBase,
814   AsyncSSLSocket::UniquePtr* clientSock,
815   AsyncSSLSocket::UniquePtr* serverSock);
816
817 class BlockingWriteClient :
818   private AsyncSSLSocket::HandshakeCB,
819   private AsyncTransportWrapper::WriteCallback {
820  public:
821   explicit BlockingWriteClient(
822     AsyncSSLSocket::UniquePtr socket)
823     : socket_(std::move(socket)),
824       bufLen_(2500),
825       iovCount_(2000) {
826     // Fill buf_
827     buf_.reset(new uint8_t[bufLen_]);
828     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
829       buf_[n] = n % 0xff;
830     }
831
832     // Initialize iov_
833     iov_.reset(new struct iovec[iovCount_]);
834     for (uint32_t n = 0; n < iovCount_; ++n) {
835       iov_[n].iov_base = buf_.get() + n;
836       if (n & 0x1) {
837         iov_[n].iov_len = n % bufLen_;
838       } else {
839         iov_[n].iov_len = bufLen_ - (n % bufLen_);
840       }
841     }
842
843     socket_->sslConn(this, std::chrono::milliseconds(100));
844   }
845
846   struct iovec* getIovec() const {
847     return iov_.get();
848   }
849   uint32_t getIovecCount() const {
850     return iovCount_;
851   }
852
853  private:
854   void handshakeSuc(AsyncSSLSocket*) noexcept override {
855     socket_->writev(this, iov_.get(), iovCount_);
856   }
857   void handshakeErr(
858     AsyncSSLSocket*,
859     const AsyncSocketException& ex) noexcept override {
860     ADD_FAILURE() << "client handshake error: " << ex.what();
861   }
862   void writeSuccess() noexcept override {
863     socket_->close();
864   }
865   void writeErr(
866     size_t bytesWritten,
867     const AsyncSocketException& ex) noexcept override {
868     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
869                   << ex.what();
870   }
871
872   AsyncSSLSocket::UniquePtr socket_;
873   uint32_t bufLen_;
874   uint32_t iovCount_;
875   std::unique_ptr<uint8_t[]> buf_;
876   std::unique_ptr<struct iovec[]> iov_;
877 };
878
879 class BlockingWriteServer :
880     private AsyncSSLSocket::HandshakeCB,
881     private AsyncTransportWrapper::ReadCallback {
882  public:
883   explicit BlockingWriteServer(
884     AsyncSSLSocket::UniquePtr socket)
885     : socket_(std::move(socket)),
886       bufSize_(2500 * 2000),
887       bytesRead_(0) {
888     buf_.reset(new uint8_t[bufSize_]);
889     socket_->sslAccept(this, std::chrono::milliseconds(100));
890   }
891
892   void checkBuffer(struct iovec* iov, uint32_t count) const {
893     uint32_t idx = 0;
894     for (uint32_t n = 0; n < count; ++n) {
895       size_t bytesLeft = bytesRead_ - idx;
896       int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
897                       std::min(iov[n].iov_len, bytesLeft));
898       if (rc != 0) {
899         FAIL() << "buffer mismatch at iovec " << n << "/" << count
900                << ": rc=" << rc;
901
902       }
903       if (iov[n].iov_len > bytesLeft) {
904         FAIL() << "server did not read enough data: "
905                << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
906                << " in iovec " << n << "/" << count;
907       }
908
909       idx += iov[n].iov_len;
910     }
911     if (idx != bytesRead_) {
912       ADD_FAILURE() << "server read extra data: " << bytesRead_
913                     << " bytes read; expected " << idx;
914     }
915   }
916
917  private:
918   void handshakeSuc(AsyncSSLSocket*) noexcept override {
919     // Wait 10ms before reading, so the client's writes will initially block.
920     socket_->getEventBase()->tryRunAfterDelay(
921         [this] { socket_->setReadCB(this); }, 10);
922   }
923   void handshakeErr(
924     AsyncSSLSocket*,
925     const AsyncSocketException& ex) noexcept override {
926     ADD_FAILURE() << "server handshake error: " << ex.what();
927   }
928   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
929     *bufReturn = buf_.get() + bytesRead_;
930     *lenReturn = bufSize_ - bytesRead_;
931   }
932   void readDataAvailable(size_t len) noexcept override {
933     bytesRead_ += len;
934     socket_->setReadCB(nullptr);
935     socket_->getEventBase()->tryRunAfterDelay(
936         [this] { socket_->setReadCB(this); }, 2);
937   }
938   void readEOF() noexcept override {
939     socket_->close();
940   }
941   void readErr(
942     const AsyncSocketException& ex) noexcept override {
943     ADD_FAILURE() << "server read error: " << ex.what();
944   }
945
946   AsyncSSLSocket::UniquePtr socket_;
947   uint32_t bufSize_;
948   uint32_t bytesRead_;
949   std::unique_ptr<uint8_t[]> buf_;
950 };
951
952 class NpnClient :
953   private AsyncSSLSocket::HandshakeCB,
954   private AsyncTransportWrapper::WriteCallback {
955  public:
956   explicit NpnClient(
957     AsyncSSLSocket::UniquePtr socket)
958       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
959     socket_->sslConn(this);
960   }
961
962   const unsigned char* nextProto;
963   unsigned nextProtoLength;
964   SSLContext::NextProtocolType protocolType;
965   folly::Optional<AsyncSocketException> except;
966
967  private:
968   void handshakeSuc(AsyncSSLSocket*) noexcept override {
969     socket_->getSelectedNextProtocol(
970         &nextProto, &nextProtoLength, &protocolType);
971   }
972   void handshakeErr(
973     AsyncSSLSocket*,
974     const AsyncSocketException& ex) noexcept override {
975     except = ex;
976   }
977   void writeSuccess() noexcept override {
978     socket_->close();
979   }
980   void writeErr(
981     size_t bytesWritten,
982     const AsyncSocketException& ex) noexcept override {
983     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
984                   << ex.what();
985   }
986
987   AsyncSSLSocket::UniquePtr socket_;
988 };
989
990 class NpnServer :
991     private AsyncSSLSocket::HandshakeCB,
992     private AsyncTransportWrapper::ReadCallback {
993  public:
994   explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
995       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
996     socket_->sslAccept(this);
997   }
998
999   const unsigned char* nextProto;
1000   unsigned nextProtoLength;
1001   SSLContext::NextProtocolType protocolType;
1002   folly::Optional<AsyncSocketException> except;
1003
1004  private:
1005   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1006     socket_->getSelectedNextProtocol(
1007         &nextProto, &nextProtoLength, &protocolType);
1008   }
1009   void handshakeErr(
1010     AsyncSSLSocket*,
1011     const AsyncSocketException& ex) noexcept override {
1012     except = ex;
1013   }
1014   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1015     *lenReturn = 0;
1016   }
1017   void readDataAvailable(size_t /* len */) noexcept override {}
1018   void readEOF() noexcept override {
1019     socket_->close();
1020   }
1021   void readErr(
1022     const AsyncSocketException& ex) noexcept override {
1023     ADD_FAILURE() << "server read error: " << ex.what();
1024   }
1025
1026   AsyncSSLSocket::UniquePtr socket_;
1027 };
1028
1029 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
1030                             public AsyncTransportWrapper::ReadCallback {
1031  public:
1032   explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
1033       : socket_(std::move(socket)) {
1034     socket_->sslAccept(this);
1035   }
1036
1037   ~RenegotiatingServer() override {
1038     socket_->setReadCB(nullptr);
1039   }
1040
1041   void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
1042     LOG(INFO) << "Renegotiating server handshake success";
1043     socket_->setReadCB(this);
1044   }
1045   void handshakeErr(
1046       AsyncSSLSocket*,
1047       const AsyncSocketException& ex) noexcept override {
1048     ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
1049   }
1050   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1051     *lenReturn = sizeof(buf);
1052     *bufReturn = buf;
1053   }
1054   void readDataAvailable(size_t /* len */) noexcept override {}
1055   void readEOF() noexcept override {}
1056   void readErr(const AsyncSocketException& ex) noexcept override {
1057     LOG(INFO) << "server got read error " << ex.what();
1058     auto exPtr = dynamic_cast<const SSLException*>(&ex);
1059     ASSERT_NE(nullptr, exPtr);
1060     std::string exStr(ex.what());
1061     SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
1062     ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
1063     renegotiationError_ = true;
1064   }
1065
1066   AsyncSSLSocket::UniquePtr socket_;
1067   unsigned char buf[128];
1068   bool renegotiationError_{false};
1069 };
1070
1071 #ifndef OPENSSL_NO_TLSEXT
1072 class SNIClient :
1073   private AsyncSSLSocket::HandshakeCB,
1074   private AsyncTransportWrapper::WriteCallback {
1075  public:
1076   explicit SNIClient(
1077     AsyncSSLSocket::UniquePtr socket)
1078       : serverNameMatch(false), socket_(std::move(socket)) {
1079     socket_->sslConn(this);
1080   }
1081
1082   bool serverNameMatch;
1083
1084  private:
1085   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1086     serverNameMatch = socket_->isServerNameMatch();
1087   }
1088   void handshakeErr(
1089     AsyncSSLSocket*,
1090     const AsyncSocketException& ex) noexcept override {
1091     ADD_FAILURE() << "client handshake error: " << ex.what();
1092   }
1093   void writeSuccess() noexcept override {
1094     socket_->close();
1095   }
1096   void writeErr(
1097     size_t bytesWritten,
1098     const AsyncSocketException& ex) noexcept override {
1099     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1100                   << ex.what();
1101   }
1102
1103   AsyncSSLSocket::UniquePtr socket_;
1104 };
1105
1106 class SNIServer :
1107     private AsyncSSLSocket::HandshakeCB,
1108     private AsyncTransportWrapper::ReadCallback {
1109  public:
1110   explicit SNIServer(
1111     AsyncSSLSocket::UniquePtr socket,
1112     const std::shared_ptr<folly::SSLContext>& ctx,
1113     const std::shared_ptr<folly::SSLContext>& sniCtx,
1114     const std::string& expectedServerName)
1115       : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1116         expectedServerName_(expectedServerName) {
1117     ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1118                                          std::placeholders::_1));
1119     socket_->sslAccept(this);
1120   }
1121
1122   bool serverNameMatch;
1123
1124  private:
1125   void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1126   void handshakeErr(
1127     AsyncSSLSocket*,
1128     const AsyncSocketException& ex) noexcept override {
1129     ADD_FAILURE() << "server handshake error: " << ex.what();
1130   }
1131   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1132     *lenReturn = 0;
1133   }
1134   void readDataAvailable(size_t /* len */) noexcept override {}
1135   void readEOF() noexcept override {
1136     socket_->close();
1137   }
1138   void readErr(
1139     const AsyncSocketException& ex) noexcept override {
1140     ADD_FAILURE() << "server read error: " << ex.what();
1141   }
1142
1143   folly::SSLContext::ServerNameCallbackResult
1144     serverNameCallback(SSL *ssl) {
1145     const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1146     if (sniCtx_ &&
1147         sn &&
1148         !strcasecmp(expectedServerName_.c_str(), sn)) {
1149       AsyncSSLSocket *sslSocket =
1150           AsyncSSLSocket::getFromSSL(ssl);
1151       sslSocket->switchServerSSLContext(sniCtx_);
1152       serverNameMatch = true;
1153       return folly::SSLContext::SERVER_NAME_FOUND;
1154     } else {
1155       serverNameMatch = false;
1156       return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1157     }
1158   }
1159
1160   AsyncSSLSocket::UniquePtr socket_;
1161   std::shared_ptr<folly::SSLContext> sniCtx_;
1162   std::string expectedServerName_;
1163 };
1164 #endif
1165
1166 class SSLClient : public AsyncSocket::ConnectCallback,
1167                   public AsyncTransportWrapper::WriteCallback,
1168                   public AsyncTransportWrapper::ReadCallback
1169 {
1170  private:
1171   EventBase *eventBase_;
1172   std::shared_ptr<AsyncSSLSocket> sslSocket_;
1173   SSL_SESSION *session_;
1174   std::shared_ptr<folly::SSLContext> ctx_;
1175   uint32_t requests_;
1176   folly::SocketAddress address_;
1177   uint32_t timeout_;
1178   char buf_[128];
1179   char readbuf_[128];
1180   uint32_t bytesRead_;
1181   uint32_t hit_;
1182   uint32_t miss_;
1183   uint32_t errors_;
1184   uint32_t writeAfterConnectErrors_;
1185
1186   // These settings test that we eventually drain the
1187   // socket, even if the maxReadsPerEvent_ is hit during
1188   // a event loop iteration.
1189   static constexpr size_t kMaxReadsPerEvent = 2;
1190   // 2 event loop iterations
1191   static constexpr size_t kMaxReadBufferSz =
1192     sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1193
1194  public:
1195   SSLClient(EventBase *eventBase,
1196             const folly::SocketAddress& address,
1197             uint32_t requests,
1198             uint32_t timeout = 0)
1199       : eventBase_(eventBase),
1200         session_(nullptr),
1201         requests_(requests),
1202         address_(address),
1203         timeout_(timeout),
1204         bytesRead_(0),
1205         hit_(0),
1206         miss_(0),
1207         errors_(0),
1208         writeAfterConnectErrors_(0) {
1209     ctx_.reset(new folly::SSLContext());
1210     ctx_->setOptions(SSL_OP_NO_TICKET);
1211     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1212     memset(buf_, 'a', sizeof(buf_));
1213   }
1214
1215   ~SSLClient() override {
1216     if (session_) {
1217       SSL_SESSION_free(session_);
1218     }
1219     if (errors_ == 0) {
1220       EXPECT_EQ(bytesRead_, sizeof(buf_));
1221     }
1222   }
1223
1224   uint32_t getHit() const { return hit_; }
1225
1226   uint32_t getMiss() const { return miss_; }
1227
1228   uint32_t getErrors() const { return errors_; }
1229
1230   uint32_t getWriteAfterConnectErrors() const {
1231     return writeAfterConnectErrors_;
1232   }
1233
1234   void connect(bool writeNow = false) {
1235     sslSocket_ = AsyncSSLSocket::newSocket(
1236       ctx_, eventBase_);
1237     if (session_ != nullptr) {
1238       sslSocket_->setSSLSession(session_);
1239     }
1240     requests_--;
1241     sslSocket_->connect(this, address_, timeout_);
1242     if (sslSocket_ && writeNow) {
1243       // write some junk, used in an error test
1244       sslSocket_->write(this, buf_, sizeof(buf_));
1245     }
1246   }
1247
1248   void connectSuccess() noexcept override {
1249     std::cerr << "client SSL socket connected" << std::endl;
1250     if (sslSocket_->getSSLSessionReused()) {
1251       hit_++;
1252     } else {
1253       miss_++;
1254       if (session_ != nullptr) {
1255         SSL_SESSION_free(session_);
1256       }
1257       session_ = sslSocket_->getSSLSession();
1258     }
1259
1260     // write()
1261     sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1262     sslSocket_->write(this, buf_, sizeof(buf_));
1263     sslSocket_->setReadCB(this);
1264     memset(readbuf_, 'b', sizeof(readbuf_));
1265     bytesRead_ = 0;
1266   }
1267
1268   void connectErr(
1269     const AsyncSocketException& ex) noexcept override {
1270     std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1271     errors_++;
1272     sslSocket_.reset();
1273   }
1274
1275   void writeSuccess() noexcept override {
1276     std::cerr << "client write success" << std::endl;
1277   }
1278
1279   void writeErr(size_t /* bytesWritten */,
1280                 const AsyncSocketException& ex) noexcept override {
1281     std::cerr << "client writeError: " << ex.what() << std::endl;
1282     if (!sslSocket_) {
1283       writeAfterConnectErrors_++;
1284     }
1285   }
1286
1287   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1288     *bufReturn = readbuf_ + bytesRead_;
1289     *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1290   }
1291
1292   void readEOF() noexcept override {
1293     std::cerr << "client readEOF" << std::endl;
1294   }
1295
1296   void readErr(
1297     const AsyncSocketException& ex) noexcept override {
1298     std::cerr << "client readError: " << ex.what() << std::endl;
1299   }
1300
1301   void readDataAvailable(size_t len) noexcept override {
1302     std::cerr << "client read data: " << len << std::endl;
1303     bytesRead_ += len;
1304     if (bytesRead_ == sizeof(buf_)) {
1305       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1306       sslSocket_->closeNow();
1307       sslSocket_.reset();
1308       if (requests_ != 0) {
1309         connect();
1310       }
1311     }
1312   }
1313
1314 };
1315
1316 class SSLHandshakeBase :
1317   public AsyncSSLSocket::HandshakeCB,
1318   private AsyncTransportWrapper::WriteCallback {
1319  public:
1320   explicit SSLHandshakeBase(
1321    AsyncSSLSocket::UniquePtr socket,
1322    bool preverifyResult,
1323    bool verifyResult) :
1324     handshakeVerify_(false),
1325     handshakeSuccess_(false),
1326     handshakeError_(false),
1327     socket_(std::move(socket)),
1328     preverifyResult_(preverifyResult),
1329     verifyResult_(verifyResult) {
1330   }
1331
1332   AsyncSSLSocket::UniquePtr moveSocket() && {
1333     return std::move(socket_);
1334   }
1335
1336   bool handshakeVerify_;
1337   bool handshakeSuccess_;
1338   bool handshakeError_;
1339   std::chrono::nanoseconds handshakeTime;
1340
1341  protected:
1342   AsyncSSLSocket::UniquePtr socket_;
1343   bool preverifyResult_;
1344   bool verifyResult_;
1345
1346   // HandshakeCallback
1347   bool handshakeVer(AsyncSSLSocket* /* sock */,
1348                     bool preverifyOk,
1349                     X509_STORE_CTX* /* ctx */) noexcept override {
1350     handshakeVerify_ = true;
1351
1352     EXPECT_EQ(preverifyResult_, preverifyOk);
1353     return verifyResult_;
1354   }
1355
1356   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1357     LOG(INFO) << "Handshake success";
1358     handshakeSuccess_ = true;
1359     handshakeTime = socket_->getHandshakeTime();
1360   }
1361
1362   void handshakeErr(
1363       AsyncSSLSocket*,
1364       const AsyncSocketException& ex) noexcept override {
1365     LOG(INFO) << "Handshake error " << ex.what();
1366     handshakeError_ = true;
1367     handshakeTime = socket_->getHandshakeTime();
1368   }
1369
1370   // WriteCallback
1371   void writeSuccess() noexcept override {
1372     socket_->close();
1373   }
1374
1375   void writeErr(
1376    size_t bytesWritten,
1377    const AsyncSocketException& ex) noexcept override {
1378     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1379                   << ex.what();
1380   }
1381 };
1382
1383 class SSLHandshakeClient : public SSLHandshakeBase {
1384  public:
1385   SSLHandshakeClient(
1386    AsyncSSLSocket::UniquePtr socket,
1387    bool preverifyResult,
1388    bool verifyResult) :
1389     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1390     socket_->sslConn(this, std::chrono::milliseconds::zero());
1391   }
1392 };
1393
1394 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1395  public:
1396   SSLHandshakeClientNoVerify(
1397    AsyncSSLSocket::UniquePtr socket,
1398    bool preverifyResult,
1399    bool verifyResult) :
1400     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1401     socket_->sslConn(
1402         this,
1403         std::chrono::milliseconds::zero(),
1404         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1405   }
1406 };
1407
1408 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1409  public:
1410   SSLHandshakeClientDoVerify(
1411    AsyncSSLSocket::UniquePtr socket,
1412    bool preverifyResult,
1413    bool verifyResult) :
1414     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1415     socket_->sslConn(
1416         this,
1417         std::chrono::milliseconds::zero(),
1418         folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1419   }
1420 };
1421
1422 class SSLHandshakeServer : public SSLHandshakeBase {
1423  public:
1424   SSLHandshakeServer(
1425       AsyncSSLSocket::UniquePtr socket,
1426       bool preverifyResult,
1427       bool verifyResult)
1428     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1429     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1430   }
1431 };
1432
1433 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1434  public:
1435   SSLHandshakeServerParseClientHello(
1436       AsyncSSLSocket::UniquePtr socket,
1437       bool preverifyResult,
1438       bool verifyResult)
1439       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1440     socket_->enableClientHelloParsing();
1441     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1442   }
1443
1444   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1445
1446  protected:
1447   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1448     handshakeSuccess_ = true;
1449     sock->getSSLSharedCiphers(sharedCiphers_);
1450     sock->getSSLServerCiphers(serverCiphers_);
1451     sock->getSSLClientCiphers(clientCiphers_);
1452     chosenCipher_ = sock->getNegotiatedCipherName();
1453   }
1454 };
1455
1456
1457 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1458  public:
1459   SSLHandshakeServerNoVerify(
1460       AsyncSSLSocket::UniquePtr socket,
1461       bool preverifyResult,
1462       bool verifyResult)
1463     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1464     socket_->sslAccept(
1465         this,
1466         std::chrono::milliseconds::zero(),
1467         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1468   }
1469 };
1470
1471 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1472  public:
1473   SSLHandshakeServerDoVerify(
1474       AsyncSSLSocket::UniquePtr socket,
1475       bool preverifyResult,
1476       bool verifyResult)
1477     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1478     socket_->sslAccept(
1479         this,
1480         std::chrono::milliseconds::zero(),
1481         folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1482   }
1483 };
1484
1485 class EventBaseAborter : public AsyncTimeout {
1486  public:
1487   EventBaseAborter(EventBase* eventBase,
1488                    uint32_t timeoutMS)
1489     : AsyncTimeout(
1490       eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1491     , eventBase_(eventBase) {
1492     scheduleTimeout(timeoutMS);
1493   }
1494
1495   void timeoutExpired() noexcept override {
1496     FAIL() << "test timed out";
1497     eventBase_->terminateLoopSoon();
1498   }
1499
1500  private:
1501   EventBase* eventBase_;
1502 };
1503
1504 }