Split out SSL test server for reuse
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/ExceptionWrapper.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/experimental/TestUtil.h>
24 #include <folly/io/async/AsyncSSLSocket.h>
25 #include <folly/io/async/AsyncServerSocket.h>
26 #include <folly/io/async/AsyncSocket.h>
27 #include <folly/io/async/AsyncTimeout.h>
28 #include <folly/io/async/AsyncTransport.h>
29 #include <folly/io/async/EventBase.h>
30 #include <folly/io/async/ssl/SSLErrors.h>
31 #include <folly/io/async/test/TestSSLServer.h>
32 #include <folly/portability/GTest.h>
33 #include <folly/portability/Sockets.h>
34 #include <folly/portability/Unistd.h>
35
36 #include <fcntl.h>
37 #include <sys/types.h>
38 #include <condition_variable>
39 #include <iostream>
40 #include <list>
41
42 namespace folly {
43
44 // The destructors of all callback classes assert that the state is
45 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
46 // are responsible for setting the succeeded state properly before the
47 // destructors are called.
48
49 class WriteCallbackBase :
50 public AsyncTransportWrapper::WriteCallback {
51 public:
52   WriteCallbackBase()
53       : state(STATE_WAITING)
54       , bytesWritten(0)
55       , exception(AsyncSocketException::UNKNOWN, "none") {}
56
57   ~WriteCallbackBase() {
58     EXPECT_EQ(STATE_SUCCEEDED, state);
59   }
60
61   void setSocket(
62     const std::shared_ptr<AsyncSSLSocket> &socket) {
63     socket_ = socket;
64   }
65
66   void writeSuccess() noexcept override {
67     std::cerr << "writeSuccess" << std::endl;
68     state = STATE_SUCCEEDED;
69   }
70
71   void writeErr(
72     size_t nBytesWritten,
73     const AsyncSocketException& ex) noexcept override {
74     std::cerr << "writeError: bytesWritten " << nBytesWritten
75          << ", exception " << ex.what() << std::endl;
76
77     state = STATE_FAILED;
78     this->bytesWritten = nBytesWritten;
79     exception = ex;
80     socket_->close();
81   }
82
83   std::shared_ptr<AsyncSSLSocket> socket_;
84   StateEnum state;
85   size_t bytesWritten;
86   AsyncSocketException exception;
87 };
88
89 class ReadCallbackBase :
90 public AsyncTransportWrapper::ReadCallback {
91  public:
92   explicit ReadCallbackBase(WriteCallbackBase* wcb)
93       : wcb_(wcb), state(STATE_WAITING) {}
94
95   ~ReadCallbackBase() {
96     EXPECT_EQ(STATE_SUCCEEDED, state);
97   }
98
99   void setSocket(
100     const std::shared_ptr<AsyncSSLSocket> &socket) {
101     socket_ = socket;
102   }
103
104   void setState(StateEnum s) {
105     state = s;
106     if (wcb_) {
107       wcb_->state = s;
108     }
109   }
110
111   void readErr(
112     const AsyncSocketException& ex) noexcept override {
113     std::cerr << "readError " << ex.what() << std::endl;
114     state = STATE_FAILED;
115     socket_->close();
116   }
117
118   void readEOF() noexcept override {
119     std::cerr << "readEOF" << std::endl;
120
121     socket_->close();
122   }
123
124   std::shared_ptr<AsyncSSLSocket> socket_;
125   WriteCallbackBase *wcb_;
126   StateEnum state;
127 };
128
129 class ReadCallback : public ReadCallbackBase {
130 public:
131   explicit ReadCallback(WriteCallbackBase *wcb)
132       : ReadCallbackBase(wcb)
133       , buffers() {}
134
135   ~ReadCallback() {
136     for (std::vector<Buffer>::iterator it = buffers.begin();
137          it != buffers.end();
138          ++it) {
139       it->free();
140     }
141     currentBuffer.free();
142   }
143
144   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
145     if (!currentBuffer.buffer) {
146       currentBuffer.allocate(4096);
147     }
148     *bufReturn = currentBuffer.buffer;
149     *lenReturn = currentBuffer.length;
150   }
151
152   void readDataAvailable(size_t len) noexcept override {
153     std::cerr << "readDataAvailable, len " << len << std::endl;
154
155     currentBuffer.length = len;
156
157     wcb_->setSocket(socket_);
158
159     // Write back the same data.
160     socket_->write(wcb_, currentBuffer.buffer, len);
161
162     buffers.push_back(currentBuffer);
163     currentBuffer.reset();
164     state = STATE_SUCCEEDED;
165   }
166
167   class Buffer {
168   public:
169     Buffer() : buffer(nullptr), length(0) {}
170     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
171
172     void reset() {
173       buffer = nullptr;
174       length = 0;
175     }
176     void allocate(size_t len) {
177       assert(buffer == nullptr);
178       this->buffer = static_cast<char*>(malloc(len));
179       this->length = len;
180     }
181     void free() {
182       ::free(buffer);
183       reset();
184     }
185
186     char* buffer;
187     size_t length;
188   };
189
190   std::vector<Buffer> buffers;
191   Buffer currentBuffer;
192 };
193
194 class ReadErrorCallback : public ReadCallbackBase {
195 public:
196   explicit ReadErrorCallback(WriteCallbackBase *wcb)
197       : ReadCallbackBase(wcb) {}
198
199   // Return nullptr buffer to trigger readError()
200   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
201     *bufReturn = nullptr;
202     *lenReturn = 0;
203   }
204
205   void readDataAvailable(size_t /* len */) noexcept override {
206     // This should never to called.
207     FAIL();
208   }
209
210   void readErr(
211     const AsyncSocketException& ex) noexcept override {
212     ReadCallbackBase::readErr(ex);
213     std::cerr << "ReadErrorCallback::readError" << std::endl;
214     setState(STATE_SUCCEEDED);
215   }
216 };
217
218 class ReadEOFCallback : public ReadCallbackBase {
219  public:
220   explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
221
222   // Return nullptr buffer to trigger readError()
223   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
224     *bufReturn = nullptr;
225     *lenReturn = 0;
226   }
227
228   void readDataAvailable(size_t /* len */) noexcept override {
229     // This should never to called.
230     FAIL();
231   }
232
233   void readEOF() noexcept override {
234     ReadCallbackBase::readEOF();
235     setState(STATE_SUCCEEDED);
236   }
237 };
238
239 class WriteErrorCallback : public ReadCallback {
240 public:
241   explicit WriteErrorCallback(WriteCallbackBase *wcb)
242       : ReadCallback(wcb) {}
243
244   void readDataAvailable(size_t len) noexcept override {
245     std::cerr << "readDataAvailable, len " << len << std::endl;
246
247     currentBuffer.length = len;
248
249     // close the socket before writing to trigger writeError().
250     ::close(socket_->getFd());
251
252     wcb_->setSocket(socket_);
253
254     // Write back the same data.
255     folly::test::msvcSuppressAbortOnInvalidParams([&] {
256       socket_->write(wcb_, currentBuffer.buffer, len);
257     });
258
259     if (wcb_->state == STATE_FAILED) {
260       setState(STATE_SUCCEEDED);
261     } else {
262       state = STATE_FAILED;
263     }
264
265     buffers.push_back(currentBuffer);
266     currentBuffer.reset();
267   }
268
269   void readErr(const AsyncSocketException& ex) noexcept override {
270     std::cerr << "readError " << ex.what() << std::endl;
271     // do nothing since this is expected
272   }
273 };
274
275 class EmptyReadCallback : public ReadCallback {
276 public:
277   explicit EmptyReadCallback()
278       : ReadCallback(nullptr) {}
279
280   void readErr(const AsyncSocketException& ex) noexcept override {
281     std::cerr << "readError " << ex.what() << std::endl;
282     state = STATE_FAILED;
283     if (tcpSocket_) {
284       tcpSocket_->close();
285     }
286   }
287
288   void readEOF() noexcept override {
289     std::cerr << "readEOF" << std::endl;
290     if (tcpSocket_) {
291       tcpSocket_->close();
292     }
293     state = STATE_SUCCEEDED;
294   }
295
296   std::shared_ptr<AsyncSocket> tcpSocket_;
297 };
298
299 class HandshakeCallback :
300 public AsyncSSLSocket::HandshakeCB {
301 public:
302   enum ExpectType {
303     EXPECT_SUCCESS,
304     EXPECT_ERROR
305   };
306
307   explicit HandshakeCallback(ReadCallbackBase *rcb,
308                              ExpectType expect = EXPECT_SUCCESS):
309       state(STATE_WAITING),
310       rcb_(rcb),
311       expect_(expect) {}
312
313   void setSocket(
314     const std::shared_ptr<AsyncSSLSocket> &socket) {
315     socket_ = socket;
316   }
317
318   void setState(StateEnum s) {
319     state = s;
320     rcb_->setState(s);
321   }
322
323   // Functions inherited from AsyncSSLSocketHandshakeCallback
324   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
325     std::lock_guard<std::mutex> g(mutex_);
326     cv_.notify_all();
327     EXPECT_EQ(sock, socket_.get());
328     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
329     rcb_->setSocket(socket_);
330     sock->setReadCB(rcb_);
331     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
332   }
333   void handshakeErr(AsyncSSLSocket* /* sock */,
334                     const AsyncSocketException& ex) noexcept override {
335     std::lock_guard<std::mutex> g(mutex_);
336     cv_.notify_all();
337     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
338     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
339     if (expect_ == EXPECT_ERROR) {
340       // rcb will never be invoked
341       rcb_->setState(STATE_SUCCEEDED);
342     }
343     errorString_ = ex.what();
344   }
345
346   void waitForHandshake() {
347     std::unique_lock<std::mutex> lock(mutex_);
348     cv_.wait(lock, [this] { return state != STATE_WAITING; });
349   }
350
351   ~HandshakeCallback() {
352     EXPECT_EQ(STATE_SUCCEEDED, state);
353   }
354
355   void closeSocket() {
356     socket_->close();
357     state = STATE_SUCCEEDED;
358   }
359
360   std::shared_ptr<AsyncSSLSocket> getSocket() {
361     return socket_;
362   }
363
364   StateEnum state;
365   std::shared_ptr<AsyncSSLSocket> socket_;
366   ReadCallbackBase *rcb_;
367   ExpectType expect_;
368   std::mutex mutex_;
369   std::condition_variable cv_;
370   std::string errorString_;
371 };
372
373 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
374 public:
375   uint32_t timeout_;
376
377   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
378                                    uint32_t timeout = 0):
379       SSLServerAcceptCallbackBase(hcb),
380       timeout_(timeout) {}
381
382   virtual ~SSLServerAcceptCallback() {
383     if (timeout_ > 0) {
384       // if we set a timeout, we expect failure
385       EXPECT_EQ(hcb_->state, STATE_FAILED);
386       hcb_->setState(STATE_SUCCEEDED);
387     }
388   }
389
390   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
391   void connAccepted(
392     const std::shared_ptr<folly::AsyncSSLSocket> &s)
393     noexcept override {
394     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
395     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
396
397     hcb_->setSocket(sock);
398     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
399     EXPECT_EQ(sock->getSSLState(),
400                       AsyncSSLSocket::STATE_ACCEPTING);
401
402     state = STATE_SUCCEEDED;
403   }
404 };
405
406 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
407 public:
408   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
409       SSLServerAcceptCallback(hcb) {}
410
411   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
412   void connAccepted(
413     const std::shared_ptr<folly::AsyncSSLSocket> &s)
414     noexcept override {
415
416     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
417
418     std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
419               << std::endl;
420     int fd = sock->getFd();
421
422 #ifndef TCP_NOPUSH
423     {
424     // The accepted connection should already have TCP_NODELAY set
425     int value;
426     socklen_t valueLength = sizeof(value);
427     int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
428     EXPECT_EQ(rc, 0);
429     EXPECT_EQ(value, 1);
430     }
431 #endif
432
433     // Unset the TCP_NODELAY option.
434     int value = 0;
435     socklen_t valueLength = sizeof(value);
436     int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
437     EXPECT_EQ(rc, 0);
438
439     rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
440     EXPECT_EQ(rc, 0);
441     EXPECT_EQ(value, 0);
442
443     SSLServerAcceptCallback::connAccepted(sock);
444   }
445 };
446
447 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
448 public:
449   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
450                                              uint32_t timeout = 0):
451     SSLServerAcceptCallback(hcb, timeout) {}
452
453   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
454   void connAccepted(
455     const std::shared_ptr<folly::AsyncSSLSocket> &s)
456     noexcept override {
457     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
458
459     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
460
461     hcb_->setSocket(sock);
462     sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
463     ASSERT_TRUE((sock->getSSLState() ==
464                  AsyncSSLSocket::STATE_ACCEPTING) ||
465                 (sock->getSSLState() ==
466                  AsyncSSLSocket::STATE_CACHE_LOOKUP));
467
468     state = STATE_SUCCEEDED;
469   }
470 };
471
472
473 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
474 public:
475   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
476   SSLServerAcceptCallbackBase(hcb)  {}
477
478   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
479   void connAccepted(
480     const std::shared_ptr<folly::AsyncSSLSocket> &s)
481     noexcept override {
482     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
483
484     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
485
486     // The first call to sslAccept() should succeed.
487     hcb_->setSocket(sock);
488     sock->sslAccept(hcb_);
489     EXPECT_EQ(sock->getSSLState(),
490                       AsyncSSLSocket::STATE_ACCEPTING);
491
492     // The second call to sslAccept() should fail.
493     HandshakeCallback callback2(hcb_->rcb_);
494     callback2.setSocket(sock);
495     sock->sslAccept(&callback2);
496     EXPECT_EQ(sock->getSSLState(),
497                       AsyncSSLSocket::STATE_ERROR);
498
499     // Both callbacks should be in the error state.
500     EXPECT_EQ(hcb_->state, STATE_FAILED);
501     EXPECT_EQ(callback2.state, STATE_FAILED);
502
503     state = STATE_SUCCEEDED;
504     hcb_->setState(STATE_SUCCEEDED);
505     callback2.setState(STATE_SUCCEEDED);
506   }
507 };
508
509 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
510 public:
511   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
512   SSLServerAcceptCallbackBase(hcb)  {}
513
514   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
515   void connAccepted(
516     const std::shared_ptr<folly::AsyncSSLSocket> &s)
517     noexcept override {
518     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
519
520     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
521
522     hcb_->setSocket(sock);
523     sock->getEventBase()->tryRunAfterDelay([=] {
524         std::cerr << "Delayed SSL accept, client will have close by now"
525                   << std::endl;
526         // SSL accept will fail
527         EXPECT_EQ(
528           sock->getSSLState(),
529           AsyncSSLSocket::STATE_UNINIT);
530         hcb_->socket_->sslAccept(hcb_);
531         // This registers for an event
532         EXPECT_EQ(
533           sock->getSSLState(),
534           AsyncSSLSocket::STATE_ACCEPTING);
535
536         state = STATE_SUCCEEDED;
537       }, 100);
538   }
539 };
540
541 class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
542  public:
543   ConnectTimeoutCallback() : SSLServerAcceptCallbackBase(nullptr) {
544     // We don't care if we get invoked or not.
545     // The client may time out and give up before connAccepted() is even
546     // called.
547     state = STATE_SUCCEEDED;
548   }
549
550   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
551   void connAccepted(
552       const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
553     std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
554
555     // Just wait a while before closing the socket, so the client
556     // will time out waiting for the handshake to complete.
557     s->getEventBase()->tryRunAfterDelay([=] { s->close(); }, 100);
558   }
559 };
560
561 class TestSSLAsyncCacheServer : public TestSSLServer {
562  public:
563   explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
564         int lookupDelay = 100) :
565       TestSSLServer(acb) {
566     SSL_CTX *sslCtx = ctx_->getSSLCtx();
567     SSL_CTX_sess_set_get_cb(sslCtx,
568                             TestSSLAsyncCacheServer::getSessionCallback);
569     SSL_CTX_set_session_cache_mode(
570       sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
571     asyncCallbacks_ = 0;
572     asyncLookups_ = 0;
573     lookupDelay_ = lookupDelay;
574   }
575
576   uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
577   uint32_t getAsyncLookups() const { return asyncLookups_; }
578
579  private:
580   static uint32_t asyncCallbacks_;
581   static uint32_t asyncLookups_;
582   static uint32_t lookupDelay_;
583
584   static SSL_SESSION* getSessionCallback(SSL* ssl,
585                                          unsigned char* /* sess_id */,
586                                          int /* id_len */,
587                                          int* copyflag) {
588     *copyflag = 0;
589     asyncCallbacks_++;
590     (void)ssl;
591 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
592     if (!SSL_want_sess_cache_lookup(ssl)) {
593       // libssl.so mismatch
594       std::cerr << "no async support" << std::endl;
595       return nullptr;
596     }
597
598     AsyncSSLSocket *sslSocket =
599         AsyncSSLSocket::getFromSSL(ssl);
600     assert(sslSocket != nullptr);
601     // Going to simulate an async cache by just running delaying the miss 100ms
602     if (asyncCallbacks_ % 2 == 0) {
603       // This socket is already blocked on lookup, return miss
604       std::cerr << "returning miss" << std::endl;
605     } else {
606       // fresh meat - block it
607       std::cerr << "async lookup" << std::endl;
608       sslSocket->getEventBase()->tryRunAfterDelay(
609         std::bind(&AsyncSSLSocket::restartSSLAccept,
610                   sslSocket), lookupDelay_);
611       *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
612       asyncLookups_++;
613     }
614 #endif
615     return nullptr;
616   }
617 };
618
619 void getfds(int fds[2]);
620
621 void getctx(
622   std::shared_ptr<folly::SSLContext> clientCtx,
623   std::shared_ptr<folly::SSLContext> serverCtx);
624
625 void sslsocketpair(
626   EventBase* eventBase,
627   AsyncSSLSocket::UniquePtr* clientSock,
628   AsyncSSLSocket::UniquePtr* serverSock);
629
630 class BlockingWriteClient :
631   private AsyncSSLSocket::HandshakeCB,
632   private AsyncTransportWrapper::WriteCallback {
633  public:
634   explicit BlockingWriteClient(
635     AsyncSSLSocket::UniquePtr socket)
636     : socket_(std::move(socket)),
637       bufLen_(2500),
638       iovCount_(2000) {
639     // Fill buf_
640     buf_.reset(new uint8_t[bufLen_]);
641     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
642       buf_[n] = n % 0xff;
643     }
644
645     // Initialize iov_
646     iov_.reset(new struct iovec[iovCount_]);
647     for (uint32_t n = 0; n < iovCount_; ++n) {
648       iov_[n].iov_base = buf_.get() + n;
649       if (n & 0x1) {
650         iov_[n].iov_len = n % bufLen_;
651       } else {
652         iov_[n].iov_len = bufLen_ - (n % bufLen_);
653       }
654     }
655
656     socket_->sslConn(this, std::chrono::milliseconds(100));
657   }
658
659   struct iovec* getIovec() const {
660     return iov_.get();
661   }
662   uint32_t getIovecCount() const {
663     return iovCount_;
664   }
665
666  private:
667   void handshakeSuc(AsyncSSLSocket*) noexcept override {
668     socket_->writev(this, iov_.get(), iovCount_);
669   }
670   void handshakeErr(
671     AsyncSSLSocket*,
672     const AsyncSocketException& ex) noexcept override {
673     ADD_FAILURE() << "client handshake error: " << ex.what();
674   }
675   void writeSuccess() noexcept override {
676     socket_->close();
677   }
678   void writeErr(
679     size_t bytesWritten,
680     const AsyncSocketException& ex) noexcept override {
681     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
682                   << ex.what();
683   }
684
685   AsyncSSLSocket::UniquePtr socket_;
686   uint32_t bufLen_;
687   uint32_t iovCount_;
688   std::unique_ptr<uint8_t[]> buf_;
689   std::unique_ptr<struct iovec[]> iov_;
690 };
691
692 class BlockingWriteServer :
693     private AsyncSSLSocket::HandshakeCB,
694     private AsyncTransportWrapper::ReadCallback {
695  public:
696   explicit BlockingWriteServer(
697     AsyncSSLSocket::UniquePtr socket)
698     : socket_(std::move(socket)),
699       bufSize_(2500 * 2000),
700       bytesRead_(0) {
701     buf_.reset(new uint8_t[bufSize_]);
702     socket_->sslAccept(this, std::chrono::milliseconds(100));
703   }
704
705   void checkBuffer(struct iovec* iov, uint32_t count) const {
706     uint32_t idx = 0;
707     for (uint32_t n = 0; n < count; ++n) {
708       size_t bytesLeft = bytesRead_ - idx;
709       int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
710                       std::min(iov[n].iov_len, bytesLeft));
711       if (rc != 0) {
712         FAIL() << "buffer mismatch at iovec " << n << "/" << count
713                << ": rc=" << rc;
714
715       }
716       if (iov[n].iov_len > bytesLeft) {
717         FAIL() << "server did not read enough data: "
718                << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
719                << " in iovec " << n << "/" << count;
720       }
721
722       idx += iov[n].iov_len;
723     }
724     if (idx != bytesRead_) {
725       ADD_FAILURE() << "server read extra data: " << bytesRead_
726                     << " bytes read; expected " << idx;
727     }
728   }
729
730  private:
731   void handshakeSuc(AsyncSSLSocket*) noexcept override {
732     // Wait 10ms before reading, so the client's writes will initially block.
733     socket_->getEventBase()->tryRunAfterDelay(
734         [this] { socket_->setReadCB(this); }, 10);
735   }
736   void handshakeErr(
737     AsyncSSLSocket*,
738     const AsyncSocketException& ex) noexcept override {
739     ADD_FAILURE() << "server handshake error: " << ex.what();
740   }
741   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
742     *bufReturn = buf_.get() + bytesRead_;
743     *lenReturn = bufSize_ - bytesRead_;
744   }
745   void readDataAvailable(size_t len) noexcept override {
746     bytesRead_ += len;
747     socket_->setReadCB(nullptr);
748     socket_->getEventBase()->tryRunAfterDelay(
749         [this] { socket_->setReadCB(this); }, 2);
750   }
751   void readEOF() noexcept override {
752     socket_->close();
753   }
754   void readErr(
755     const AsyncSocketException& ex) noexcept override {
756     ADD_FAILURE() << "server read error: " << ex.what();
757   }
758
759   AsyncSSLSocket::UniquePtr socket_;
760   uint32_t bufSize_;
761   uint32_t bytesRead_;
762   std::unique_ptr<uint8_t[]> buf_;
763 };
764
765 class NpnClient :
766   private AsyncSSLSocket::HandshakeCB,
767   private AsyncTransportWrapper::WriteCallback {
768  public:
769   explicit NpnClient(
770     AsyncSSLSocket::UniquePtr socket)
771       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
772     socket_->sslConn(this);
773   }
774
775   const unsigned char* nextProto;
776   unsigned nextProtoLength;
777   SSLContext::NextProtocolType protocolType;
778
779  private:
780   void handshakeSuc(AsyncSSLSocket*) noexcept override {
781     socket_->getSelectedNextProtocol(
782         &nextProto, &nextProtoLength, &protocolType);
783   }
784   void handshakeErr(
785     AsyncSSLSocket*,
786     const AsyncSocketException& ex) noexcept override {
787     ADD_FAILURE() << "client handshake error: " << ex.what();
788   }
789   void writeSuccess() noexcept override {
790     socket_->close();
791   }
792   void writeErr(
793     size_t bytesWritten,
794     const AsyncSocketException& ex) noexcept override {
795     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
796                   << ex.what();
797   }
798
799   AsyncSSLSocket::UniquePtr socket_;
800 };
801
802 class NpnServer :
803     private AsyncSSLSocket::HandshakeCB,
804     private AsyncTransportWrapper::ReadCallback {
805  public:
806   explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
807       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
808     socket_->sslAccept(this);
809   }
810
811   const unsigned char* nextProto;
812   unsigned nextProtoLength;
813   SSLContext::NextProtocolType protocolType;
814
815  private:
816   void handshakeSuc(AsyncSSLSocket*) noexcept override {
817     socket_->getSelectedNextProtocol(
818         &nextProto, &nextProtoLength, &protocolType);
819   }
820   void handshakeErr(
821     AsyncSSLSocket*,
822     const AsyncSocketException& ex) noexcept override {
823     ADD_FAILURE() << "server handshake error: " << ex.what();
824   }
825   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
826     *lenReturn = 0;
827   }
828   void readDataAvailable(size_t /* len */) noexcept override {}
829   void readEOF() noexcept override {
830     socket_->close();
831   }
832   void readErr(
833     const AsyncSocketException& ex) noexcept override {
834     ADD_FAILURE() << "server read error: " << ex.what();
835   }
836
837   AsyncSSLSocket::UniquePtr socket_;
838 };
839
840 class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
841                             public AsyncTransportWrapper::ReadCallback {
842  public:
843   explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
844       : socket_(std::move(socket)) {
845     socket_->sslAccept(this);
846   }
847
848   ~RenegotiatingServer() {
849     socket_->setReadCB(nullptr);
850   }
851
852   void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
853     LOG(INFO) << "Renegotiating server handshake success";
854     socket_->setReadCB(this);
855   }
856   void handshakeErr(
857       AsyncSSLSocket*,
858       const AsyncSocketException& ex) noexcept override {
859     ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
860   }
861   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
862     *lenReturn = sizeof(buf);
863     *bufReturn = buf;
864   }
865   void readDataAvailable(size_t /* len */) noexcept override {}
866   void readEOF() noexcept override {}
867   void readErr(const AsyncSocketException& ex) noexcept override {
868     LOG(INFO) << "server got read error " << ex.what();
869     auto exPtr = dynamic_cast<const SSLException*>(&ex);
870     ASSERT_NE(nullptr, exPtr);
871     std::string exStr(ex.what());
872     SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
873     ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
874     renegotiationError_ = true;
875   }
876
877   AsyncSSLSocket::UniquePtr socket_;
878   unsigned char buf[128];
879   bool renegotiationError_{false};
880 };
881
882 #ifndef OPENSSL_NO_TLSEXT
883 class SNIClient :
884   private AsyncSSLSocket::HandshakeCB,
885   private AsyncTransportWrapper::WriteCallback {
886  public:
887   explicit SNIClient(
888     AsyncSSLSocket::UniquePtr socket)
889       : serverNameMatch(false), socket_(std::move(socket)) {
890     socket_->sslConn(this);
891   }
892
893   bool serverNameMatch;
894
895  private:
896   void handshakeSuc(AsyncSSLSocket*) noexcept override {
897     serverNameMatch = socket_->isServerNameMatch();
898   }
899   void handshakeErr(
900     AsyncSSLSocket*,
901     const AsyncSocketException& ex) noexcept override {
902     ADD_FAILURE() << "client handshake error: " << ex.what();
903   }
904   void writeSuccess() noexcept override {
905     socket_->close();
906   }
907   void writeErr(
908     size_t bytesWritten,
909     const AsyncSocketException& ex) noexcept override {
910     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
911                   << ex.what();
912   }
913
914   AsyncSSLSocket::UniquePtr socket_;
915 };
916
917 class SNIServer :
918     private AsyncSSLSocket::HandshakeCB,
919     private AsyncTransportWrapper::ReadCallback {
920  public:
921   explicit SNIServer(
922     AsyncSSLSocket::UniquePtr socket,
923     const std::shared_ptr<folly::SSLContext>& ctx,
924     const std::shared_ptr<folly::SSLContext>& sniCtx,
925     const std::string& expectedServerName)
926       : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
927         expectedServerName_(expectedServerName) {
928     ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
929                                          std::placeholders::_1));
930     socket_->sslAccept(this);
931   }
932
933   bool serverNameMatch;
934
935  private:
936   void handshakeSuc(AsyncSSLSocket* /* ssl */) noexcept override {}
937   void handshakeErr(
938     AsyncSSLSocket*,
939     const AsyncSocketException& ex) noexcept override {
940     ADD_FAILURE() << "server handshake error: " << ex.what();
941   }
942   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
943     *lenReturn = 0;
944   }
945   void readDataAvailable(size_t /* len */) noexcept override {}
946   void readEOF() noexcept override {
947     socket_->close();
948   }
949   void readErr(
950     const AsyncSocketException& ex) noexcept override {
951     ADD_FAILURE() << "server read error: " << ex.what();
952   }
953
954   folly::SSLContext::ServerNameCallbackResult
955     serverNameCallback(SSL *ssl) {
956     const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
957     if (sniCtx_ &&
958         sn &&
959         !strcasecmp(expectedServerName_.c_str(), sn)) {
960       AsyncSSLSocket *sslSocket =
961           AsyncSSLSocket::getFromSSL(ssl);
962       sslSocket->switchServerSSLContext(sniCtx_);
963       serverNameMatch = true;
964       return folly::SSLContext::SERVER_NAME_FOUND;
965     } else {
966       serverNameMatch = false;
967       return folly::SSLContext::SERVER_NAME_NOT_FOUND;
968     }
969   }
970
971   AsyncSSLSocket::UniquePtr socket_;
972   std::shared_ptr<folly::SSLContext> sniCtx_;
973   std::string expectedServerName_;
974 };
975 #endif
976
977 class SSLClient : public AsyncSocket::ConnectCallback,
978                   public AsyncTransportWrapper::WriteCallback,
979                   public AsyncTransportWrapper::ReadCallback
980 {
981  private:
982   EventBase *eventBase_;
983   std::shared_ptr<AsyncSSLSocket> sslSocket_;
984   SSL_SESSION *session_;
985   std::shared_ptr<folly::SSLContext> ctx_;
986   uint32_t requests_;
987   folly::SocketAddress address_;
988   uint32_t timeout_;
989   char buf_[128];
990   char readbuf_[128];
991   uint32_t bytesRead_;
992   uint32_t hit_;
993   uint32_t miss_;
994   uint32_t errors_;
995   uint32_t writeAfterConnectErrors_;
996
997   // These settings test that we eventually drain the
998   // socket, even if the maxReadsPerEvent_ is hit during
999   // a event loop iteration.
1000   static constexpr size_t kMaxReadsPerEvent = 2;
1001   // 2 event loop iterations
1002   static constexpr size_t kMaxReadBufferSz =
1003     sizeof(decltype(readbuf_)) / kMaxReadsPerEvent / 2;
1004
1005  public:
1006   SSLClient(EventBase *eventBase,
1007             const folly::SocketAddress& address,
1008             uint32_t requests,
1009             uint32_t timeout = 0)
1010       : eventBase_(eventBase),
1011         session_(nullptr),
1012         requests_(requests),
1013         address_(address),
1014         timeout_(timeout),
1015         bytesRead_(0),
1016         hit_(0),
1017         miss_(0),
1018         errors_(0),
1019         writeAfterConnectErrors_(0) {
1020     ctx_.reset(new folly::SSLContext());
1021     ctx_->setOptions(SSL_OP_NO_TICKET);
1022     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1023     memset(buf_, 'a', sizeof(buf_));
1024   }
1025
1026   ~SSLClient() {
1027     if (session_) {
1028       SSL_SESSION_free(session_);
1029     }
1030     if (errors_ == 0) {
1031       EXPECT_EQ(bytesRead_, sizeof(buf_));
1032     }
1033   }
1034
1035   uint32_t getHit() const { return hit_; }
1036
1037   uint32_t getMiss() const { return miss_; }
1038
1039   uint32_t getErrors() const { return errors_; }
1040
1041   uint32_t getWriteAfterConnectErrors() const {
1042     return writeAfterConnectErrors_;
1043   }
1044
1045   void connect(bool writeNow = false) {
1046     sslSocket_ = AsyncSSLSocket::newSocket(
1047       ctx_, eventBase_);
1048     if (session_ != nullptr) {
1049       sslSocket_->setSSLSession(session_);
1050     }
1051     requests_--;
1052     sslSocket_->connect(this, address_, timeout_);
1053     if (sslSocket_ && writeNow) {
1054       // write some junk, used in an error test
1055       sslSocket_->write(this, buf_, sizeof(buf_));
1056     }
1057   }
1058
1059   void connectSuccess() noexcept override {
1060     std::cerr << "client SSL socket connected" << std::endl;
1061     if (sslSocket_->getSSLSessionReused()) {
1062       hit_++;
1063     } else {
1064       miss_++;
1065       if (session_ != nullptr) {
1066         SSL_SESSION_free(session_);
1067       }
1068       session_ = sslSocket_->getSSLSession();
1069     }
1070
1071     // write()
1072     sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1073     sslSocket_->write(this, buf_, sizeof(buf_));
1074     sslSocket_->setReadCB(this);
1075     memset(readbuf_, 'b', sizeof(readbuf_));
1076     bytesRead_ = 0;
1077   }
1078
1079   void connectErr(
1080     const AsyncSocketException& ex) noexcept override {
1081     std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1082     errors_++;
1083     sslSocket_.reset();
1084   }
1085
1086   void writeSuccess() noexcept override {
1087     std::cerr << "client write success" << std::endl;
1088   }
1089
1090   void writeErr(size_t /* bytesWritten */,
1091                 const AsyncSocketException& ex) noexcept override {
1092     std::cerr << "client writeError: " << ex.what() << std::endl;
1093     if (!sslSocket_) {
1094       writeAfterConnectErrors_++;
1095     }
1096   }
1097
1098   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1099     *bufReturn = readbuf_ + bytesRead_;
1100     *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1101   }
1102
1103   void readEOF() noexcept override {
1104     std::cerr << "client readEOF" << std::endl;
1105   }
1106
1107   void readErr(
1108     const AsyncSocketException& ex) noexcept override {
1109     std::cerr << "client readError: " << ex.what() << std::endl;
1110   }
1111
1112   void readDataAvailable(size_t len) noexcept override {
1113     std::cerr << "client read data: " << len << std::endl;
1114     bytesRead_ += len;
1115     if (bytesRead_ == sizeof(buf_)) {
1116       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1117       sslSocket_->closeNow();
1118       sslSocket_.reset();
1119       if (requests_ != 0) {
1120         connect();
1121       }
1122     }
1123   }
1124
1125 };
1126
1127 class SSLHandshakeBase :
1128   public AsyncSSLSocket::HandshakeCB,
1129   private AsyncTransportWrapper::WriteCallback {
1130  public:
1131   explicit SSLHandshakeBase(
1132    AsyncSSLSocket::UniquePtr socket,
1133    bool preverifyResult,
1134    bool verifyResult) :
1135     handshakeVerify_(false),
1136     handshakeSuccess_(false),
1137     handshakeError_(false),
1138     socket_(std::move(socket)),
1139     preverifyResult_(preverifyResult),
1140     verifyResult_(verifyResult) {
1141   }
1142
1143   AsyncSSLSocket::UniquePtr moveSocket() && {
1144     return std::move(socket_);
1145   }
1146
1147   bool handshakeVerify_;
1148   bool handshakeSuccess_;
1149   bool handshakeError_;
1150   std::chrono::nanoseconds handshakeTime;
1151
1152  protected:
1153   AsyncSSLSocket::UniquePtr socket_;
1154   bool preverifyResult_;
1155   bool verifyResult_;
1156
1157   // HandshakeCallback
1158   bool handshakeVer(AsyncSSLSocket* /* sock */,
1159                     bool preverifyOk,
1160                     X509_STORE_CTX* /* ctx */) noexcept override {
1161     handshakeVerify_ = true;
1162
1163     EXPECT_EQ(preverifyResult_, preverifyOk);
1164     return verifyResult_;
1165   }
1166
1167   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1168     LOG(INFO) << "Handshake success";
1169     handshakeSuccess_ = true;
1170     handshakeTime = socket_->getHandshakeTime();
1171   }
1172
1173   void handshakeErr(
1174       AsyncSSLSocket*,
1175       const AsyncSocketException& ex) noexcept override {
1176     LOG(INFO) << "Handshake error " << ex.what();
1177     handshakeError_ = true;
1178     handshakeTime = socket_->getHandshakeTime();
1179   }
1180
1181   // WriteCallback
1182   void writeSuccess() noexcept override {
1183     socket_->close();
1184   }
1185
1186   void writeErr(
1187    size_t bytesWritten,
1188    const AsyncSocketException& ex) noexcept override {
1189     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1190                   << ex.what();
1191   }
1192 };
1193
1194 class SSLHandshakeClient : public SSLHandshakeBase {
1195  public:
1196   SSLHandshakeClient(
1197    AsyncSSLSocket::UniquePtr socket,
1198    bool preverifyResult,
1199    bool verifyResult) :
1200     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1201     socket_->sslConn(this, std::chrono::milliseconds::zero());
1202   }
1203 };
1204
1205 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1206  public:
1207   SSLHandshakeClientNoVerify(
1208    AsyncSSLSocket::UniquePtr socket,
1209    bool preverifyResult,
1210    bool verifyResult) :
1211     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1212     socket_->sslConn(
1213         this,
1214         std::chrono::milliseconds::zero(),
1215         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1216   }
1217 };
1218
1219 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1220  public:
1221   SSLHandshakeClientDoVerify(
1222    AsyncSSLSocket::UniquePtr socket,
1223    bool preverifyResult,
1224    bool verifyResult) :
1225     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1226     socket_->sslConn(
1227         this,
1228         std::chrono::milliseconds::zero(),
1229         folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1230   }
1231 };
1232
1233 class SSLHandshakeServer : public SSLHandshakeBase {
1234  public:
1235   SSLHandshakeServer(
1236       AsyncSSLSocket::UniquePtr socket,
1237       bool preverifyResult,
1238       bool verifyResult)
1239     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1240     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1241   }
1242 };
1243
1244 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1245  public:
1246   SSLHandshakeServerParseClientHello(
1247       AsyncSSLSocket::UniquePtr socket,
1248       bool preverifyResult,
1249       bool verifyResult)
1250       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1251     socket_->enableClientHelloParsing();
1252     socket_->sslAccept(this, std::chrono::milliseconds::zero());
1253   }
1254
1255   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1256
1257  protected:
1258   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1259     handshakeSuccess_ = true;
1260     sock->getSSLSharedCiphers(sharedCiphers_);
1261     sock->getSSLServerCiphers(serverCiphers_);
1262     sock->getSSLClientCiphers(clientCiphers_);
1263     chosenCipher_ = sock->getNegotiatedCipherName();
1264   }
1265 };
1266
1267
1268 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1269  public:
1270   SSLHandshakeServerNoVerify(
1271       AsyncSSLSocket::UniquePtr socket,
1272       bool preverifyResult,
1273       bool verifyResult)
1274     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1275     socket_->sslAccept(
1276         this,
1277         std::chrono::milliseconds::zero(),
1278         folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1279   }
1280 };
1281
1282 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1283  public:
1284   SSLHandshakeServerDoVerify(
1285       AsyncSSLSocket::UniquePtr socket,
1286       bool preverifyResult,
1287       bool verifyResult)
1288     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1289     socket_->sslAccept(
1290         this,
1291         std::chrono::milliseconds::zero(),
1292         folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1293   }
1294 };
1295
1296 class EventBaseAborter : public AsyncTimeout {
1297  public:
1298   EventBaseAborter(EventBase* eventBase,
1299                    uint32_t timeoutMS)
1300     : AsyncTimeout(
1301       eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1302     , eventBase_(eventBase) {
1303     scheduleTimeout(timeoutMS);
1304   }
1305
1306   void timeoutExpired() noexcept override {
1307     FAIL() << "test timed out";
1308     eventBase_->terminateLoopSoon();
1309   }
1310
1311  private:
1312   EventBase* eventBase_;
1313 };
1314
1315 }