Implementing a callback interface for folly::AsyncSocket allowing to supply an ancill...
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/experimental/TestUtil.h>
24 #include <folly/io/async/AsyncSSLSocket.h>
25 #include <folly/io/async/AsyncServerSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/AsyncTimeout.h>
28 #include <folly/io/async/AsyncTransport.h>
29 #include <folly/io/async/EventBase.h>
30 #include <folly/io/async/ssl/SSLErrors.h>
31 #include <folly/io/async/test/TestSSLServer.h>
32 #include <folly/portability/GTest.h>
33 #include <folly/portability/Sockets.h>
34 #include <folly/portability/Unistd.h>
35
36 #include <fcntl.h>
37 #include <sys/types.h>
38 #include <condition_variable>
39 #include <iostream>
40 #include <list>
41
42 namespace folly {
43
44 // The destructors of all callback classes assert that the state is
45 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
46 // are responsible for setting the succeeded state properly before the
47 // destructors are called.
48
49 class SendMsgParamsCallbackBase :
50       public folly::AsyncSocket::SendMsgParamsCallback {
51  public:
52   SendMsgParamsCallbackBase() {}
53
54   void setSocket(
55     const std::shared_ptr<AsyncSSLSocket> &socket) {
56     socket_ = socket;
57     oldCallback_ = socket_->getSendMsgParamsCB();
58     socket_->setSendMsgParamCB(this);
59   }
60
61   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
62                                                                      override {
63     return oldCallback_->getFlags(flags);
64   }
65
66   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
67     oldCallback_->getAncillaryData(flags, data);
68   }
69
70   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
71     return oldCallback_->getAncillaryDataSize(flags);
72   }
73
74   std::shared_ptr<AsyncSSLSocket> socket_;
75   folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
76 };
77
78 class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
79  public:
80   SendMsgFlagsCallback() {}
81
82   void resetFlags(int flags) {
83     flags_ = flags;
84   }
85
86   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
87                                                                   override {
88     if (flags_) {
89       return flags_;
90     } else {
91       return oldCallback_->getFlags(flags);
92     }
93   }
94
95   int flags_{0};
96 };
97
98 class SendMsgDataCallback : public SendMsgFlagsCallback {
99  public:
100   SendMsgDataCallback() {}
101
102   void resetData(std::vector<char>&& data) {
103     ancillaryData_.swap(data);
104   }
105
106   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
107     if (ancillaryData_.size()) {
108       std::cerr << "getAncillaryData: copying data" << std::endl;
109       memcpy(data, ancillaryData_.data(), ancillaryData_.size());
110     } else {
111       oldCallback_->getAncillaryData(flags, data);
112     }
113   }
114
115   uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
116     if (ancillaryData_.size()) {
117       std::cerr << "getAncillaryDataSize: returning size" << std::endl;
118       return ancillaryData_.size();
119     } else {
120       return oldCallback_->getAncillaryDataSize(flags);
121     }
122   }
123
124   std::vector<char> ancillaryData_;
125 };
126
127 class WriteCallbackBase :
128 public AsyncTransportWrapper::WriteCallback {
129 public:
130   explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
131       : state(STATE_WAITING)
132       , bytesWritten(0)
133       , exception(AsyncSocketException::UNKNOWN, "none")
134       , mcb_(mcb) {}
135
136   ~WriteCallbackBase() {
137     EXPECT_EQ(STATE_SUCCEEDED, state);
138   }
139
140   virtual void setSocket(
141     const std::shared_ptr<AsyncSSLSocket> &socket) {
142     socket_ = socket;
143     if (mcb_) {
144       mcb_->setSocket(socket);
145     }
146   }
147
148   virtual void writeSuccess() noexcept override {
149     std::cerr << "writeSuccess" << std::endl;
150     state = STATE_SUCCEEDED;
151   }
152
153   void writeErr(
154     size_t nBytesWritten,
155     const AsyncSocketException& ex) noexcept override {
156     std::cerr << "writeError: bytesWritten " << nBytesWritten
157          << ", exception " << ex.what() << std::endl;
158
159     state = STATE_FAILED;
160     this->bytesWritten = nBytesWritten;
161     exception = ex;
162     socket_->close();
163   }
164
165   std::shared_ptr<AsyncSSLSocket> socket_;
166   StateEnum state;
167   size_t bytesWritten;
168   AsyncSocketException exception;
169   SendMsgParamsCallbackBase* mcb_;
170 };
171
172 class ExpectWriteErrorCallback :
173 public WriteCallbackBase {
174 public:
175   explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
176       : WriteCallbackBase(mcb) {}
177
178   ~ExpectWriteErrorCallback() {
179     EXPECT_EQ(STATE_FAILED, state);
180     EXPECT_EQ(exception.type_,
181              AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
182     EXPECT_EQ(exception.errno_, 22);
183     // Suppress the assert in  ~WriteCallbackBase()
184     state = STATE_SUCCEEDED;
185   }
186 };
187
188 #ifdef MSG_ERRQUEUE
189 /* copied from include/uapi/linux/net_tstamp.h */
190 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
191 enum SOF_TIMESTAMPING {
192   SOF_TIMESTAMPING_TX_SOFTWARE = (1 << 1),
193   SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
194   SOF_TIMESTAMPING_OPT_ID = (1 << 7),
195   SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
196   SOF_TIMESTAMPING_TX_ACK = (1 << 9),
197   SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
198 };
199
200 class WriteCheckTimestampCallback :
201  public WriteCallbackBase {
202 public:
203   explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
204     : WriteCallbackBase(mcb) {}
205
206   ~WriteCheckTimestampCallback() {
207     EXPECT_EQ(STATE_SUCCEEDED, state);
208     EXPECT_TRUE(gotTimestamp_);
209     EXPECT_TRUE(gotByteSeq_);
210   }
211
212   void setSocket(
213     const std::shared_ptr<AsyncSSLSocket> &socket) override {
214     WriteCallbackBase::setSocket(socket);
215
216     EXPECT_NE(socket_->getFd(), 0);
217     int flags = SOF_TIMESTAMPING_OPT_ID
218                 | SOF_TIMESTAMPING_OPT_TSONLY
219                 | SOF_TIMESTAMPING_SOFTWARE;
220     AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
221     int ret = tstampingOpt.apply(socket_->getFd(), flags);
222     EXPECT_EQ(ret, 0);
223   }
224
225   void checkForTimestampNotifications() noexcept {
226     int fd = socket_->getFd();
227     std::vector<char> ctrl(1024, 0);
228     unsigned char data;
229     struct msghdr msg;
230     iovec entry;
231
232     memset(&msg, 0, sizeof(msg));
233     entry.iov_base = &data;
234     entry.iov_len = sizeof(data);
235     msg.msg_iov = &entry;
236     msg.msg_iovlen = 1;
237     msg.msg_control = ctrl.data();
238     msg.msg_controllen = ctrl.size();
239
240     int ret;
241     while (true) {
242       ret = recvmsg(fd, &msg, MSG_ERRQUEUE);
243       if (ret < 0) {
244         if (errno != EAGAIN) {
245           auto errnoCopy = errno;
246           std::cerr << "::recvmsg exited with code " << ret
247                     << ", errno: " << errnoCopy << std::endl;
248           AsyncSocketException ex(
249             AsyncSocketException::INTERNAL_ERROR,
250             "recvmsg() failed",
251             errnoCopy);
252           exception = ex;
253         }
254         return;
255       }
256
257       for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
258            cmsg != nullptr && cmsg->cmsg_len != 0;
259            cmsg = CMSG_NXTHDR(&msg, cmsg)) {
260         if (cmsg->cmsg_level == SOL_SOCKET &&
261             cmsg->cmsg_type == SCM_TIMESTAMPING) {
262           gotTimestamp_ = true;
263           continue;
264         }
265
266         if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
267             (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
268           gotByteSeq_ = true;
269           continue;
270         }
271       }
272     }
273   }
274
275   bool gotTimestamp_{false};
276   bool gotByteSeq_{false};
277 };
278 #endif // MSG_ERRQUEUE
279
280 class ReadCallbackBase :
281 public AsyncTransportWrapper::ReadCallback {
282  public:
283   explicit ReadCallbackBase(WriteCallbackBase* wcb)
284       : wcb_(wcb), state(STATE_WAITING) {}
285
286   ~ReadCallbackBase() {
287     EXPECT_EQ(STATE_SUCCEEDED, state);
288   }
289
290   void setSocket(
291     const std::shared_ptr<AsyncSSLSocket> &socket) {
292     socket_ = socket;
293   }
294
295   void setState(StateEnum s) {
296     state = s;
297     if (wcb_) {
298       wcb_->state = s;
299     }
300   }
301
302   void readErr(
303     const AsyncSocketException& ex) noexcept override {
304     std::cerr << "readError " << ex.what() << std::endl;
305     state = STATE_FAILED;
306     socket_->close();
307   }
308
309   void readEOF() noexcept override {
310     std::cerr << "readEOF" << std::endl;
311
312     socket_->close();
313   }
314
315   std::shared_ptr<AsyncSSLSocket> socket_;
316   WriteCallbackBase *wcb_;
317   StateEnum state;
318 };
319
320 class ReadCallback : public ReadCallbackBase {
321 public:
322   explicit ReadCallback(WriteCallbackBase *wcb)
323       : ReadCallbackBase(wcb)
324       , buffers() {}
325
326   ~ReadCallback() {
327     for (std::vector<Buffer>::iterator it = buffers.begin();
328          it != buffers.end();
329          ++it) {
330       it->free();
331     }
332     currentBuffer.free();
333   }
334
335   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
336     if (!currentBuffer.buffer) {
337       currentBuffer.allocate(4096);
338     }
339     *bufReturn = currentBuffer.buffer;
340     *lenReturn = currentBuffer.length;
341   }
342
343   void readDataAvailable(size_t len) noexcept override {
344     std::cerr << "readDataAvailable, len " << len << std::endl;
345
346     currentBuffer.length = len;
347
348     wcb_->setSocket(socket_);
349
350     // Write back the same data.
351     socket_->write(wcb_, currentBuffer.buffer, len);
352
353     buffers.push_back(currentBuffer);
354     currentBuffer.reset();
355     state = STATE_SUCCEEDED;
356   }
357
358   class Buffer {
359   public:
360     Buffer() : buffer(nullptr), length(0) {}
361     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
362
363     void reset() {
364       buffer = nullptr;
365       length = 0;
366     }
367     void allocate(size_t len) {
368       assert(buffer == nullptr);
369       this->buffer = static_cast<char*>(malloc(len));
370       this->length = len;
371     }
372     void free() {
373       ::free(buffer);
374       reset();
375     }
376
377     char* buffer;
378     size_t length;
379   };
380
381   std::vector<Buffer> buffers;
382   Buffer currentBuffer;
383 };
384
385 class ReadErrorCallback : public ReadCallbackBase {
386 public:
387   explicit ReadErrorCallback(WriteCallbackBase *wcb)
388       : ReadCallbackBase(wcb) {}
389
390   // Return nullptr buffer to trigger readError()
391   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
392     *bufReturn = nullptr;
393     *lenReturn = 0;
394   }
395
396   void readDataAvailable(size_t /* len */) noexcept override {
397     // This should never to called.
398     FAIL();
399   }
400
401   void readErr(
402     const AsyncSocketException& ex) noexcept override {
403     ReadCallbackBase::readErr(ex);
404     std::cerr << "ReadErrorCallback::readError" << std::endl;
405     setState(STATE_SUCCEEDED);
406   }
407 };
408
409 class ReadEOFCallback : public ReadCallbackBase {
410  public:
411   explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
412
413   // Return nullptr buffer to trigger readError()
414   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
415     *bufReturn = nullptr;
416     *lenReturn = 0;
417   }
418
419   void readDataAvailable(size_t /* len */) noexcept override {
420     // This should never to called.
421     FAIL();
422   }
423
424   void readEOF() noexcept override {
425     ReadCallbackBase::readEOF();
426     setState(STATE_SUCCEEDED);
427   }
428 };
429
430 class WriteErrorCallback : public ReadCallback {
431 public:
432   explicit WriteErrorCallback(WriteCallbackBase *wcb)
433       : ReadCallback(wcb) {}
434
435   void readDataAvailable(size_t len) noexcept override {
436     std::cerr << "readDataAvailable, len " << len << std::endl;
437
438     currentBuffer.length = len;
439
440     // close the socket before writing to trigger writeError().
441     ::close(socket_->getFd());
442
443     wcb_->setSocket(socket_);
444
445     // Write back the same data.
446     folly::test::msvcSuppressAbortOnInvalidParams([&] {
447       socket_->write(wcb_, currentBuffer.buffer, len);
448     });
449
450     if (wcb_->state == STATE_FAILED) {
451       setState(STATE_SUCCEEDED);
452     } else {
453       state = STATE_FAILED;
454     }
455
456     buffers.push_back(currentBuffer);
457     currentBuffer.reset();
458   }
459
460   void readErr(const AsyncSocketException& ex) noexcept override {
461     std::cerr << "readError " << ex.what() << std::endl;
462     // do nothing since this is expected
463   }
464 };
465
466 class EmptyReadCallback : public ReadCallback {
467 public:
468   explicit EmptyReadCallback()
469       : ReadCallback(nullptr) {}
470
471   void readErr(const AsyncSocketException& ex) noexcept override {
472     std::cerr << "readError " << ex.what() << std::endl;
473     state = STATE_FAILED;
474     if (tcpSocket_) {
475       tcpSocket_->close();
476     }
477   }
478
479   void readEOF() noexcept override {
480     std::cerr << "readEOF" << std::endl;
481     if (tcpSocket_) {
482       tcpSocket_->close();
483     }
484     state = STATE_SUCCEEDED;
485   }
486
487   std::shared_ptr<AsyncSocket> tcpSocket_;
488 };
489
490 class HandshakeCallback :
491 public AsyncSSLSocket::HandshakeCB {
492 public:
493   enum ExpectType {
494     EXPECT_SUCCESS,
495     EXPECT_ERROR
496   };
497
498   explicit HandshakeCallback(ReadCallbackBase *rcb,
499                              ExpectType expect = EXPECT_SUCCESS):
500       state(STATE_WAITING),
501       rcb_(rcb),
502       expect_(expect) {}
503
504   void setSocket(
505     const std::shared_ptr<AsyncSSLSocket> &socket) {
506     socket_ = socket;
507   }
508
509   void setState(StateEnum s) {
510     state = s;
511     rcb_->setState(s);
512   }
513
514   // Functions inherited from AsyncSSLSocketHandshakeCallback
515   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
516     std::lock_guard<std::mutex> g(mutex_);
517     cv_.notify_all();
518     EXPECT_EQ(sock, socket_.get());
519     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
520     rcb_->setSocket(socket_);
521     sock->setReadCB(rcb_);
522     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
523   }
524   void handshakeErr(AsyncSSLSocket* /* sock */,
525                     const AsyncSocketException& ex) noexcept override {
526     std::lock_guard<std::mutex> g(mutex_);
527     cv_.notify_all();
528     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
529     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
530     if (expect_ == EXPECT_ERROR) {
531       // rcb will never be invoked
532       rcb_->setState(STATE_SUCCEEDED);
533     }
534     errorString_ = ex.what();
535   }
536
537   void waitForHandshake() {
538     std::unique_lock<std::mutex> lock(mutex_);
539     cv_.wait(lock, [this] { return state != STATE_WAITING; });
540   }
541
542   ~HandshakeCallback() {
543     EXPECT_EQ(STATE_SUCCEEDED, state);
544   }
545
546   void closeSocket() {
547     socket_->close();
548     state = STATE_SUCCEEDED;
549   }
550
551   std::shared_ptr<AsyncSSLSocket> getSocket() {
552     return socket_;
553   }
554
555   StateEnum state;
556   std::shared_ptr<AsyncSSLSocket> socket_;
557   ReadCallbackBase *rcb_;
558   ExpectType expect_;
559   std::mutex mutex_;
560   std::condition_variable cv_;
561   std::string errorString_;
562 };
563
564 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
565 public:
566   uint32_t timeout_;
567
568   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
569                                    uint32_t timeout = 0):
570       SSLServerAcceptCallbackBase(hcb),
571       timeout_(timeout) {}
572
573   virtual ~SSLServerAcceptCallback() {
574     if (timeout_ > 0) {
575       // if we set a timeout, we expect failure
576       EXPECT_EQ(hcb_->state, STATE_FAILED);
577       hcb_->setState(STATE_SUCCEEDED);
578     }
579   }
580
581   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
582   void connAccepted(
583     const std::shared_ptr<folly::AsyncSSLSocket> &s)
584     noexcept override {
585     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
586     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
587
588     hcb_->setSocket(sock);
589     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
590     EXPECT_EQ(sock->getSSLState(),
591                       AsyncSSLSocket::STATE_ACCEPTING);
592
593     state = STATE_SUCCEEDED;
594   }
595 };
596
597 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
598 public:
599   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
600       SSLServerAcceptCallback(hcb) {}
601
602   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
603   void connAccepted(
604     const std::shared_ptr<folly::AsyncSSLSocket> &s)
605     noexcept override {
606
607     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
608
609     std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
610               << std::endl;
611     int fd = sock->getFd();
612
613 #ifndef TCP_NOPUSH
614     {
615     // The accepted connection should already have TCP_NODELAY set
616     int value;
617     socklen_t valueLength = sizeof(value);
618     int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
619     EXPECT_EQ(rc, 0);
620     EXPECT_EQ(value, 1);
621     }
622 #endif
623
624     // Unset the TCP_NODELAY option.
625     int value = 0;
626     socklen_t valueLength = sizeof(value);
627     int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
628     EXPECT_EQ(rc, 0);
629
630     rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
631     EXPECT_EQ(rc, 0);
632     EXPECT_EQ(value, 0);
633
634     SSLServerAcceptCallback::connAccepted(sock);
635   }
636 };
637
638 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
639 public:
640   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
641                                              uint32_t timeout = 0):
642     SSLServerAcceptCallback(hcb, timeout) {}
643
644   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
645   void connAccepted(
646     const std::shared_ptr<folly::AsyncSSLSocket> &s)
647     noexcept override {
648     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
649
650     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
651
652     hcb_->setSocket(sock);
653     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
654     ASSERT_TRUE((sock->getSSLState() ==
655                  AsyncSSLSocket::STATE_ACCEPTING) ||
656                 (sock->getSSLState() ==
657                  AsyncSSLSocket::STATE_CACHE_LOOKUP));
658
659     state = STATE_SUCCEEDED;
660   }
661 };
662
663
664 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
665 public:
666   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
667   SSLServerAcceptCallbackBase(hcb)  {}
668
669   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
670   void connAccepted(
671     const std::shared_ptr<folly::AsyncSSLSocket> &s)
672     noexcept override {
673     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
674
675     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
676
677     // The first call to sslAccept() should succeed.
678     hcb_->setSocket(sock);
679     sock->sslAccept(hcb_);
680     EXPECT_EQ(sock->getSSLState(),
681                       AsyncSSLSocket::STATE_ACCEPTING);
682
683     // The second call to sslAccept() should fail.
684     HandshakeCallback callback2(hcb_->rcb_);
685     callback2.setSocket(sock);
686     sock->sslAccept(&callback2);
687     EXPECT_EQ(sock->getSSLState(),
688                       AsyncSSLSocket::STATE_ERROR);
689
690     // Both callbacks should be in the error state.
691     EXPECT_EQ(hcb_->state, STATE_FAILED);
692     EXPECT_EQ(callback2.state, STATE_FAILED);
693
694     state = STATE_SUCCEEDED;
695     hcb_->setState(STATE_SUCCEEDED);
696     callback2.setState(STATE_SUCCEEDED);
697   }
698 };
699
700 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
701 public:
702   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
703   SSLServerAcceptCallbackBase(hcb)  {}
704
705   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
706   void connAccepted(
707     const std::shared_ptr<folly::AsyncSSLSocket> &s)
708     noexcept override {
709     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
710
711     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
712
713     hcb_->setSocket(sock);
714     sock->getEventBase()->tryRunAfterDelay([=] {
715         std::cerr << "Delayed SSL accept, client will have close by now"
716                   << std::endl;
717         // SSL accept will fail
718         EXPECT_EQ(
719           sock->getSSLState(),
720           AsyncSSLSocket::STATE_UNINIT);
721         hcb_->socket_->sslAccept(hcb_);
722         // This registers for an event
723         EXPECT_EQ(
724           sock->getSSLState(),
725           AsyncSSLSocket::STATE_ACCEPTING);
726
727         state = STATE_SUCCEEDED;
728       }, 100);
729   }
730 };
731
732 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
733  public:
734   ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
735     // We don't care if we get invoked or not.
736     // The client may time out and give up before connAccepted() is even
737     // called.
738     state = STATE_SUCCEEDED;
739   }
740
741   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
742   void connAccepted(
743       const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
744     std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
745
746     // Just wait a while before closing the socket, so the client
747     // will time out waiting for the handshake to complete.
748     s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
749   }
750 };
751
752 class TestSSLAsyncCacheServer : public TestSSLServer {
753  public:
754   explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
755         int lookupDelay = 100) :
756       TestSSLServer(acb) {
757     SSL_CTX *sslCtx = ctx_->getSSLCtx();
758     SSL_CTX_sess_set_get_cb(sslCtx,
759                             TestSSLAsyncCacheServer::getSessionCallback);
760     SSL_CTX_set_session_cache_mode(
761       sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
762     asyncCallbacks_ = 0;
763     asyncLookups_ = 0;
764     lookupDelay_ = lookupDelay;
765   }
766
767   uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
768   uint32_t getAsyncLookups() const { return asyncLookups_; }
769
770  private:
771   static uint32_t asyncCallbacks_;
772   static uint32_t asyncLookups_;
773   static uint32_t lookupDelay_;
774
775   static SSL_SESSION* getSessionCallback(SSL* ssl,
776                                          unsigned char* /* sess_id */,
777                                          int /* id_len */,
778                                          int* copyflag) {
779     *copyflag = 0;
780     asyncCallbacks_++;
781     (void)ssl;
782 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
783     if (!SSL_want_sess_cache_lookup(ssl)) {
784       // libssl.so mismatch
785       std::cerr << "no async support" << std::endl;
786       return nullptr;
787     }
788
789     AsyncSSLSocket *sslSocket =
790         AsyncSSLSocket::getFromSSL(ssl);
791     assert(sslSocket != nullptr);
792     // Going to simulate an async cache by just running delaying the miss 100ms
793     if (asyncCallbacks_ % 2 == 0) {
794       // This socket is already blocked on lookup, return miss
795       std::cerr << "returning miss" << std::endl;
796     } else {
797       // fresh meat - block it
798       std::cerr << "async lookup" << std::endl;
799       sslSocket->getEventBase()->tryRunAfterDelay(
800         std::bind(&AsyncSSLSocket::restartSSLAccept,
801                   sslSocket), lookupDelay_);
802       *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
803       asyncLookups_++;
804     }
805 #endif
806     return nullptr;
807   }
808 };
809
810 void getfds(int fds[2]);
811
812 void getctx(
813   std::shared_ptr<folly::SSLContext> clientCtx,
814   std::shared_ptr<folly::SSLContext> serverCtx);
815
816 void sslsocketpair(
817   EventBase* eventBase,
818   AsyncSSLSocket::UniquePtr* clientSock,
819   AsyncSSLSocket::UniquePtr* serverSock);
820
821 class BlockingWriteClient :
822   private AsyncSSLSocket::HandshakeCB,
823   private AsyncTransportWrapper::WriteCallback {
824  public:
825   explicit BlockingWriteClient(
826     AsyncSSLSocket::UniquePtr socket)
827     : socket_(std::move(socket)),
828       bufLen_(2500),
829       iovCount_(2000) {
830     // Fill buf_
831     buf_.reset(new uint8_t[bufLen_]);
832     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
833       buf_[n] = n % 0xff;
834     }
835
836     // Initialize iov_
837     iov_.reset(new struct iovec[iovCount_]);
838     for (uint32_t n = 0; n < iovCount_; ++n) {
839       iov_[n].iov_base = buf_.get() + n;
840       if (n & 0x1) {
841         iov_[n].iov_len = n % bufLen_;
842       } else {
843         iov_[n].iov_len = bufLen_ - (n % bufLen_);
844       }
845     }
846
847     socket_->sslConn(this, std::chrono::milliseconds(100));
848   }
849
850   struct iovec* getIovec() const {
851     return iov_.get();
852   }
853   uint32_t getIovecCount() const {
854     return iovCount_;
855   }
856
857  private:
858   void handshakeSuc(AsyncSSLSocket*) noexcept override {
859     socket_->writev(this, iov_.get(), iovCount_);
860   }
861   void handshakeErr(
862     AsyncSSLSocket*,
863     const AsyncSocketException& ex) noexcept override {
864     ADD_FAILURE() << "client handshake error: " << ex.what();
865   }
866   void writeSuccess() noexcept override {
867     socket_->close();
868   }
869   void writeErr(
870     size_t bytesWritten,
871     const AsyncSocketException& ex) noexcept override {
872     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
873                   << ex.what();
874   }
875
876   AsyncSSLSocket::UniquePtr socket_;
877   uint32_t bufLen_;
878   uint32_t iovCount_;
879   std::unique_ptr<uint8_t[]> buf_;
880   std::unique_ptr<struct iovec[]> iov_;
881 };
882
883 class BlockingWriteServer :
884     private AsyncSSLSocket::HandshakeCB,
885     private AsyncTransportWrapper::ReadCallback {
886  public:
887   explicit BlockingWriteServer(
888     AsyncSSLSocket::UniquePtr socket)
889     : socket_(std::move(socket)),
890       bufSize_(2500 * 2000),
891       bytesRead_(0) {
892     buf_.reset(new uint8_t[bufSize_]);
893     socket_->sslAccept(this, std::chrono::milliseconds(100));
894   }
895
896   void checkBuffer(struct iovec* iov, uint32_t count) const {
897     uint32_t idx = 0;
898     for (uint32_t n = 0; n < count; ++n) {
899       size_t bytesLeft = bytesRead_ - idx;
900       int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
901                       std::min(iov[n].iov_len, bytesLeft));
902       if (rc != 0) {
903         FAIL() << "buffer mismatch at iovec " << n << "/" << count
904                << ": rc=" << rc;
905
906       }
907       if (iov[n].iov_len > bytesLeft) {
908         FAIL() << "server did not read enough data: "
909                << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
910                << " in iovec " << n << "/" << count;
911       }
912
913       idx += iov[n].iov_len;
914     }
915     if (idx != bytesRead_) {
916       ADD_FAILURE() << "server read extra data: " << bytesRead_
917                     << " bytes read; expected " << idx;
918     }
919   }
920
921  private:
922   void handshakeSuc(AsyncSSLSocket*) noexcept override {
923     // Wait 10ms before reading, so the client's writes will initially block.
924     socket_->getEventBase()->tryRunAfterDelay(
925         [this] { socket_->setReadCB(this); }, 10);
926   }
927   void handshakeErr(
928     AsyncSSLSocket*,
929     const AsyncSocketException& ex) noexcept override {
930     ADD_FAILURE() << "server handshake error: " << ex.what();
931   }
932   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
933     *bufReturn = buf_.get() + bytesRead_;
934     *lenReturn = bufSize_ - bytesRead_;
935   }
936   void readDataAvailable(size_t len) noexcept override {
937     bytesRead_ += len;
938     socket_->setReadCB(nullptr);
939     socket_->getEventBase()->tryRunAfterDelay(
940         [this] { socket_->setReadCB(this); }, 2);
941   }
942   void readEOF() noexcept override {
943     socket_->close();
944   }
945   void readErr(
946     const AsyncSocketException& ex) noexcept override {
947     ADD_FAILURE() << "server read error: " << ex.what();
948   }
949
950   AsyncSSLSocket::UniquePtr socket_;
951   uint32_t bufSize_;
952   uint32_t bytesRead_;
953   std::unique_ptr<uint8_t[]> buf_;
954 };
955
956 class NpnClient :
957   private AsyncSSLSocket::HandshakeCB,
958   private AsyncTransportWrapper::WriteCallback {
959  public:
960   explicit NpnClient(
961     AsyncSSLSocket::UniquePtr socket)
962       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
963     socket_->sslConn(this);
964   }
965
966   const unsigned char* nextProto;
967   unsigned nextProtoLength;
968   SSLContext::NextProtocolType protocolType;
969
970  private:
971   void handshakeSuc(AsyncSSLSocket*) noexcept override {
972     socket_->getSelectedNextProtocol(
973         &nextProto, &nextProtoLength, &protocolType);
974   }
975   void handshakeErr(
976     AsyncSSLSocket*,
977     const AsyncSocketException& ex) noexcept override {
978     ADD_FAILURE() << "client handshake error: " << ex.what();
979   }
980   void writeSuccess() noexcept override {
981     socket_->close();
982   }
983   void writeErr(
984     size_t bytesWritten,
985     const AsyncSocketException& ex) noexcept override {
986     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
987                   << ex.what();
988   }
989
990   AsyncSSLSocket::UniquePtr socket_;
991 };
992
993 class NpnServer :
994     private AsyncSSLSocket::HandshakeCB,
995     private AsyncTransportWrapper::ReadCallback {
996  public:
997   explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
998       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
999     socket_->sslAccept(this);
1000   }
1001
1002   const unsigned char* nextProto;
1003   unsigned nextProtoLength;
1004   SSLContext::NextProtocolType protocolType;
1005
1006  private:
1007   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1008     socket_->getSelectedNextProtocol(
1009         &nextProto, &nextProtoLength, &protocolType);
1010   }
1011   void handshakeErr(
1012     AsyncSSLSocket*,
1013     const AsyncSocketException& ex) noexcept override {
1014     ADD_FAILURE() << "server handshake error: " << ex.what();
1015   }
1016   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1017     *lenReturn = 0;
1018   }
1019   void readDataAvailable(size_t /* len */) noexcept override {}
1020   void readEOF() noexcept override {
1021     socket_->close();
1022   }
1023   void readErr(
1024     const AsyncSocketException& ex) noexcept override {
1025     ADD_FAILURE() << "server read error: " << ex.what();
1026   }
1027
1028   AsyncSSLSocket::UniquePtr socket_;
1029 };
1030
1031 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
1032                             public AsyncTransportWrapper::ReadCallback {
1033  public:
1034   explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
1035       : socket_(std::move(socket)) {
1036     socket_->sslAccept(this);
1037   }
1038
1039   ~RenegotiatingServer() {
1040     socket_->setReadCB(nullptr);
1041   }
1042
1043   void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
1044     LOG(INFO) << "Renegotiating server handshake success";
1045     socket_->setReadCB(this);
1046   }
1047   void handshakeErr(
1048       AsyncSSLSocket*,
1049       const AsyncSocketException& ex) noexcept override {
1050     ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
1051   }
1052   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1053     *lenReturn = sizeof(buf);
1054     *bufReturn = buf;
1055   }
1056   void readDataAvailable(size_t /* len */) noexcept override {}
1057   void readEOF() noexcept override {}
1058   void readErr(const AsyncSocketException& ex) noexcept override {
1059     LOG(INFO) << "server got read error " << ex.what();
1060     auto exPtr = dynamic_cast<const SSLException*>(&ex);
1061     ASSERT_NE(nullptr, exPtr);
1062     std::string exStr(ex.what());
1063     SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
1064     ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
1065     renegotiationError_ = true;
1066   }
1067
1068   AsyncSSLSocket::UniquePtr socket_;
1069   unsigned char buf[128];
1070   bool renegotiationError_{false};
1071 };
1072
1073 #ifndef OPENSSL_NO_TLSEXT
1074 class SNIClient :
1075   private AsyncSSLSocket::HandshakeCB,
1076   private AsyncTransportWrapper::WriteCallback {
1077  public:
1078   explicit SNIClient(
1079     AsyncSSLSocket::UniquePtr socket)
1080       : serverNameMatch(false), socket_(std::move(socket)) {
1081     socket_->sslConn(this);
1082   }
1083
1084   bool serverNameMatch;
1085
1086  private:
1087   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1088     serverNameMatch = socket_->isServerNameMatch();
1089   }
1090   void handshakeErr(
1091     AsyncSSLSocket*,
1092     const AsyncSocketException& ex) noexcept override {
1093     ADD_FAILURE() << "client handshake error: " << ex.what();
1094   }
1095   void writeSuccess() noexcept override {
1096     socket_->close();
1097   }
1098   void writeErr(
1099     size_t bytesWritten,
1100     const AsyncSocketException& ex) noexcept override {
1101     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1102                   << ex.what();
1103   }
1104
1105   AsyncSSLSocket::UniquePtr socket_;
1106 };
1107
1108 class SNIServer :
1109     private AsyncSSLSocket::HandshakeCB,
1110     private AsyncTransportWrapper::ReadCallback {
1111  public:
1112   explicit SNIServer(
1113     AsyncSSLSocket::UniquePtr socket,
1114     const std::shared_ptr<folly::SSLContext>& ctx,
1115     const std::shared_ptr<folly::SSLContext>& sniCtx,
1116     const std::string& expectedServerName)
1117       : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
1118         expectedServerName_(expectedServerName) {
1119     ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
1120                                          std::placeholders::_1));
1121     socket_->sslAccept(this);
1122   }
1123
1124   bool serverNameMatch;
1125
1126  private:
1127   void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
1128   void handshakeErr(
1129     AsyncSSLSocket*,
1130     const AsyncSocketException& ex) noexcept override {
1131     ADD_FAILURE() << "server handshake error: " << ex.what();
1132   }
1133   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
1134     *lenReturn = 0;
1135   }
1136   void readDataAvailable(size_t /* len */) noexcept override {}
1137   void readEOF() noexcept override {
1138     socket_->close();
1139   }
1140   void readErr(
1141     const AsyncSocketException& ex) noexcept override {
1142     ADD_FAILURE() << "server read error: " << ex.what();
1143   }
1144
1145   folly::SSLContext::ServerNameCallbackResult
1146     serverNameCallback(SSL *ssl) {
1147     const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
1148     if (sniCtx_ &&
1149         sn &&
1150         !strcasecmp(expectedServerName_.c_str(), sn)) {
1151       AsyncSSLSocket *sslSocket =
1152           AsyncSSLSocket::getFromSSL(ssl);
1153       sslSocket->switchServerSSLContext(sniCtx_);
1154       serverNameMatch = true;
1155       return folly::SSLContext::SERVER_NAME_FOUND;
1156     } else {
1157       serverNameMatch = false;
1158       return folly::SSLContext::SERVER_NAME_NOT_FOUND;
1159     }
1160   }
1161
1162   AsyncSSLSocket::UniquePtr socket_;
1163   std::shared_ptr<folly::SSLContext> sniCtx_;
1164   std::string expectedServerName_;
1165 };
1166 #endif
1167
1168 class SSLClient : public AsyncSocket::ConnectCallback,
1169                   public AsyncTransportWrapper::WriteCallback,
1170                   public AsyncTransportWrapper::ReadCallback
1171 {
1172  private:
1173   EventBase *eventBase_;
1174   std::shared_ptr<AsyncSSLSocket> sslSocket_;
1175   SSL_SESSION *session_;
1176   std::shared_ptr<folly::SSLContext> ctx_;
1177   uint32_t requests_;
1178   folly::SocketAddress address_;
1179   uint32_t timeout_;
1180   char buf_[128];
1181   char readbuf_[128];
1182   uint32_t bytesRead_;
1183   uint32_t hit_;
1184   uint32_t miss_;
1185   uint32_t errors_;
1186   uint32_t writeAfterConnectErrors_;
1187
1188   // These settings test that we eventually drain the
1189   // socket, even if the maxReadsPerEvent_ is hit during
1190   // a event loop iteration.
1191   static constexpr size_t kMaxReadsPerEvent = 2;
1192   // 2 event loop iterations
1193   static constexpr size_t kMaxReadBufferSz =
1194     sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1195
1196  public:
1197   SSLClient(EventBase *eventBase,
1198             const folly::SocketAddress& address,
1199             uint32_t requests,
1200             uint32_t timeout = 0)
1201       : eventBase_(eventBase),
1202         session_(nullptr),
1203         requests_(requests),
1204         address_(address),
1205         timeout_(timeout),
1206         bytesRead_(0),
1207         hit_(0),
1208         miss_(0),
1209         errors_(0),
1210         writeAfterConnectErrors_(0) {
1211     ctx_.reset(new folly::SSLContext());
1212     ctx_->setOptions(SSL_OP_NO_TICKET);
1213     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1214     memset(buf_, 'a', sizeof(buf_));
1215   }
1216
1217   ~SSLClient() {
1218     if (session_) {
1219       SSL_SESSION_free(session_);
1220     }
1221     if (errors_ == 0) {
1222       EXPECT_EQ(bytesRead_, sizeof(buf_));
1223     }
1224   }
1225
1226   uint32_t getHit() const { return hit_; }
1227
1228   uint32_t getMiss() const { return miss_; }
1229
1230   uint32_t getErrors() const { return errors_; }
1231
1232   uint32_t getWriteAfterConnectErrors() const {
1233     return writeAfterConnectErrors_;
1234   }
1235
1236   void connect(bool writeNow = false) {
1237     sslSocket_ = AsyncSSLSocket::newSocket(
1238       ctx_, eventBase_);
1239     if (session_ != nullptr) {
1240       sslSocket_->setSSLSession(session_);
1241     }
1242     requests_--;
1243     sslSocket_->connect(this, address_, timeout_);
1244     if (sslSocket_ && writeNow) {
1245       // write some junk, used in an error test
1246       sslSocket_->write(this, buf_, sizeof(buf_));
1247     }
1248   }
1249
1250   void connectSuccess() noexcept override {
1251     std::cerr << "client SSL socket connected" << std::endl;
1252     if (sslSocket_->getSSLSessionReused()) {
1253       hit_++;
1254     } else {
1255       miss_++;
1256       if (session_ != nullptr) {
1257         SSL_SESSION_free(session_);
1258       }
1259       session_ = sslSocket_->getSSLSession();
1260     }
1261
1262     // write()
1263     sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1264     sslSocket_->write(this, buf_, sizeof(buf_));
1265     sslSocket_->setReadCB(this);
1266     memset(readbuf_, 'b', sizeof(readbuf_));
1267     bytesRead_ = 0;
1268   }
1269
1270   void connectErr(
1271     const AsyncSocketException& ex) noexcept override {
1272     std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1273     errors_++;
1274     sslSocket_.reset();
1275   }
1276
1277   void writeSuccess() noexcept override {
1278     std::cerr << "client write success" << std::endl;
1279   }
1280
1281   void writeErr(size_t /* bytesWritten */,
1282                 const AsyncSocketException& ex) noexcept override {
1283     std::cerr << "client writeError: " << ex.what() << std::endl;
1284     if (!sslSocket_) {
1285       writeAfterConnectErrors_++;
1286     }
1287   }
1288
1289   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1290     *bufReturn = readbuf_ + bytesRead_;
1291     *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1292   }
1293
1294   void readEOF() noexcept override {
1295     std::cerr << "client readEOF" << std::endl;
1296   }
1297
1298   void readErr(
1299     const AsyncSocketException& ex) noexcept override {
1300     std::cerr << "client readError: " << ex.what() << std::endl;
1301   }
1302
1303   void readDataAvailable(size_t len) noexcept override {
1304     std::cerr << "client read data: " << len << std::endl;
1305     bytesRead_ += len;
1306     if (bytesRead_ == sizeof(buf_)) {
1307       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1308       sslSocket_->closeNow();
1309       sslSocket_.reset();
1310       if (requests_ != 0) {
1311         connect();
1312       }
1313     }
1314   }
1315
1316 };
1317
1318 class SSLHandshakeBase :
1319   public AsyncSSLSocket::HandshakeCB,
1320   private AsyncTransportWrapper::WriteCallback {
1321  public:
1322   explicit SSLHandshakeBase(
1323    AsyncSSLSocket::UniquePtr socket,
1324    bool preverifyResult,
1325    bool verifyResult) :
1326     handshakeVerify_(false),
1327     handshakeSuccess_(false),
1328     handshakeError_(false),
1329     socket_(std::move(socket)),
1330     preverifyResult_(preverifyResult),
1331     verifyResult_(verifyResult) {
1332   }
1333
1334   AsyncSSLSocket::UniquePtr moveSocket() && {
1335     return std::move(socket_);
1336   }
1337
1338   bool handshakeVerify_;
1339   bool handshakeSuccess_;
1340   bool handshakeError_;
1341   std::chrono::nanoseconds handshakeTime;
1342
1343  protected:
1344   AsyncSSLSocket::UniquePtr socket_;
1345   bool preverifyResult_;
1346   bool verifyResult_;
1347
1348   // HandshakeCallback
1349   bool handshakeVer(AsyncSSLSocket* /* sock */,
1350                     bool preverifyOk,
1351                     X509_STORE_CTX* /* ctx */) noexcept override {
1352     handshakeVerify_ = true;
1353
1354     EXPECT_EQ(preverifyResult_, preverifyOk);
1355     return verifyResult_;
1356   }
1357
1358   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1359     LOG(INFO) << "Handshake success";
1360     handshakeSuccess_ = true;
1361     handshakeTime = socket_->getHandshakeTime();
1362   }
1363
1364   void handshakeErr(
1365       AsyncSSLSocket*,
1366       const AsyncSocketException& ex) noexcept override {
1367     LOG(INFO) << "Handshake error " << ex.what();
1368     handshakeError_ = true;
1369     handshakeTime = socket_->getHandshakeTime();
1370   }
1371
1372   // WriteCallback
1373   void writeSuccess() noexcept override {
1374     socket_->close();
1375   }
1376
1377   void writeErr(
1378    size_t bytesWritten,
1379    const AsyncSocketException& ex) noexcept override {
1380     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1381                   << ex.what();
1382   }
1383 };
1384
1385 class SSLHandshakeClient : public SSLHandshakeBase {
1386  public:
1387   SSLHandshakeClient(
1388    AsyncSSLSocket::UniquePtr socket,
1389    bool preverifyResult,
1390    bool verifyResult) :
1391     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1392     socket_->sslConn(this, std::chrono::milliseconds::zero());
1393   }
1394 };
1395
1396 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1397  public:
1398   SSLHandshakeClientNoVerify(
1399    AsyncSSLSocket::UniquePtr socket,
1400    bool preverifyResult,
1401    bool verifyResult) :
1402     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1403     socket_->sslConn(
1404         this,
1405         std::chrono::milliseconds::zero(),
1406         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1407   }
1408 };
1409
1410 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1411  public:
1412   SSLHandshakeClientDoVerify(
1413    AsyncSSLSocket::UniquePtr socket,
1414    bool preverifyResult,
1415    bool verifyResult) :
1416     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1417     socket_->sslConn(
1418         this,
1419         std::chrono::milliseconds::zero(),
1420         folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1421   }
1422 };
1423
1424 class SSLHandshakeServer : public SSLHandshakeBase {
1425  public:
1426   SSLHandshakeServer(
1427       AsyncSSLSocket::UniquePtr socket,
1428       bool preverifyResult,
1429       bool verifyResult)
1430     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1431     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1432   }
1433 };
1434
1435 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1436  public:
1437   SSLHandshakeServerParseClientHello(
1438       AsyncSSLSocket::UniquePtr socket,
1439       bool preverifyResult,
1440       bool verifyResult)
1441       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1442     socket_->enableClientHelloParsing();
1443     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1444   }
1445
1446   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1447
1448  protected:
1449   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1450     handshakeSuccess_ = true;
1451     sock->getSSLSharedCiphers(sharedCiphers_);
1452     sock->getSSLServerCiphers(serverCiphers_);
1453     sock->getSSLClientCiphers(clientCiphers_);
1454     chosenCipher_ = sock->getNegotiatedCipherName();
1455   }
1456 };
1457
1458
1459 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1460  public:
1461   SSLHandshakeServerNoVerify(
1462       AsyncSSLSocket::UniquePtr socket,
1463       bool preverifyResult,
1464       bool verifyResult)
1465     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1466     socket_->sslAccept(
1467         this,
1468         std::chrono::milliseconds::zero(),
1469         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1470   }
1471 };
1472
1473 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1474  public:
1475   SSLHandshakeServerDoVerify(
1476       AsyncSSLSocket::UniquePtr socket,
1477       bool preverifyResult,
1478       bool verifyResult)
1479     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1480     socket_->sslAccept(
1481         this,
1482         std::chrono::milliseconds::zero(),
1483         folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1484   }
1485 };
1486
1487 class EventBaseAborter : public AsyncTimeout {
1488  public:
1489   EventBaseAborter(EventBase* eventBase,
1490                    uint32_t timeoutMS)
1491     : AsyncTimeout(
1492       eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1493     , eventBase_(eventBase) {
1494     scheduleTimeout(timeoutMS);
1495   }
1496
1497   void timeoutExpired() noexcept override {
1498     FAIL() << "test timed out";
1499     eventBase_->terminateLoopSoon();
1500   }
1501
1502  private:
1503   EventBase* eventBase_;
1504 };
1505
1506 }