63349bc909cd5874e993eaaaeed6ac6097ee74c1
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncServerSocket.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <folly/io/async/AsyncTransport.h>
25 #include <folly/io/async/EventBase.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/SocketAddress.h>
28
29 #include <gtest/gtest.h>
30 #include <iostream>
31 #include <list>
32 #include <unistd.h>
33 #include <fcntl.h>
34 #include <poll.h>
35 #include <sys/types.h>
36 #include <sys/socket.h>
37 #include <netinet/tcp.h>
38
39 namespace folly {
40
41 enum StateEnum {
42   STATE_WAITING,
43   STATE_SUCCEEDED,
44   STATE_FAILED
45 };
46
47 // The destructors of all callback classes assert that the state is
48 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
49 // are responsible for setting the succeeded state properly before the
50 // destructors are called.
51
52 class WriteCallbackBase :
53 public AsyncTransportWrapper::WriteCallback {
54 public:
55   WriteCallbackBase()
56       : state(STATE_WAITING)
57       , bytesWritten(0)
58       , exception(AsyncSocketException::UNKNOWN, "none") {}
59
60   ~WriteCallbackBase() {
61     EXPECT_EQ(state, STATE_SUCCEEDED);
62   }
63
64   void setSocket(
65     const std::shared_ptr<AsyncSSLSocket> &socket) {
66     socket_ = socket;
67   }
68
69   void writeSuccess() noexcept override {
70     std::cerr << "writeSuccess" << std::endl;
71     state = STATE_SUCCEEDED;
72   }
73
74   void writeErr(
75     size_t bytesWritten,
76     const AsyncSocketException& ex) noexcept override {
77     std::cerr << "writeError: bytesWritten " << bytesWritten
78          << ", exception " << ex.what() << std::endl;
79
80     state = STATE_FAILED;
81     this->bytesWritten = bytesWritten;
82     exception = ex;
83     socket_->close();
84     socket_->detachEventBase();
85   }
86
87   std::shared_ptr<AsyncSSLSocket> socket_;
88   StateEnum state;
89   size_t bytesWritten;
90   AsyncSocketException exception;
91 };
92
93 class ReadCallbackBase :
94 public AsyncTransportWrapper::ReadCallback {
95 public:
96   explicit ReadCallbackBase(WriteCallbackBase *wcb)
97       : wcb_(wcb)
98       , state(STATE_WAITING) {}
99
100   ~ReadCallbackBase() {
101     EXPECT_EQ(state, STATE_SUCCEEDED);
102   }
103
104   void setSocket(
105     const std::shared_ptr<AsyncSSLSocket> &socket) {
106     socket_ = socket;
107   }
108
109   void setState(StateEnum s) {
110     state = s;
111     if (wcb_) {
112       wcb_->state = s;
113     }
114   }
115
116   void readErr(
117     const AsyncSocketException& ex) noexcept override {
118     std::cerr << "readError " << ex.what() << std::endl;
119     state = STATE_FAILED;
120     socket_->close();
121     socket_->detachEventBase();
122   }
123
124   void readEOF() noexcept override {
125     std::cerr << "readEOF" << std::endl;
126
127     socket_->close();
128     socket_->detachEventBase();
129   }
130
131   std::shared_ptr<AsyncSSLSocket> socket_;
132   WriteCallbackBase *wcb_;
133   StateEnum state;
134 };
135
136 class ReadCallback : public ReadCallbackBase {
137 public:
138   explicit ReadCallback(WriteCallbackBase *wcb)
139       : ReadCallbackBase(wcb)
140       , buffers() {}
141
142   ~ReadCallback() {
143     for (std::vector<Buffer>::iterator it = buffers.begin();
144          it != buffers.end();
145          ++it) {
146       it->free();
147     }
148     currentBuffer.free();
149   }
150
151   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
152     if (!currentBuffer.buffer) {
153       currentBuffer.allocate(4096);
154     }
155     *bufReturn = currentBuffer.buffer;
156     *lenReturn = currentBuffer.length;
157   }
158
159   void readDataAvailable(size_t len) noexcept override {
160     std::cerr << "readDataAvailable, len " << len << std::endl;
161
162     currentBuffer.length = len;
163
164     wcb_->setSocket(socket_);
165
166     // Write back the same data.
167     socket_->write(wcb_, currentBuffer.buffer, len);
168
169     buffers.push_back(currentBuffer);
170     currentBuffer.reset();
171     state = STATE_SUCCEEDED;
172   }
173
174   class Buffer {
175   public:
176     Buffer() : buffer(nullptr), length(0) {}
177     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
178
179     void reset() {
180       buffer = nullptr;
181       length = 0;
182     }
183     void allocate(size_t length) {
184       assert(buffer == nullptr);
185       this->buffer = static_cast<char*>(malloc(length));
186       this->length = length;
187     }
188     void free() {
189       ::free(buffer);
190       reset();
191     }
192
193     char* buffer;
194     size_t length;
195   };
196
197   std::vector<Buffer> buffers;
198   Buffer currentBuffer;
199 };
200
201 class ReadErrorCallback : public ReadCallbackBase {
202 public:
203   explicit ReadErrorCallback(WriteCallbackBase *wcb)
204       : ReadCallbackBase(wcb) {}
205
206   // Return nullptr buffer to trigger readError()
207   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
208     *bufReturn = nullptr;
209     *lenReturn = 0;
210   }
211
212   void readDataAvailable(size_t len) noexcept override {
213     // This should never to called.
214     FAIL();
215   }
216
217   void readErr(
218     const AsyncSocketException& ex) noexcept override {
219     ReadCallbackBase::readErr(ex);
220     std::cerr << "ReadErrorCallback::readError" << std::endl;
221     setState(STATE_SUCCEEDED);
222   }
223 };
224
225 class WriteErrorCallback : public ReadCallback {
226 public:
227   explicit WriteErrorCallback(WriteCallbackBase *wcb)
228       : ReadCallback(wcb) {}
229
230   void readDataAvailable(size_t len) noexcept override {
231     std::cerr << "readDataAvailable, len " << len << std::endl;
232
233     currentBuffer.length = len;
234
235     // close the socket before writing to trigger writeError().
236     ::close(socket_->getFd());
237
238     wcb_->setSocket(socket_);
239
240     // Write back the same data.
241     socket_->write(wcb_, currentBuffer.buffer, len);
242
243     if (wcb_->state == STATE_FAILED) {
244       setState(STATE_SUCCEEDED);
245     } else {
246       state = STATE_FAILED;
247     }
248
249     buffers.push_back(currentBuffer);
250     currentBuffer.reset();
251   }
252
253   void readErr(const AsyncSocketException& ex) noexcept override {
254     std::cerr << "readError " << ex.what() << std::endl;
255     // do nothing since this is expected
256   }
257 };
258
259 class EmptyReadCallback : public ReadCallback {
260 public:
261   explicit EmptyReadCallback()
262       : ReadCallback(nullptr) {}
263
264   void readErr(const AsyncSocketException& ex) noexcept override {
265     std::cerr << "readError " << ex.what() << std::endl;
266     state = STATE_FAILED;
267     tcpSocket_->close();
268     tcpSocket_->detachEventBase();
269   }
270
271   void readEOF() noexcept override {
272     std::cerr << "readEOF" << std::endl;
273
274     tcpSocket_->close();
275     tcpSocket_->detachEventBase();
276     state = STATE_SUCCEEDED;
277   }
278
279   std::shared_ptr<AsyncSocket> tcpSocket_;
280 };
281
282 class HandshakeCallback :
283 public AsyncSSLSocket::HandshakeCB {
284 public:
285   enum ExpectType {
286     EXPECT_SUCCESS,
287     EXPECT_ERROR
288   };
289
290   explicit HandshakeCallback(ReadCallbackBase *rcb,
291                              ExpectType expect = EXPECT_SUCCESS):
292       state(STATE_WAITING),
293       rcb_(rcb),
294       expect_(expect) {}
295
296   void setSocket(
297     const std::shared_ptr<AsyncSSLSocket> &socket) {
298     socket_ = socket;
299   }
300
301   void setState(StateEnum s) {
302     state = s;
303     rcb_->setState(s);
304   }
305
306   // Functions inherited from AsyncSSLSocketHandshakeCallback
307   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
308     EXPECT_EQ(sock, socket_.get());
309     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
310     rcb_->setSocket(socket_);
311     sock->setReadCB(rcb_);
312     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
313   }
314   void handshakeErr(
315     AsyncSSLSocket *sock,
316     const AsyncSocketException& ex) noexcept override {
317     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
318     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
319     if (expect_ == EXPECT_ERROR) {
320       // rcb will never be invoked
321       rcb_->setState(STATE_SUCCEEDED);
322     }
323   }
324
325   ~HandshakeCallback() {
326     EXPECT_EQ(state, STATE_SUCCEEDED);
327   }
328
329   void closeSocket() {
330     socket_->close();
331     state = STATE_SUCCEEDED;
332   }
333
334   StateEnum state;
335   std::shared_ptr<AsyncSSLSocket> socket_;
336   ReadCallbackBase *rcb_;
337   ExpectType expect_;
338 };
339
340 class SSLServerAcceptCallbackBase:
341 public folly::AsyncServerSocket::AcceptCallback {
342 public:
343   explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
344   state(STATE_WAITING), hcb_(hcb) {}
345
346   ~SSLServerAcceptCallbackBase() {
347     EXPECT_EQ(state, STATE_SUCCEEDED);
348   }
349
350   void acceptError(const std::exception& ex) noexcept override {
351     std::cerr << "SSLServerAcceptCallbackBase::acceptError "
352               << ex.what() << std::endl;
353     state = STATE_FAILED;
354   }
355
356   void connectionAccepted(int fd, const folly::SocketAddress& clientAddr)
357     noexcept override{
358     printf("Connection accepted\n");
359     std::shared_ptr<AsyncSSLSocket> sslSock;
360     try {
361       // Create a AsyncSSLSocket object with the fd. The socket should be
362       // added to the event base and in the state of accepting SSL connection.
363       sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
364     } catch (const std::exception &e) {
365       LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
366         "object with socket " << e.what() << fd;
367       ::close(fd);
368       acceptError(e);
369       return;
370     }
371
372     connAccepted(sslSock);
373   }
374
375   virtual void connAccepted(
376     const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
377
378   StateEnum state;
379   HandshakeCallback *hcb_;
380   std::shared_ptr<folly::SSLContext> ctx_;
381   folly::EventBase* base_;
382 };
383
384 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
385 public:
386   uint32_t timeout_;
387
388   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
389                                    uint32_t timeout = 0):
390       SSLServerAcceptCallbackBase(hcb),
391       timeout_(timeout) {}
392
393   virtual ~SSLServerAcceptCallback() {
394     if (timeout_ > 0) {
395       // if we set a timeout, we expect failure
396       EXPECT_EQ(hcb_->state, STATE_FAILED);
397       hcb_->setState(STATE_SUCCEEDED);
398     }
399   }
400
401   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
402   void connAccepted(
403     const std::shared_ptr<folly::AsyncSSLSocket> &s)
404     noexcept override {
405     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
406     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
407
408     hcb_->setSocket(sock);
409     sock->sslAccept(hcb_, timeout_);
410     EXPECT_EQ(sock->getSSLState(),
411                       AsyncSSLSocket::STATE_ACCEPTING);
412
413     state = STATE_SUCCEEDED;
414   }
415 };
416
417 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
418 public:
419   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
420       SSLServerAcceptCallback(hcb) {}
421
422   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
423   void connAccepted(
424     const std::shared_ptr<folly::AsyncSSLSocket> &s)
425     noexcept override {
426
427     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
428
429     std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
430               << std::endl;
431     int fd = sock->getFd();
432
433 #ifndef TCP_NOPUSH
434     {
435     // The accepted connection should already have TCP_NODELAY set
436     int value;
437     socklen_t valueLength = sizeof(value);
438     int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
439     EXPECT_EQ(rc, 0);
440     EXPECT_EQ(value, 1);
441     }
442 #endif
443
444     // Unset the TCP_NODELAY option.
445     int value = 0;
446     socklen_t valueLength = sizeof(value);
447     int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
448     EXPECT_EQ(rc, 0);
449
450     rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
451     EXPECT_EQ(rc, 0);
452     EXPECT_EQ(value, 0);
453
454     SSLServerAcceptCallback::connAccepted(sock);
455   }
456 };
457
458 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
459 public:
460   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
461                                              uint32_t timeout = 0):
462     SSLServerAcceptCallback(hcb, timeout) {}
463
464   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
465   void connAccepted(
466     const std::shared_ptr<folly::AsyncSSLSocket> &s)
467     noexcept override {
468     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
469
470     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
471
472     hcb_->setSocket(sock);
473     sock->sslAccept(hcb_, timeout_);
474     ASSERT_TRUE((sock->getSSLState() ==
475                  AsyncSSLSocket::STATE_ACCEPTING) ||
476                 (sock->getSSLState() ==
477                  AsyncSSLSocket::STATE_CACHE_LOOKUP));
478
479     state = STATE_SUCCEEDED;
480   }
481 };
482
483
484 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
485 public:
486   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
487   SSLServerAcceptCallbackBase(hcb)  {}
488
489   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
490   void connAccepted(
491     const std::shared_ptr<folly::AsyncSSLSocket> &s)
492     noexcept override {
493     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
494
495     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
496
497     // The first call to sslAccept() should succeed.
498     hcb_->setSocket(sock);
499     sock->sslAccept(hcb_);
500     EXPECT_EQ(sock->getSSLState(),
501                       AsyncSSLSocket::STATE_ACCEPTING);
502
503     // The second call to sslAccept() should fail.
504     HandshakeCallback callback2(hcb_->rcb_);
505     callback2.setSocket(sock);
506     sock->sslAccept(&callback2);
507     EXPECT_EQ(sock->getSSLState(),
508                       AsyncSSLSocket::STATE_ERROR);
509
510     // Both callbacks should be in the error state.
511     EXPECT_EQ(hcb_->state, STATE_FAILED);
512     EXPECT_EQ(callback2.state, STATE_FAILED);
513
514     sock->detachEventBase();
515
516     state = STATE_SUCCEEDED;
517     hcb_->setState(STATE_SUCCEEDED);
518     callback2.setState(STATE_SUCCEEDED);
519   }
520 };
521
522 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
523 public:
524   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
525   SSLServerAcceptCallbackBase(hcb)  {}
526
527   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
528   void connAccepted(
529     const std::shared_ptr<folly::AsyncSSLSocket> &s)
530     noexcept override {
531     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
532
533     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
534
535     hcb_->setSocket(sock);
536     sock->getEventBase()->tryRunAfterDelay([=] {
537         std::cerr << "Delayed SSL accept, client will have close by now"
538                   << std::endl;
539         // SSL accept will fail
540         EXPECT_EQ(
541           sock->getSSLState(),
542           AsyncSSLSocket::STATE_UNINIT);
543         hcb_->socket_->sslAccept(hcb_);
544         // This registers for an event
545         EXPECT_EQ(
546           sock->getSSLState(),
547           AsyncSSLSocket::STATE_ACCEPTING);
548
549         state = STATE_SUCCEEDED;
550       }, 100);
551   }
552 };
553
554
555 class TestSSLServer {
556  protected:
557   EventBase evb_;
558   std::shared_ptr<folly::SSLContext> ctx_;
559   SSLServerAcceptCallbackBase *acb_;
560   std::shared_ptr<folly::AsyncServerSocket> socket_;
561   folly::SocketAddress address_;
562   pthread_t thread_;
563
564   static void *Main(void *ctx) {
565     TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
566     self->evb_.loop();
567     std::cerr << "Server thread exited event loop" << std::endl;
568     return nullptr;
569   }
570
571  public:
572   // Create a TestSSLServer.
573   // This immediately starts listening on the given port.
574   explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
575
576   // Kill the thread.
577   ~TestSSLServer() {
578     evb_.runInEventBaseThread([&](){
579       socket_->stopAccepting();
580     });
581     std::cerr << "Waiting for server thread to exit" << std::endl;
582     pthread_join(thread_, nullptr);
583   }
584
585   EventBase &getEventBase() { return evb_; }
586
587   const folly::SocketAddress& getAddress() const {
588     return address_;
589   }
590 };
591
592 class TestSSLAsyncCacheServer : public TestSSLServer {
593  public:
594   explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
595         int lookupDelay = 100) :
596       TestSSLServer(acb) {
597     SSL_CTX *sslCtx = ctx_->getSSLCtx();
598     SSL_CTX_sess_set_get_cb(sslCtx,
599                             TestSSLAsyncCacheServer::getSessionCallback);
600     SSL_CTX_set_session_cache_mode(
601       sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
602     asyncCallbacks_ = 0;
603     asyncLookups_ = 0;
604     lookupDelay_ = lookupDelay;
605   }
606
607   uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
608   uint32_t getAsyncLookups() const { return asyncLookups_; }
609
610  private:
611   static uint32_t asyncCallbacks_;
612   static uint32_t asyncLookups_;
613   static uint32_t lookupDelay_;
614
615   static SSL_SESSION *getSessionCallback(SSL *ssl,
616                                          unsigned char *sess_id,
617                                          int id_len,
618                                          int *copyflag) {
619     *copyflag = 0;
620     asyncCallbacks_++;
621 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
622     if (!SSL_want_sess_cache_lookup(ssl)) {
623       // libssl.so mismatch
624       std::cerr << "no async support" << std::endl;
625       return nullptr;
626     }
627
628     AsyncSSLSocket *sslSocket =
629         AsyncSSLSocket::getFromSSL(ssl);
630     assert(sslSocket != nullptr);
631     // Going to simulate an async cache by just running delaying the miss 100ms
632     if (asyncCallbacks_ % 2 == 0) {
633       // This socket is already blocked on lookup, return miss
634       std::cerr << "returning miss" << std::endl;
635     } else {
636       // fresh meat - block it
637       std::cerr << "async lookup" << std::endl;
638       sslSocket->getEventBase()->tryRunAfterDelay(
639         std::bind(&AsyncSSLSocket::restartSSLAccept,
640                   sslSocket), lookupDelay_);
641       *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
642       asyncLookups_++;
643     }
644 #endif
645     return nullptr;
646   }
647 };
648
649 void getfds(int fds[2]);
650
651 void getctx(
652   std::shared_ptr<folly::SSLContext> clientCtx,
653   std::shared_ptr<folly::SSLContext> serverCtx);
654
655 void sslsocketpair(
656   EventBase* eventBase,
657   AsyncSSLSocket::UniquePtr* clientSock,
658   AsyncSSLSocket::UniquePtr* serverSock);
659
660 class BlockingWriteClient :
661   private AsyncSSLSocket::HandshakeCB,
662   private AsyncTransportWrapper::WriteCallback {
663  public:
664   explicit BlockingWriteClient(
665     AsyncSSLSocket::UniquePtr socket)
666     : socket_(std::move(socket)),
667       bufLen_(2500),
668       iovCount_(2000) {
669     // Fill buf_
670     buf_.reset(new uint8_t[bufLen_]);
671     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
672       buf_[n] = n % 0xff;
673     }
674
675     // Initialize iov_
676     iov_.reset(new struct iovec[iovCount_]);
677     for (uint32_t n = 0; n < iovCount_; ++n) {
678       iov_[n].iov_base = buf_.get() + n;
679       if (n & 0x1) {
680         iov_[n].iov_len = n % bufLen_;
681       } else {
682         iov_[n].iov_len = bufLen_ - (n % bufLen_);
683       }
684     }
685
686     socket_->sslConn(this, 100);
687   }
688
689   struct iovec* getIovec() const {
690     return iov_.get();
691   }
692   uint32_t getIovecCount() const {
693     return iovCount_;
694   }
695
696  private:
697   void handshakeSuc(AsyncSSLSocket*) noexcept override {
698     socket_->writev(this, iov_.get(), iovCount_);
699   }
700   void handshakeErr(
701     AsyncSSLSocket*,
702     const AsyncSocketException& ex) noexcept override {
703     ADD_FAILURE() << "client handshake error: " << ex.what();
704   }
705   void writeSuccess() noexcept override {
706     socket_->close();
707   }
708   void writeErr(
709     size_t bytesWritten,
710     const AsyncSocketException& ex) noexcept override {
711     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
712                   << ex.what();
713   }
714
715   AsyncSSLSocket::UniquePtr socket_;
716   uint32_t bufLen_;
717   uint32_t iovCount_;
718   std::unique_ptr<uint8_t[]> buf_;
719   std::unique_ptr<struct iovec[]> iov_;
720 };
721
722 class BlockingWriteServer :
723     private AsyncSSLSocket::HandshakeCB,
724     private AsyncTransportWrapper::ReadCallback {
725  public:
726   explicit BlockingWriteServer(
727     AsyncSSLSocket::UniquePtr socket)
728     : socket_(std::move(socket)),
729       bufSize_(2500 * 2000),
730       bytesRead_(0) {
731     buf_.reset(new uint8_t[bufSize_]);
732     socket_->sslAccept(this, 100);
733   }
734
735   void checkBuffer(struct iovec* iov, uint32_t count) const {
736     uint32_t idx = 0;
737     for (uint32_t n = 0; n < count; ++n) {
738       size_t bytesLeft = bytesRead_ - idx;
739       int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
740                       std::min(iov[n].iov_len, bytesLeft));
741       if (rc != 0) {
742         FAIL() << "buffer mismatch at iovec " << n << "/" << count
743                << ": rc=" << rc;
744
745       }
746       if (iov[n].iov_len > bytesLeft) {
747         FAIL() << "server did not read enough data: "
748                << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
749                << " in iovec " << n << "/" << count;
750       }
751
752       idx += iov[n].iov_len;
753     }
754     if (idx != bytesRead_) {
755       ADD_FAILURE() << "server read extra data: " << bytesRead_
756                     << " bytes read; expected " << idx;
757     }
758   }
759
760  private:
761   void handshakeSuc(AsyncSSLSocket*) noexcept override {
762     // Wait 10ms before reading, so the client's writes will initially block.
763     socket_->getEventBase()->tryRunAfterDelay(
764         [this] { socket_->setReadCB(this); }, 10);
765   }
766   void handshakeErr(
767     AsyncSSLSocket*,
768     const AsyncSocketException& ex) noexcept override {
769     ADD_FAILURE() << "server handshake error: " << ex.what();
770   }
771   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
772     *bufReturn = buf_.get() + bytesRead_;
773     *lenReturn = bufSize_ - bytesRead_;
774   }
775   void readDataAvailable(size_t len) noexcept override {
776     bytesRead_ += len;
777     socket_->setReadCB(nullptr);
778     socket_->getEventBase()->tryRunAfterDelay(
779         [this] { socket_->setReadCB(this); }, 2);
780   }
781   void readEOF() noexcept override {
782     socket_->close();
783   }
784   void readErr(
785     const AsyncSocketException& ex) noexcept override {
786     ADD_FAILURE() << "server read error: " << ex.what();
787   }
788
789   AsyncSSLSocket::UniquePtr socket_;
790   uint32_t bufSize_;
791   uint32_t bytesRead_;
792   std::unique_ptr<uint8_t[]> buf_;
793 };
794
795 class NpnClient :
796   private AsyncSSLSocket::HandshakeCB,
797   private AsyncTransportWrapper::WriteCallback {
798  public:
799   explicit NpnClient(
800     AsyncSSLSocket::UniquePtr socket)
801       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
802     socket_->sslConn(this);
803   }
804
805   const unsigned char* nextProto;
806   unsigned nextProtoLength;
807  private:
808   void handshakeSuc(AsyncSSLSocket*) noexcept override {
809     socket_->getSelectedNextProtocol(&nextProto,
810                                      &nextProtoLength);
811   }
812   void handshakeErr(
813     AsyncSSLSocket*,
814     const AsyncSocketException& ex) noexcept override {
815     ADD_FAILURE() << "client handshake error: " << ex.what();
816   }
817   void writeSuccess() noexcept override {
818     socket_->close();
819   }
820   void writeErr(
821     size_t bytesWritten,
822     const AsyncSocketException& ex) noexcept override {
823     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
824                   << ex.what();
825   }
826
827   AsyncSSLSocket::UniquePtr socket_;
828 };
829
830 class NpnServer :
831     private AsyncSSLSocket::HandshakeCB,
832     private AsyncTransportWrapper::ReadCallback {
833  public:
834   explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
835       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
836     socket_->sslAccept(this);
837   }
838
839   const unsigned char* nextProto;
840   unsigned nextProtoLength;
841  private:
842   void handshakeSuc(AsyncSSLSocket*) noexcept override {
843     socket_->getSelectedNextProtocol(&nextProto,
844                                      &nextProtoLength);
845   }
846   void handshakeErr(
847     AsyncSSLSocket*,
848     const AsyncSocketException& ex) noexcept override {
849     ADD_FAILURE() << "server handshake error: " << ex.what();
850   }
851   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
852     *lenReturn = 0;
853   }
854   void readDataAvailable(size_t len) noexcept override {
855   }
856   void readEOF() noexcept override {
857     socket_->close();
858   }
859   void readErr(
860     const AsyncSocketException& ex) noexcept override {
861     ADD_FAILURE() << "server read error: " << ex.what();
862   }
863
864   AsyncSSLSocket::UniquePtr socket_;
865 };
866
867 #ifndef OPENSSL_NO_TLSEXT
868 class SNIClient :
869   private AsyncSSLSocket::HandshakeCB,
870   private AsyncTransportWrapper::WriteCallback {
871  public:
872   explicit SNIClient(
873     AsyncSSLSocket::UniquePtr socket)
874       : serverNameMatch(false), socket_(std::move(socket)) {
875     socket_->sslConn(this);
876   }
877
878   bool serverNameMatch;
879
880  private:
881   void handshakeSuc(AsyncSSLSocket*) noexcept override {
882     serverNameMatch = socket_->isServerNameMatch();
883   }
884   void handshakeErr(
885     AsyncSSLSocket*,
886     const AsyncSocketException& ex) noexcept override {
887     ADD_FAILURE() << "client handshake error: " << ex.what();
888   }
889   void writeSuccess() noexcept override {
890     socket_->close();
891   }
892   void writeErr(
893     size_t bytesWritten,
894     const AsyncSocketException& ex) noexcept override {
895     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
896                   << ex.what();
897   }
898
899   AsyncSSLSocket::UniquePtr socket_;
900 };
901
902 class SNIServer :
903     private AsyncSSLSocket::HandshakeCB,
904     private AsyncTransportWrapper::ReadCallback {
905  public:
906   explicit SNIServer(
907     AsyncSSLSocket::UniquePtr socket,
908     const std::shared_ptr<folly::SSLContext>& ctx,
909     const std::shared_ptr<folly::SSLContext>& sniCtx,
910     const std::string& expectedServerName)
911       : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
912         expectedServerName_(expectedServerName) {
913     ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
914                                          std::placeholders::_1));
915     socket_->sslAccept(this);
916   }
917
918   bool serverNameMatch;
919
920  private:
921   void handshakeSuc(AsyncSSLSocket* ssl) noexcept override {}
922   void handshakeErr(
923     AsyncSSLSocket*,
924     const AsyncSocketException& ex) noexcept override {
925     ADD_FAILURE() << "server handshake error: " << ex.what();
926   }
927   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
928     *lenReturn = 0;
929   }
930   void readDataAvailable(size_t len) noexcept override {
931   }
932   void readEOF() noexcept override {
933     socket_->close();
934   }
935   void readErr(
936     const AsyncSocketException& ex) noexcept override {
937     ADD_FAILURE() << "server read error: " << ex.what();
938   }
939
940   folly::SSLContext::ServerNameCallbackResult
941     serverNameCallback(SSL *ssl) {
942     const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
943     if (sniCtx_ &&
944         sn &&
945         !strcasecmp(expectedServerName_.c_str(), sn)) {
946       AsyncSSLSocket *sslSocket =
947           AsyncSSLSocket::getFromSSL(ssl);
948       sslSocket->switchServerSSLContext(sniCtx_);
949       serverNameMatch = true;
950       return folly::SSLContext::SERVER_NAME_FOUND;
951     } else {
952       serverNameMatch = false;
953       return folly::SSLContext::SERVER_NAME_NOT_FOUND;
954     }
955   }
956
957   AsyncSSLSocket::UniquePtr socket_;
958   std::shared_ptr<folly::SSLContext> sniCtx_;
959   std::string expectedServerName_;
960 };
961 #endif
962
963 class SSLClient : public AsyncSocket::ConnectCallback,
964                   public AsyncTransportWrapper::WriteCallback,
965                   public AsyncTransportWrapper::ReadCallback
966 {
967  private:
968   EventBase *eventBase_;
969   std::shared_ptr<AsyncSSLSocket> sslSocket_;
970   SSL_SESSION *session_;
971   std::shared_ptr<folly::SSLContext> ctx_;
972   uint32_t requests_;
973   folly::SocketAddress address_;
974   uint32_t timeout_;
975   char buf_[128];
976   char readbuf_[128];
977   uint32_t bytesRead_;
978   uint32_t hit_;
979   uint32_t miss_;
980   uint32_t errors_;
981   uint32_t writeAfterConnectErrors_;
982
983   // These settings test that we eventually drain the
984   // socket, even if the maxReadsPerEvent_ is hit during
985   // a event loop iteration.
986   static constexpr size_t kMaxReadsPerEvent = 2;
987   static constexpr size_t kMaxReadBufferSz =
988     sizeof(readbuf_) / kMaxReadsPerEvent / 2;  // 2 event loop iterations
989
990  public:
991   SSLClient(EventBase *eventBase,
992             const folly::SocketAddress& address,
993             uint32_t requests,
994             uint32_t timeout = 0)
995       : eventBase_(eventBase),
996         session_(nullptr),
997         requests_(requests),
998         address_(address),
999         timeout_(timeout),
1000         bytesRead_(0),
1001         hit_(0),
1002         miss_(0),
1003         errors_(0),
1004         writeAfterConnectErrors_(0) {
1005     ctx_.reset(new folly::SSLContext());
1006     ctx_->setOptions(SSL_OP_NO_TICKET);
1007     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1008     memset(buf_, 'a', sizeof(buf_));
1009   }
1010
1011   ~SSLClient() {
1012     if (session_) {
1013       SSL_SESSION_free(session_);
1014     }
1015     if (errors_ == 0) {
1016       EXPECT_EQ(bytesRead_, sizeof(buf_));
1017     }
1018   }
1019
1020   uint32_t getHit() const { return hit_; }
1021
1022   uint32_t getMiss() const { return miss_; }
1023
1024   uint32_t getErrors() const { return errors_; }
1025
1026   uint32_t getWriteAfterConnectErrors() const {
1027     return writeAfterConnectErrors_;
1028   }
1029
1030   void connect(bool writeNow = false) {
1031     sslSocket_ = AsyncSSLSocket::newSocket(
1032       ctx_, eventBase_);
1033     if (session_ != nullptr) {
1034       sslSocket_->setSSLSession(session_);
1035     }
1036     requests_--;
1037     sslSocket_->connect(this, address_, timeout_);
1038     if (sslSocket_ && writeNow) {
1039       // write some junk, used in an error test
1040       sslSocket_->write(this, buf_, sizeof(buf_));
1041     }
1042   }
1043
1044   void connectSuccess() noexcept override {
1045     std::cerr << "client SSL socket connected" << std::endl;
1046     if (sslSocket_->getSSLSessionReused()) {
1047       hit_++;
1048     } else {
1049       miss_++;
1050       if (session_ != nullptr) {
1051         SSL_SESSION_free(session_);
1052       }
1053       session_ = sslSocket_->getSSLSession();
1054     }
1055
1056     // write()
1057     sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1058     sslSocket_->write(this, buf_, sizeof(buf_));
1059     sslSocket_->setReadCB(this);
1060     memset(readbuf_, 'b', sizeof(readbuf_));
1061     bytesRead_ = 0;
1062   }
1063
1064   void connectErr(
1065     const AsyncSocketException& ex) noexcept override {
1066     std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1067     errors_++;
1068     sslSocket_.reset();
1069   }
1070
1071   void writeSuccess() noexcept override {
1072     std::cerr << "client write success" << std::endl;
1073   }
1074
1075   void writeErr(
1076     size_t bytesWritten,
1077     const AsyncSocketException& ex)
1078     noexcept override {
1079     std::cerr << "client writeError: " << ex.what() << std::endl;
1080     if (!sslSocket_) {
1081       writeAfterConnectErrors_++;
1082     }
1083   }
1084
1085   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1086     *bufReturn = readbuf_ + bytesRead_;
1087     *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1088   }
1089
1090   void readEOF() noexcept override {
1091     std::cerr << "client readEOF" << std::endl;
1092   }
1093
1094   void readErr(
1095     const AsyncSocketException& ex) noexcept override {
1096     std::cerr << "client readError: " << ex.what() << std::endl;
1097   }
1098
1099   void readDataAvailable(size_t len) noexcept override {
1100     std::cerr << "client read data: " << len << std::endl;
1101     bytesRead_ += len;
1102     if (bytesRead_ == sizeof(buf_)) {
1103       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1104       sslSocket_->closeNow();
1105       sslSocket_.reset();
1106       if (requests_ != 0) {
1107         connect();
1108       }
1109     }
1110   }
1111
1112 };
1113
1114 class SSLHandshakeBase :
1115   public AsyncSSLSocket::HandshakeCB,
1116   private AsyncTransportWrapper::WriteCallback {
1117  public:
1118   explicit SSLHandshakeBase(
1119    AsyncSSLSocket::UniquePtr socket,
1120    bool preverifyResult,
1121    bool verifyResult) :
1122     handshakeVerify_(false),
1123     handshakeSuccess_(false),
1124     handshakeError_(false),
1125     socket_(std::move(socket)),
1126     preverifyResult_(preverifyResult),
1127     verifyResult_(verifyResult) {
1128   }
1129
1130   bool handshakeVerify_;
1131   bool handshakeSuccess_;
1132   bool handshakeError_;
1133   std::chrono::nanoseconds handshakeTime;
1134
1135  protected:
1136   AsyncSSLSocket::UniquePtr socket_;
1137   bool preverifyResult_;
1138   bool verifyResult_;
1139
1140   // HandshakeCallback
1141   bool handshakeVer(
1142    AsyncSSLSocket* sock,
1143    bool preverifyOk,
1144    X509_STORE_CTX* ctx) noexcept override {
1145     handshakeVerify_ = true;
1146
1147     EXPECT_EQ(preverifyResult_, preverifyOk);
1148     return verifyResult_;
1149   }
1150
1151   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1152     handshakeSuccess_ = true;
1153     handshakeTime = socket_->getHandshakeTime();
1154   }
1155
1156   void handshakeErr(
1157    AsyncSSLSocket*,
1158    const AsyncSocketException& ex) noexcept override {
1159     handshakeError_ = true;
1160     handshakeTime = socket_->getHandshakeTime();
1161   }
1162
1163   // WriteCallback
1164   void writeSuccess() noexcept override {
1165     socket_->close();
1166   }
1167
1168   void writeErr(
1169    size_t bytesWritten,
1170    const AsyncSocketException& ex) noexcept override {
1171     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1172                   << ex.what();
1173   }
1174 };
1175
1176 class SSLHandshakeClient : public SSLHandshakeBase {
1177  public:
1178   SSLHandshakeClient(
1179    AsyncSSLSocket::UniquePtr socket,
1180    bool preverifyResult,
1181    bool verifyResult) :
1182     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1183     socket_->sslConn(this, 0);
1184   }
1185 };
1186
1187 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1188  public:
1189   SSLHandshakeClientNoVerify(
1190    AsyncSSLSocket::UniquePtr socket,
1191    bool preverifyResult,
1192    bool verifyResult) :
1193     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1194     socket_->sslConn(this, 0,
1195       folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1196   }
1197 };
1198
1199 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1200  public:
1201   SSLHandshakeClientDoVerify(
1202    AsyncSSLSocket::UniquePtr socket,
1203    bool preverifyResult,
1204    bool verifyResult) :
1205     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1206     socket_->sslConn(this, 0,
1207       folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1208   }
1209 };
1210
1211 class SSLHandshakeServer : public SSLHandshakeBase {
1212  public:
1213   SSLHandshakeServer(
1214       AsyncSSLSocket::UniquePtr socket,
1215       bool preverifyResult,
1216       bool verifyResult)
1217     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1218     socket_->sslAccept(this, 0);
1219   }
1220 };
1221
1222 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1223  public:
1224   SSLHandshakeServerParseClientHello(
1225       AsyncSSLSocket::UniquePtr socket,
1226       bool preverifyResult,
1227       bool verifyResult)
1228       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1229     socket_->enableClientHelloParsing();
1230     socket_->sslAccept(this, 0);
1231   }
1232
1233   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1234
1235  protected:
1236   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1237     handshakeSuccess_ = true;
1238     sock->getSSLSharedCiphers(sharedCiphers_);
1239     sock->getSSLServerCiphers(serverCiphers_);
1240     sock->getSSLClientCiphers(clientCiphers_);
1241     chosenCipher_ = sock->getNegotiatedCipherName();
1242   }
1243 };
1244
1245
1246 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1247  public:
1248   SSLHandshakeServerNoVerify(
1249       AsyncSSLSocket::UniquePtr socket,
1250       bool preverifyResult,
1251       bool verifyResult)
1252     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1253     socket_->sslAccept(this, 0,
1254       folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1255   }
1256 };
1257
1258 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1259  public:
1260   SSLHandshakeServerDoVerify(
1261       AsyncSSLSocket::UniquePtr socket,
1262       bool preverifyResult,
1263       bool verifyResult)
1264     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1265     socket_->sslAccept(this, 0,
1266       folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1267   }
1268 };
1269
1270 class EventBaseAborter : public AsyncTimeout {
1271  public:
1272   EventBaseAborter(EventBase* eventBase,
1273                    uint32_t timeoutMS)
1274     : AsyncTimeout(
1275       eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1276     , eventBase_(eventBase) {
1277     scheduleTimeout(timeoutMS);
1278   }
1279
1280   void timeoutExpired() noexcept override {
1281     FAIL() << "test timed out";
1282     eventBase_->terminateLoopSoon();
1283   }
1284
1285  private:
1286   EventBase* eventBase_;
1287 };
1288
1289 }