Fix AsyncSSLSocket handshake error reporting.
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncServerSocket.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <folly/io/async/AsyncTransport.h>
25 #include <folly/io/async/EventBase.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/SocketAddress.h>
28
29 #include <gtest/gtest.h>
30 #include <iostream>
31 #include <list>
32 #include <unistd.h>
33 #include <fcntl.h>
34 #include <poll.h>
35 #include <sys/types.h>
36 #include <sys/socket.h>
37 #include <netinet/tcp.h>
38
39 namespace folly {
40
41 enum StateEnum {
42   STATE_WAITING,
43   STATE_SUCCEEDED,
44   STATE_FAILED
45 };
46
47 // The destructors of all callback classes assert that the state is
48 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
49 // are responsible for setting the succeeded state properly before the
50 // destructors are called.
51
52 class WriteCallbackBase :
53 public AsyncTransportWrapper::WriteCallback {
54 public:
55   WriteCallbackBase()
56       : state(STATE_WAITING)
57       , bytesWritten(0)
58       , exception(AsyncSocketException::UNKNOWN, "none") {}
59
60   ~WriteCallbackBase() {
61     EXPECT_EQ(state, STATE_SUCCEEDED);
62   }
63
64   void setSocket(
65     const std::shared_ptr<AsyncSSLSocket> &socket) {
66     socket_ = socket;
67   }
68
69   void writeSuccess() noexcept override {
70     std::cerr << "writeSuccess" << std::endl;
71     state = STATE_SUCCEEDED;
72   }
73
74   void writeErr(
75     size_t bytesWritten,
76     const AsyncSocketException& ex) noexcept override {
77     std::cerr << "writeError: bytesWritten " << bytesWritten
78          << ", exception " << ex.what() << std::endl;
79
80     state = STATE_FAILED;
81     this->bytesWritten = bytesWritten;
82     exception = ex;
83     socket_->close();
84     socket_->detachEventBase();
85   }
86
87   std::shared_ptr<AsyncSSLSocket> socket_;
88   StateEnum state;
89   size_t bytesWritten;
90   AsyncSocketException exception;
91 };
92
93 class ReadCallbackBase :
94 public AsyncTransportWrapper::ReadCallback {
95 public:
96   explicit ReadCallbackBase(WriteCallbackBase *wcb)
97       : wcb_(wcb)
98       , state(STATE_WAITING) {}
99
100   ~ReadCallbackBase() {
101     EXPECT_EQ(state, STATE_SUCCEEDED);
102   }
103
104   void setSocket(
105     const std::shared_ptr<AsyncSSLSocket> &socket) {
106     socket_ = socket;
107   }
108
109   void setState(StateEnum s) {
110     state = s;
111     if (wcb_) {
112       wcb_->state = s;
113     }
114   }
115
116   void readErr(
117     const AsyncSocketException& ex) noexcept override {
118     std::cerr << "readError " << ex.what() << std::endl;
119     state = STATE_FAILED;
120     socket_->close();
121     socket_->detachEventBase();
122   }
123
124   void readEOF() noexcept override {
125     std::cerr << "readEOF" << std::endl;
126
127     socket_->close();
128     socket_->detachEventBase();
129   }
130
131   std::shared_ptr<AsyncSSLSocket> socket_;
132   WriteCallbackBase *wcb_;
133   StateEnum state;
134 };
135
136 class ReadCallback : public ReadCallbackBase {
137 public:
138   explicit ReadCallback(WriteCallbackBase *wcb)
139       : ReadCallbackBase(wcb)
140       , buffers() {}
141
142   ~ReadCallback() {
143     for (std::vector<Buffer>::iterator it = buffers.begin();
144          it != buffers.end();
145          ++it) {
146       it->free();
147     }
148     currentBuffer.free();
149   }
150
151   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
152     if (!currentBuffer.buffer) {
153       currentBuffer.allocate(4096);
154     }
155     *bufReturn = currentBuffer.buffer;
156     *lenReturn = currentBuffer.length;
157   }
158
159   void readDataAvailable(size_t len) noexcept override {
160     std::cerr << "readDataAvailable, len " << len << std::endl;
161
162     currentBuffer.length = len;
163
164     wcb_->setSocket(socket_);
165
166     // Write back the same data.
167     socket_->write(wcb_, currentBuffer.buffer, len);
168
169     buffers.push_back(currentBuffer);
170     currentBuffer.reset();
171     state = STATE_SUCCEEDED;
172   }
173
174   class Buffer {
175   public:
176     Buffer() : buffer(nullptr), length(0) {}
177     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
178
179     void reset() {
180       buffer = nullptr;
181       length = 0;
182     }
183     void allocate(size_t length) {
184       assert(buffer == nullptr);
185       this->buffer = static_cast<char*>(malloc(length));
186       this->length = length;
187     }
188     void free() {
189       ::free(buffer);
190       reset();
191     }
192
193     char* buffer;
194     size_t length;
195   };
196
197   std::vector<Buffer> buffers;
198   Buffer currentBuffer;
199 };
200
201 class ReadErrorCallback : public ReadCallbackBase {
202 public:
203   explicit ReadErrorCallback(WriteCallbackBase *wcb)
204       : ReadCallbackBase(wcb) {}
205
206   // Return nullptr buffer to trigger readError()
207   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
208     *bufReturn = nullptr;
209     *lenReturn = 0;
210   }
211
212   void readDataAvailable(size_t /* len */) noexcept override {
213     // This should never to called.
214     FAIL();
215   }
216
217   void readErr(
218     const AsyncSocketException& ex) noexcept override {
219     ReadCallbackBase::readErr(ex);
220     std::cerr << "ReadErrorCallback::readError" << std::endl;
221     setState(STATE_SUCCEEDED);
222   }
223 };
224
225 class WriteErrorCallback : public ReadCallback {
226 public:
227   explicit WriteErrorCallback(WriteCallbackBase *wcb)
228       : ReadCallback(wcb) {}
229
230   void readDataAvailable(size_t len) noexcept override {
231     std::cerr << "readDataAvailable, len " << len << std::endl;
232
233     currentBuffer.length = len;
234
235     // close the socket before writing to trigger writeError().
236     ::close(socket_->getFd());
237
238     wcb_->setSocket(socket_);
239
240     // Write back the same data.
241     socket_->write(wcb_, currentBuffer.buffer, len);
242
243     if (wcb_->state == STATE_FAILED) {
244       setState(STATE_SUCCEEDED);
245     } else {
246       state = STATE_FAILED;
247     }
248
249     buffers.push_back(currentBuffer);
250     currentBuffer.reset();
251   }
252
253   void readErr(const AsyncSocketException& ex) noexcept override {
254     std::cerr << "readError " << ex.what() << std::endl;
255     // do nothing since this is expected
256   }
257 };
258
259 class EmptyReadCallback : public ReadCallback {
260 public:
261   explicit EmptyReadCallback()
262       : ReadCallback(nullptr) {}
263
264   void readErr(const AsyncSocketException& ex) noexcept override {
265     std::cerr << "readError " << ex.what() << std::endl;
266     state = STATE_FAILED;
267     tcpSocket_->close();
268     tcpSocket_->detachEventBase();
269   }
270
271   void readEOF() noexcept override {
272     std::cerr << "readEOF" << std::endl;
273
274     tcpSocket_->close();
275     tcpSocket_->detachEventBase();
276     state = STATE_SUCCEEDED;
277   }
278
279   std::shared_ptr<AsyncSocket> tcpSocket_;
280 };
281
282 class HandshakeCallback :
283 public AsyncSSLSocket::HandshakeCB {
284 public:
285   enum ExpectType {
286     EXPECT_SUCCESS,
287     EXPECT_ERROR
288   };
289
290   explicit HandshakeCallback(ReadCallbackBase *rcb,
291                              ExpectType expect = EXPECT_SUCCESS):
292       state(STATE_WAITING),
293       rcb_(rcb),
294       expect_(expect) {}
295
296   void setSocket(
297     const std::shared_ptr<AsyncSSLSocket> &socket) {
298     socket_ = socket;
299   }
300
301   void setState(StateEnum s) {
302     state = s;
303     rcb_->setState(s);
304   }
305
306   // Functions inherited from AsyncSSLSocketHandshakeCallback
307   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
308     std::lock_guard<std::mutex> g(mutex_);
309     cv_.notify_all();
310     EXPECT_EQ(sock, socket_.get());
311     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
312     rcb_->setSocket(socket_);
313     sock->setReadCB(rcb_);
314     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
315   }
316   void handshakeErr(AsyncSSLSocket* /* sock */,
317                     const AsyncSocketException& ex) noexcept override {
318     std::lock_guard<std::mutex> g(mutex_);
319     cv_.notify_all();
320     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
321     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
322     if (expect_ == EXPECT_ERROR) {
323       // rcb will never be invoked
324       rcb_->setState(STATE_SUCCEEDED);
325     }
326     errorString_ = ex.what();
327   }
328
329   void waitForHandshake() {
330     std::unique_lock<std::mutex> lock(mutex_);
331     cv_.wait(lock, [this] { return state != STATE_WAITING; });
332   }
333
334   ~HandshakeCallback() {
335     EXPECT_EQ(state, STATE_SUCCEEDED);
336   }
337
338   void closeSocket() {
339     socket_->close();
340     state = STATE_SUCCEEDED;
341   }
342
343   StateEnum state;
344   std::shared_ptr<AsyncSSLSocket> socket_;
345   ReadCallbackBase *rcb_;
346   ExpectType expect_;
347   std::mutex mutex_;
348   std::condition_variable cv_;
349   std::string errorString_;
350 };
351
352 class SSLServerAcceptCallbackBase:
353 public folly::AsyncServerSocket::AcceptCallback {
354 public:
355   explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
356   state(STATE_WAITING), hcb_(hcb) {}
357
358   ~SSLServerAcceptCallbackBase() {
359     EXPECT_EQ(state, STATE_SUCCEEDED);
360   }
361
362   void acceptError(const std::exception& ex) noexcept override {
363     std::cerr << "SSLServerAcceptCallbackBase::acceptError "
364               << ex.what() << std::endl;
365     state = STATE_FAILED;
366   }
367
368   void connectionAccepted(
369       int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
370     printf("Connection accepted\n");
371     std::shared_ptr<AsyncSSLSocket> sslSock;
372     try {
373       // Create a AsyncSSLSocket object with the fd. The socket should be
374       // added to the event base and in the state of accepting SSL connection.
375       sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
376     } catch (const std::exception &e) {
377       LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
378         "object with socket " << e.what() << fd;
379       ::close(fd);
380       acceptError(e);
381       return;
382     }
383
384     connAccepted(sslSock);
385   }
386
387   virtual void connAccepted(
388     const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
389
390   StateEnum state;
391   HandshakeCallback *hcb_;
392   std::shared_ptr<folly::SSLContext> ctx_;
393   folly::EventBase* base_;
394 };
395
396 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
397 public:
398   uint32_t timeout_;
399
400   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
401                                    uint32_t timeout = 0):
402       SSLServerAcceptCallbackBase(hcb),
403       timeout_(timeout) {}
404
405   virtual ~SSLServerAcceptCallback() {
406     if (timeout_ > 0) {
407       // if we set a timeout, we expect failure
408       EXPECT_EQ(hcb_->state, STATE_FAILED);
409       hcb_->setState(STATE_SUCCEEDED);
410     }
411   }
412
413   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
414   void connAccepted(
415     const std::shared_ptr<folly::AsyncSSLSocket> &s)
416     noexcept override {
417     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
418     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
419
420     hcb_->setSocket(sock);
421     sock->sslAccept(hcb_, timeout_);
422     EXPECT_EQ(sock->getSSLState(),
423                       AsyncSSLSocket::STATE_ACCEPTING);
424
425     state = STATE_SUCCEEDED;
426   }
427 };
428
429 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
430 public:
431   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
432       SSLServerAcceptCallback(hcb) {}
433
434   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
435   void connAccepted(
436     const std::shared_ptr<folly::AsyncSSLSocket> &s)
437     noexcept override {
438
439     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
440
441     std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
442               << std::endl;
443     int fd = sock->getFd();
444
445 #ifndef TCP_NOPUSH
446     {
447     // The accepted connection should already have TCP_NODELAY set
448     int value;
449     socklen_t valueLength = sizeof(value);
450     int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
451     EXPECT_EQ(rc, 0);
452     EXPECT_EQ(value, 1);
453     }
454 #endif
455
456     // Unset the TCP_NODELAY option.
457     int value = 0;
458     socklen_t valueLength = sizeof(value);
459     int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
460     EXPECT_EQ(rc, 0);
461
462     rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
463     EXPECT_EQ(rc, 0);
464     EXPECT_EQ(value, 0);
465
466     SSLServerAcceptCallback::connAccepted(sock);
467   }
468 };
469
470 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
471 public:
472   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
473                                              uint32_t timeout = 0):
474     SSLServerAcceptCallback(hcb, timeout) {}
475
476   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
477   void connAccepted(
478     const std::shared_ptr<folly::AsyncSSLSocket> &s)
479     noexcept override {
480     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
481
482     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
483
484     hcb_->setSocket(sock);
485     sock->sslAccept(hcb_, timeout_);
486     ASSERT_TRUE((sock->getSSLState() ==
487                  AsyncSSLSocket::STATE_ACCEPTING) ||
488                 (sock->getSSLState() ==
489                  AsyncSSLSocket::STATE_CACHE_LOOKUP));
490
491     state = STATE_SUCCEEDED;
492   }
493 };
494
495
496 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
497 public:
498   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
499   SSLServerAcceptCallbackBase(hcb)  {}
500
501   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
502   void connAccepted(
503     const std::shared_ptr<folly::AsyncSSLSocket> &s)
504     noexcept override {
505     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
506
507     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
508
509     // The first call to sslAccept() should succeed.
510     hcb_->setSocket(sock);
511     sock->sslAccept(hcb_);
512     EXPECT_EQ(sock->getSSLState(),
513                       AsyncSSLSocket::STATE_ACCEPTING);
514
515     // The second call to sslAccept() should fail.
516     HandshakeCallback callback2(hcb_->rcb_);
517     callback2.setSocket(sock);
518     sock->sslAccept(&callback2);
519     EXPECT_EQ(sock->getSSLState(),
520                       AsyncSSLSocket::STATE_ERROR);
521
522     // Both callbacks should be in the error state.
523     EXPECT_EQ(hcb_->state, STATE_FAILED);
524     EXPECT_EQ(callback2.state, STATE_FAILED);
525
526     sock->detachEventBase();
527
528     state = STATE_SUCCEEDED;
529     hcb_->setState(STATE_SUCCEEDED);
530     callback2.setState(STATE_SUCCEEDED);
531   }
532 };
533
534 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
535 public:
536   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
537   SSLServerAcceptCallbackBase(hcb)  {}
538
539   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
540   void connAccepted(
541     const std::shared_ptr<folly::AsyncSSLSocket> &s)
542     noexcept override {
543     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
544
545     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
546
547     hcb_->setSocket(sock);
548     sock->getEventBase()->tryRunAfterDelay([=] {
549         std::cerr << "Delayed SSL accept, client will have close by now"
550                   << std::endl;
551         // SSL accept will fail
552         EXPECT_EQ(
553           sock->getSSLState(),
554           AsyncSSLSocket::STATE_UNINIT);
555         hcb_->socket_->sslAccept(hcb_);
556         // This registers for an event
557         EXPECT_EQ(
558           sock->getSSLState(),
559           AsyncSSLSocket::STATE_ACCEPTING);
560
561         state = STATE_SUCCEEDED;
562       }, 100);
563   }
564 };
565
566
567 class TestSSLServer {
568  protected:
569   EventBase evb_;
570   std::shared_ptr<folly::SSLContext> ctx_;
571   SSLServerAcceptCallbackBase *acb_;
572   std::shared_ptr<folly::AsyncServerSocket> socket_;
573   folly::SocketAddress address_;
574   pthread_t thread_;
575
576   static void *Main(void *ctx) {
577     TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
578     self->evb_.loop();
579     std::cerr << "Server thread exited event loop" << std::endl;
580     return nullptr;
581   }
582
583  public:
584   // Create a TestSSLServer.
585   // This immediately starts listening on the given port.
586   explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
587
588   // Kill the thread.
589   ~TestSSLServer() {
590     evb_.runInEventBaseThread([&](){
591       socket_->stopAccepting();
592     });
593     std::cerr << "Waiting for server thread to exit" << std::endl;
594     pthread_join(thread_, nullptr);
595   }
596
597   EventBase &getEventBase() { return evb_; }
598
599   const folly::SocketAddress& getAddress() const {
600     return address_;
601   }
602 };
603
604 class TestSSLAsyncCacheServer : public TestSSLServer {
605  public:
606   explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
607         int lookupDelay = 100) :
608       TestSSLServer(acb) {
609     SSL_CTX *sslCtx = ctx_->getSSLCtx();
610     SSL_CTX_sess_set_get_cb(sslCtx,
611                             TestSSLAsyncCacheServer::getSessionCallback);
612     SSL_CTX_set_session_cache_mode(
613       sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
614     asyncCallbacks_ = 0;
615     asyncLookups_ = 0;
616     lookupDelay_ = lookupDelay;
617   }
618
619   uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
620   uint32_t getAsyncLookups() const { return asyncLookups_; }
621
622  private:
623   static uint32_t asyncCallbacks_;
624   static uint32_t asyncLookups_;
625   static uint32_t lookupDelay_;
626
627   static SSL_SESSION* getSessionCallback(SSL* ssl,
628                                          unsigned char* /* sess_id */,
629                                          int /* id_len */,
630                                          int* copyflag) {
631     *copyflag = 0;
632     asyncCallbacks_++;
633 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
634     if (!SSL_want_sess_cache_lookup(ssl)) {
635       // libssl.so mismatch
636       std::cerr << "no async support" << std::endl;
637       return nullptr;
638     }
639
640     AsyncSSLSocket *sslSocket =
641         AsyncSSLSocket::getFromSSL(ssl);
642     assert(sslSocket != nullptr);
643     // Going to simulate an async cache by just running delaying the miss 100ms
644     if (asyncCallbacks_ % 2 == 0) {
645       // This socket is already blocked on lookup, return miss
646       std::cerr << "returning miss" << std::endl;
647     } else {
648       // fresh meat - block it
649       std::cerr << "async lookup" << std::endl;
650       sslSocket->getEventBase()->tryRunAfterDelay(
651         std::bind(&AsyncSSLSocket::restartSSLAccept,
652                   sslSocket), lookupDelay_);
653       *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
654       asyncLookups_++;
655     }
656 #endif
657     return nullptr;
658   }
659 };
660
661 void getfds(int fds[2]);
662
663 void getctx(
664   std::shared_ptr<folly::SSLContext> clientCtx,
665   std::shared_ptr<folly::SSLContext> serverCtx);
666
667 void sslsocketpair(
668   EventBase* eventBase,
669   AsyncSSLSocket::UniquePtr* clientSock,
670   AsyncSSLSocket::UniquePtr* serverSock);
671
672 class BlockingWriteClient :
673   private AsyncSSLSocket::HandshakeCB,
674   private AsyncTransportWrapper::WriteCallback {
675  public:
676   explicit BlockingWriteClient(
677     AsyncSSLSocket::UniquePtr socket)
678     : socket_(std::move(socket)),
679       bufLen_(2500),
680       iovCount_(2000) {
681     // Fill buf_
682     buf_.reset(new uint8_t[bufLen_]);
683     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
684       buf_[n] = n % 0xff;
685     }
686
687     // Initialize iov_
688     iov_.reset(new struct iovec[iovCount_]);
689     for (uint32_t n = 0; n < iovCount_; ++n) {
690       iov_[n].iov_base = buf_.get() + n;
691       if (n & 0x1) {
692         iov_[n].iov_len = n % bufLen_;
693       } else {
694         iov_[n].iov_len = bufLen_ - (n % bufLen_);
695       }
696     }
697
698     socket_->sslConn(this, 100);
699   }
700
701   struct iovec* getIovec() const {
702     return iov_.get();
703   }
704   uint32_t getIovecCount() const {
705     return iovCount_;
706   }
707
708  private:
709   void handshakeSuc(AsyncSSLSocket*) noexcept override {
710     socket_->writev(this, iov_.get(), iovCount_);
711   }
712   void handshakeErr(
713     AsyncSSLSocket*,
714     const AsyncSocketException& ex) noexcept override {
715     ADD_FAILURE() << "client handshake error: " << ex.what();
716   }
717   void writeSuccess() noexcept override {
718     socket_->close();
719   }
720   void writeErr(
721     size_t bytesWritten,
722     const AsyncSocketException& ex) noexcept override {
723     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
724                   << ex.what();
725   }
726
727   AsyncSSLSocket::UniquePtr socket_;
728   uint32_t bufLen_;
729   uint32_t iovCount_;
730   std::unique_ptr<uint8_t[]> buf_;
731   std::unique_ptr<struct iovec[]> iov_;
732 };
733
734 class BlockingWriteServer :
735     private AsyncSSLSocket::HandshakeCB,
736     private AsyncTransportWrapper::ReadCallback {
737  public:
738   explicit BlockingWriteServer(
739     AsyncSSLSocket::UniquePtr socket)
740     : socket_(std::move(socket)),
741       bufSize_(2500 * 2000),
742       bytesRead_(0) {
743     buf_.reset(new uint8_t[bufSize_]);
744     socket_->sslAccept(this, 100);
745   }
746
747   void checkBuffer(struct iovec* iov, uint32_t count) const {
748     uint32_t idx = 0;
749     for (uint32_t n = 0; n < count; ++n) {
750       size_t bytesLeft = bytesRead_ - idx;
751       int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
752                       std::min(iov[n].iov_len, bytesLeft));
753       if (rc != 0) {
754         FAIL() << "buffer mismatch at iovec " << n << "/" << count
755                << ": rc=" << rc;
756
757       }
758       if (iov[n].iov_len > bytesLeft) {
759         FAIL() << "server did not read enough data: "
760                << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
761                << " in iovec " << n << "/" << count;
762       }
763
764       idx += iov[n].iov_len;
765     }
766     if (idx != bytesRead_) {
767       ADD_FAILURE() << "server read extra data: " << bytesRead_
768                     << " bytes read; expected " << idx;
769     }
770   }
771
772  private:
773   void handshakeSuc(AsyncSSLSocket*) noexcept override {
774     // Wait 10ms before reading, so the client's writes will initially block.
775     socket_->getEventBase()->tryRunAfterDelay(
776         [this] { socket_->setReadCB(this); }, 10);
777   }
778   void handshakeErr(
779     AsyncSSLSocket*,
780     const AsyncSocketException& ex) noexcept override {
781     ADD_FAILURE() << "server handshake error: " << ex.what();
782   }
783   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
784     *bufReturn = buf_.get() + bytesRead_;
785     *lenReturn = bufSize_ - bytesRead_;
786   }
787   void readDataAvailable(size_t len) noexcept override {
788     bytesRead_ += len;
789     socket_->setReadCB(nullptr);
790     socket_->getEventBase()->tryRunAfterDelay(
791         [this] { socket_->setReadCB(this); }, 2);
792   }
793   void readEOF() noexcept override {
794     socket_->close();
795   }
796   void readErr(
797     const AsyncSocketException& ex) noexcept override {
798     ADD_FAILURE() << "server read error: " << ex.what();
799   }
800
801   AsyncSSLSocket::UniquePtr socket_;
802   uint32_t bufSize_;
803   uint32_t bytesRead_;
804   std::unique_ptr<uint8_t[]> buf_;
805 };
806
807 class NpnClient :
808   private AsyncSSLSocket::HandshakeCB,
809   private AsyncTransportWrapper::WriteCallback {
810  public:
811   explicit NpnClient(
812     AsyncSSLSocket::UniquePtr socket)
813       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
814     socket_->sslConn(this);
815   }
816
817   const unsigned char* nextProto;
818   unsigned nextProtoLength;
819   SSLContext::NextProtocolType protocolType;
820
821  private:
822   void handshakeSuc(AsyncSSLSocket*) noexcept override {
823     socket_->getSelectedNextProtocol(
824         &nextProto, &nextProtoLength, &protocolType);
825   }
826   void handshakeErr(
827     AsyncSSLSocket*,
828     const AsyncSocketException& ex) noexcept override {
829     ADD_FAILURE() << "client handshake error: " << ex.what();
830   }
831   void writeSuccess() noexcept override {
832     socket_->close();
833   }
834   void writeErr(
835     size_t bytesWritten,
836     const AsyncSocketException& ex) noexcept override {
837     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
838                   << ex.what();
839   }
840
841   AsyncSSLSocket::UniquePtr socket_;
842 };
843
844 class NpnServer :
845     private AsyncSSLSocket::HandshakeCB,
846     private AsyncTransportWrapper::ReadCallback {
847  public:
848   explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
849       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
850     socket_->sslAccept(this);
851   }
852
853   const unsigned char* nextProto;
854   unsigned nextProtoLength;
855   SSLContext::NextProtocolType protocolType;
856
857  private:
858   void handshakeSuc(AsyncSSLSocket*) noexcept override {
859     socket_->getSelectedNextProtocol(
860         &nextProto, &nextProtoLength, &protocolType);
861   }
862   void handshakeErr(
863     AsyncSSLSocket*,
864     const AsyncSocketException& ex) noexcept override {
865     ADD_FAILURE() << "server handshake error: " << ex.what();
866   }
867   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
868     *lenReturn = 0;
869   }
870   void readDataAvailable(size_t /* len */) noexcept override {}
871   void readEOF() noexcept override {
872     socket_->close();
873   }
874   void readErr(
875     const AsyncSocketException& ex) noexcept override {
876     ADD_FAILURE() << "server read error: " << ex.what();
877   }
878
879   AsyncSSLSocket::UniquePtr socket_;
880 };
881
882 #ifndef OPENSSL_NO_TLSEXT
883 class SNIClient :
884   private AsyncSSLSocket::HandshakeCB,
885   private AsyncTransportWrapper::WriteCallback {
886  public:
887   explicit SNIClient(
888     AsyncSSLSocket::UniquePtr socket)
889       : serverNameMatch(false), socket_(std::move(socket)) {
890     socket_->sslConn(this);
891   }
892
893   bool serverNameMatch;
894
895  private:
896   void handshakeSuc(AsyncSSLSocket*) noexcept override {
897     serverNameMatch = socket_->isServerNameMatch();
898   }
899   void handshakeErr(
900     AsyncSSLSocket*,
901     const AsyncSocketException& ex) noexcept override {
902     ADD_FAILURE() << "client handshake error: " << ex.what();
903   }
904   void writeSuccess() noexcept override {
905     socket_->close();
906   }
907   void writeErr(
908     size_t bytesWritten,
909     const AsyncSocketException& ex) noexcept override {
910     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
911                   << ex.what();
912   }
913
914   AsyncSSLSocket::UniquePtr socket_;
915 };
916
917 class SNIServer :
918     private AsyncSSLSocket::HandshakeCB,
919     private AsyncTransportWrapper::ReadCallback {
920  public:
921   explicit SNIServer(
922     AsyncSSLSocket::UniquePtr socket,
923     const std::shared_ptr<folly::SSLContext>& ctx,
924     const std::shared_ptr<folly::SSLContext>& sniCtx,
925     const std::string& expectedServerName)
926       : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
927         expectedServerName_(expectedServerName) {
928     ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
929                                          std::placeholders::_1));
930     socket_->sslAccept(this);
931   }
932
933   bool serverNameMatch;
934
935  private:
936   void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
937   void handshakeErr(
938     AsyncSSLSocket*,
939     const AsyncSocketException& ex) noexcept override {
940     ADD_FAILURE() << "server handshake error: " << ex.what();
941   }
942   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
943     *lenReturn = 0;
944   }
945   void readDataAvailable(size_t /* len */) noexcept override {}
946   void readEOF() noexcept override {
947     socket_->close();
948   }
949   void readErr(
950     const AsyncSocketException& ex) noexcept override {
951     ADD_FAILURE() << "server read error: " << ex.what();
952   }
953
954   folly::SSLContext::ServerNameCallbackResult
955     serverNameCallback(SSL *ssl) {
956     const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
957     if (sniCtx_ &&
958         sn &&
959         !strcasecmp(expectedServerName_.c_str(), sn)) {
960       AsyncSSLSocket *sslSocket =
961           AsyncSSLSocket::getFromSSL(ssl);
962       sslSocket->switchServerSSLContext(sniCtx_);
963       serverNameMatch = true;
964       return folly::SSLContext::SERVER_NAME_FOUND;
965     } else {
966       serverNameMatch = false;
967       return folly::SSLContext::SERVER_NAME_NOT_FOUND;
968     }
969   }
970
971   AsyncSSLSocket::UniquePtr socket_;
972   std::shared_ptr<folly::SSLContext> sniCtx_;
973   std::string expectedServerName_;
974 };
975 #endif
976
977 class SSLClient : public AsyncSocket::ConnectCallback,
978                   public AsyncTransportWrapper::WriteCallback,
979                   public AsyncTransportWrapper::ReadCallback
980 {
981  private:
982   EventBase *eventBase_;
983   std::shared_ptr<AsyncSSLSocket> sslSocket_;
984   SSL_SESSION *session_;
985   std::shared_ptr<folly::SSLContext> ctx_;
986   uint32_t requests_;
987   folly::SocketAddress address_;
988   uint32_t timeout_;
989   char buf_[128];
990   char readbuf_[128];
991   uint32_t bytesRead_;
992   uint32_t hit_;
993   uint32_t miss_;
994   uint32_t errors_;
995   uint32_t writeAfterConnectErrors_;
996
997   // These settings test that we eventually drain the
998   // socket, even if the maxReadsPerEvent_ is hit during
999   // a event loop iteration.
1000   static constexpr size_t kMaxReadsPerEvent = 2;
1001   static constexpr size_t kMaxReadBufferSz =
1002     sizeof(readbuf_) / kMaxReadsPerEvent / 2;  // 2 event loop iterations
1003
1004  public:
1005   SSLClient(EventBase *eventBase,
1006             const folly::SocketAddress& address,
1007             uint32_t requests,
1008             uint32_t timeout = 0)
1009       : eventBase_(eventBase),
1010         session_(nullptr),
1011         requests_(requests),
1012         address_(address),
1013         timeout_(timeout),
1014         bytesRead_(0),
1015         hit_(0),
1016         miss_(0),
1017         errors_(0),
1018         writeAfterConnectErrors_(0) {
1019     ctx_.reset(new folly::SSLContext());
1020     ctx_->setOptions(SSL_OP_NO_TICKET);
1021     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1022     memset(buf_, 'a', sizeof(buf_));
1023   }
1024
1025   ~SSLClient() {
1026     if (session_) {
1027       SSL_SESSION_free(session_);
1028     }
1029     if (errors_ == 0) {
1030       EXPECT_EQ(bytesRead_, sizeof(buf_));
1031     }
1032   }
1033
1034   uint32_t getHit() const { return hit_; }
1035
1036   uint32_t getMiss() const { return miss_; }
1037
1038   uint32_t getErrors() const { return errors_; }
1039
1040   uint32_t getWriteAfterConnectErrors() const {
1041     return writeAfterConnectErrors_;
1042   }
1043
1044   void connect(bool writeNow = false) {
1045     sslSocket_ = AsyncSSLSocket::newSocket(
1046       ctx_, eventBase_);
1047     if (session_ != nullptr) {
1048       sslSocket_->setSSLSession(session_);
1049     }
1050     requests_--;
1051     sslSocket_->connect(this, address_, timeout_);
1052     if (sslSocket_ && writeNow) {
1053       // write some junk, used in an error test
1054       sslSocket_->write(this, buf_, sizeof(buf_));
1055     }
1056   }
1057
1058   void connectSuccess() noexcept override {
1059     std::cerr << "client SSL socket connected" << std::endl;
1060     if (sslSocket_->getSSLSessionReused()) {
1061       hit_++;
1062     } else {
1063       miss_++;
1064       if (session_ != nullptr) {
1065         SSL_SESSION_free(session_);
1066       }
1067       session_ = sslSocket_->getSSLSession();
1068     }
1069
1070     // write()
1071     sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1072     sslSocket_->write(this, buf_, sizeof(buf_));
1073     sslSocket_->setReadCB(this);
1074     memset(readbuf_, 'b', sizeof(readbuf_));
1075     bytesRead_ = 0;
1076   }
1077
1078   void connectErr(
1079     const AsyncSocketException& ex) noexcept override {
1080     std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1081     errors_++;
1082     sslSocket_.reset();
1083   }
1084
1085   void writeSuccess() noexcept override {
1086     std::cerr << "client write success" << std::endl;
1087   }
1088
1089   void writeErr(size_t /* bytesWritten */,
1090                 const AsyncSocketException& ex) noexcept override {
1091     std::cerr << "client writeError: " << ex.what() << std::endl;
1092     if (!sslSocket_) {
1093       writeAfterConnectErrors_++;
1094     }
1095   }
1096
1097   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1098     *bufReturn = readbuf_ + bytesRead_;
1099     *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1100   }
1101
1102   void readEOF() noexcept override {
1103     std::cerr << "client readEOF" << std::endl;
1104   }
1105
1106   void readErr(
1107     const AsyncSocketException& ex) noexcept override {
1108     std::cerr << "client readError: " << ex.what() << std::endl;
1109   }
1110
1111   void readDataAvailable(size_t len) noexcept override {
1112     std::cerr << "client read data: " << len << std::endl;
1113     bytesRead_ += len;
1114     if (bytesRead_ == sizeof(buf_)) {
1115       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1116       sslSocket_->closeNow();
1117       sslSocket_.reset();
1118       if (requests_ != 0) {
1119         connect();
1120       }
1121     }
1122   }
1123
1124 };
1125
1126 class SSLHandshakeBase :
1127   public AsyncSSLSocket::HandshakeCB,
1128   private AsyncTransportWrapper::WriteCallback {
1129  public:
1130   explicit SSLHandshakeBase(
1131    AsyncSSLSocket::UniquePtr socket,
1132    bool preverifyResult,
1133    bool verifyResult) :
1134     handshakeVerify_(false),
1135     handshakeSuccess_(false),
1136     handshakeError_(false),
1137     socket_(std::move(socket)),
1138     preverifyResult_(preverifyResult),
1139     verifyResult_(verifyResult) {
1140   }
1141
1142   bool handshakeVerify_;
1143   bool handshakeSuccess_;
1144   bool handshakeError_;
1145   std::chrono::nanoseconds handshakeTime;
1146
1147  protected:
1148   AsyncSSLSocket::UniquePtr socket_;
1149   bool preverifyResult_;
1150   bool verifyResult_;
1151
1152   // HandshakeCallback
1153   bool handshakeVer(AsyncSSLSocket* /* sock */,
1154                     bool preverifyOk,
1155                     X509_STORE_CTX* /* ctx */) noexcept override {
1156     handshakeVerify_ = true;
1157
1158     EXPECT_EQ(preverifyResult_, preverifyOk);
1159     return verifyResult_;
1160   }
1161
1162   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1163     handshakeSuccess_ = true;
1164     handshakeTime = socket_->getHandshakeTime();
1165   }
1166
1167   void handshakeErr(AsyncSSLSocket*,
1168                     const AsyncSocketException& /* ex */) noexcept override {
1169     handshakeError_ = true;
1170     handshakeTime = socket_->getHandshakeTime();
1171   }
1172
1173   // WriteCallback
1174   void writeSuccess() noexcept override {
1175     socket_->close();
1176   }
1177
1178   void writeErr(
1179    size_t bytesWritten,
1180    const AsyncSocketException& ex) noexcept override {
1181     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1182                   << ex.what();
1183   }
1184 };
1185
1186 class SSLHandshakeClient : public SSLHandshakeBase {
1187  public:
1188   SSLHandshakeClient(
1189    AsyncSSLSocket::UniquePtr socket,
1190    bool preverifyResult,
1191    bool verifyResult) :
1192     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1193     socket_->sslConn(this, 0);
1194   }
1195 };
1196
1197 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1198  public:
1199   SSLHandshakeClientNoVerify(
1200    AsyncSSLSocket::UniquePtr socket,
1201    bool preverifyResult,
1202    bool verifyResult) :
1203     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1204     socket_->sslConn(this, 0,
1205       folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1206   }
1207 };
1208
1209 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1210  public:
1211   SSLHandshakeClientDoVerify(
1212    AsyncSSLSocket::UniquePtr socket,
1213    bool preverifyResult,
1214    bool verifyResult) :
1215     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1216     socket_->sslConn(this, 0,
1217       folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1218   }
1219 };
1220
1221 class SSLHandshakeServer : public SSLHandshakeBase {
1222  public:
1223   SSLHandshakeServer(
1224       AsyncSSLSocket::UniquePtr socket,
1225       bool preverifyResult,
1226       bool verifyResult)
1227     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1228     socket_->sslAccept(this, 0);
1229   }
1230 };
1231
1232 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1233  public:
1234   SSLHandshakeServerParseClientHello(
1235       AsyncSSLSocket::UniquePtr socket,
1236       bool preverifyResult,
1237       bool verifyResult)
1238       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1239     socket_->enableClientHelloParsing();
1240     socket_->sslAccept(this, 0);
1241   }
1242
1243   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1244
1245  protected:
1246   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1247     handshakeSuccess_ = true;
1248     sock->getSSLSharedCiphers(sharedCiphers_);
1249     sock->getSSLServerCiphers(serverCiphers_);
1250     sock->getSSLClientCiphers(clientCiphers_);
1251     chosenCipher_ = sock->getNegotiatedCipherName();
1252   }
1253 };
1254
1255
1256 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1257  public:
1258   SSLHandshakeServerNoVerify(
1259       AsyncSSLSocket::UniquePtr socket,
1260       bool preverifyResult,
1261       bool verifyResult)
1262     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1263     socket_->sslAccept(this, 0,
1264       folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1265   }
1266 };
1267
1268 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1269  public:
1270   SSLHandshakeServerDoVerify(
1271       AsyncSSLSocket::UniquePtr socket,
1272       bool preverifyResult,
1273       bool verifyResult)
1274     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1275     socket_->sslAccept(this, 0,
1276       folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1277   }
1278 };
1279
1280 class EventBaseAborter : public AsyncTimeout {
1281  public:
1282   EventBaseAborter(EventBase* eventBase,
1283                    uint32_t timeoutMS)
1284     : AsyncTimeout(
1285       eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1286     , eventBase_(eventBase) {
1287     scheduleTimeout(timeoutMS);
1288   }
1289
1290   void timeoutExpired() noexcept override {
1291     FAIL() << "test timed out";
1292     eventBase_->terminateLoopSoon();
1293   }
1294
1295  private:
1296   EventBase* eventBase_;
1297 };
1298
1299 }