683b2a6507acfa25a1a9c73cb8b5ad3cf013e490
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncServerSocket.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <folly/io/async/AsyncTransport.h>
25 #include <folly/io/async/EventBase.h>
26 #include <folly/io/async/AsyncTimeout.h>
27 #include <folly/SocketAddress.h>
28
29 #include <gtest/gtest.h>
30 #include <iostream>
31 #include <list>
32 #include <unistd.h>
33 #include <fcntl.h>
34 #include <poll.h>
35 #include <sys/types.h>
36 #include <sys/socket.h>
37 #include <netinet/tcp.h>
38
39 namespace folly {
40
41 enum StateEnum {
42   STATE_WAITING,
43   STATE_SUCCEEDED,
44   STATE_FAILED
45 };
46
47 // The destructors of all callback classes assert that the state is
48 // STATE_SUCCEEDED, for both possitive and negative tests. The tests
49 // are responsible for setting the succeeded state properly before the
50 // destructors are called.
51
52 class WriteCallbackBase :
53 public AsyncTransportWrapper::WriteCallback {
54 public:
55   WriteCallbackBase()
56       : state(STATE_WAITING)
57       , bytesWritten(0)
58       , exception(AsyncSocketException::UNKNOWN, "none") {}
59
60   ~WriteCallbackBase() {
61     EXPECT_EQ(state, STATE_SUCCEEDED);
62   }
63
64   void setSocket(
65     const std::shared_ptr<AsyncSSLSocket> &socket) {
66     socket_ = socket;
67   }
68
69   void writeSuccess() noexcept override {
70     std::cerr << "writeSuccess" << std::endl;
71     state = STATE_SUCCEEDED;
72   }
73
74   void writeErr(
75     size_t bytesWritten,
76     const AsyncSocketException& ex) noexcept override {
77     std::cerr << "writeError: bytesWritten " << bytesWritten
78          << ", exception " << ex.what() << std::endl;
79
80     state = STATE_FAILED;
81     this->bytesWritten = bytesWritten;
82     exception = ex;
83     socket_->close();
84     socket_->detachEventBase();
85   }
86
87   std::shared_ptr<AsyncSSLSocket> socket_;
88   StateEnum state;
89   size_t bytesWritten;
90   AsyncSocketException exception;
91 };
92
93 class ReadCallbackBase :
94 public AsyncTransportWrapper::ReadCallback {
95 public:
96   explicit ReadCallbackBase(WriteCallbackBase *wcb)
97       : wcb_(wcb)
98       , state(STATE_WAITING) {}
99
100   ~ReadCallbackBase() {
101     EXPECT_EQ(state, STATE_SUCCEEDED);
102   }
103
104   void setSocket(
105     const std::shared_ptr<AsyncSSLSocket> &socket) {
106     socket_ = socket;
107   }
108
109   void setState(StateEnum s) {
110     state = s;
111     if (wcb_) {
112       wcb_->state = s;
113     }
114   }
115
116   void readErr(
117     const AsyncSocketException& ex) noexcept override {
118     std::cerr << "readError " << ex.what() << std::endl;
119     state = STATE_FAILED;
120     socket_->close();
121     socket_->detachEventBase();
122   }
123
124   void readEOF() noexcept override {
125     std::cerr << "readEOF" << std::endl;
126
127     socket_->close();
128     socket_->detachEventBase();
129   }
130
131   std::shared_ptr<AsyncSSLSocket> socket_;
132   WriteCallbackBase *wcb_;
133   StateEnum state;
134 };
135
136 class ReadCallback : public ReadCallbackBase {
137 public:
138   explicit ReadCallback(WriteCallbackBase *wcb)
139       : ReadCallbackBase(wcb)
140       , buffers() {}
141
142   ~ReadCallback() {
143     for (std::vector<Buffer>::iterator it = buffers.begin();
144          it != buffers.end();
145          ++it) {
146       it->free();
147     }
148     currentBuffer.free();
149   }
150
151   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
152     if (!currentBuffer.buffer) {
153       currentBuffer.allocate(4096);
154     }
155     *bufReturn = currentBuffer.buffer;
156     *lenReturn = currentBuffer.length;
157   }
158
159   void readDataAvailable(size_t len) noexcept override {
160     std::cerr << "readDataAvailable, len " << len << std::endl;
161
162     currentBuffer.length = len;
163
164     wcb_->setSocket(socket_);
165
166     // Write back the same data.
167     socket_->write(wcb_, currentBuffer.buffer, len);
168
169     buffers.push_back(currentBuffer);
170     currentBuffer.reset();
171     state = STATE_SUCCEEDED;
172   }
173
174   class Buffer {
175   public:
176     Buffer() : buffer(nullptr), length(0) {}
177     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
178
179     void reset() {
180       buffer = nullptr;
181       length = 0;
182     }
183     void allocate(size_t length) {
184       assert(buffer == nullptr);
185       this->buffer = static_cast<char*>(malloc(length));
186       this->length = length;
187     }
188     void free() {
189       ::free(buffer);
190       reset();
191     }
192
193     char* buffer;
194     size_t length;
195   };
196
197   std::vector<Buffer> buffers;
198   Buffer currentBuffer;
199 };
200
201 class ReadErrorCallback : public ReadCallbackBase {
202 public:
203   explicit ReadErrorCallback(WriteCallbackBase *wcb)
204       : ReadCallbackBase(wcb) {}
205
206   // Return nullptr buffer to trigger readError()
207   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
208     *bufReturn = nullptr;
209     *lenReturn = 0;
210   }
211
212   void readDataAvailable(size_t len) noexcept override {
213     // This should never to called.
214     FAIL();
215   }
216
217   void readErr(
218     const AsyncSocketException& ex) noexcept override {
219     ReadCallbackBase::readErr(ex);
220     std::cerr << "ReadErrorCallback::readError" << std::endl;
221     setState(STATE_SUCCEEDED);
222   }
223 };
224
225 class WriteErrorCallback : public ReadCallback {
226 public:
227   explicit WriteErrorCallback(WriteCallbackBase *wcb)
228       : ReadCallback(wcb) {}
229
230   void readDataAvailable(size_t len) noexcept override {
231     std::cerr << "readDataAvailable, len " << len << std::endl;
232
233     currentBuffer.length = len;
234
235     // close the socket before writing to trigger writeError().
236     ::close(socket_->getFd());
237
238     wcb_->setSocket(socket_);
239
240     // Write back the same data.
241     socket_->write(wcb_, currentBuffer.buffer, len);
242
243     if (wcb_->state == STATE_FAILED) {
244       setState(STATE_SUCCEEDED);
245     } else {
246       state = STATE_FAILED;
247     }
248
249     buffers.push_back(currentBuffer);
250     currentBuffer.reset();
251   }
252
253   void readErr(const AsyncSocketException& ex) noexcept override {
254     std::cerr << "readError " << ex.what() << std::endl;
255     // do nothing since this is expected
256   }
257 };
258
259 class EmptyReadCallback : public ReadCallback {
260 public:
261   explicit EmptyReadCallback()
262       : ReadCallback(nullptr) {}
263
264   void readErr(const AsyncSocketException& ex) noexcept override {
265     std::cerr << "readError " << ex.what() << std::endl;
266     state = STATE_FAILED;
267     tcpSocket_->close();
268     tcpSocket_->detachEventBase();
269   }
270
271   void readEOF() noexcept override {
272     std::cerr << "readEOF" << std::endl;
273
274     tcpSocket_->close();
275     tcpSocket_->detachEventBase();
276     state = STATE_SUCCEEDED;
277   }
278
279   std::shared_ptr<AsyncSocket> tcpSocket_;
280 };
281
282 class HandshakeCallback :
283 public AsyncSSLSocket::HandshakeCB {
284 public:
285   enum ExpectType {
286     EXPECT_SUCCESS,
287     EXPECT_ERROR
288   };
289
290   explicit HandshakeCallback(ReadCallbackBase *rcb,
291                              ExpectType expect = EXPECT_SUCCESS):
292       state(STATE_WAITING),
293       rcb_(rcb),
294       expect_(expect) {}
295
296   void setSocket(
297     const std::shared_ptr<AsyncSSLSocket> &socket) {
298     socket_ = socket;
299   }
300
301   void setState(StateEnum s) {
302     state = s;
303     rcb_->setState(s);
304   }
305
306   // Functions inherited from AsyncSSLSocketHandshakeCallback
307   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
308     EXPECT_EQ(sock, socket_.get());
309     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
310     rcb_->setSocket(socket_);
311     sock->setReadCB(rcb_);
312     state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
313   }
314   void handshakeErr(
315     AsyncSSLSocket *sock,
316     const AsyncSocketException& ex) noexcept override {
317     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
318     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
319     if (expect_ == EXPECT_ERROR) {
320       // rcb will never be invoked
321       rcb_->setState(STATE_SUCCEEDED);
322     }
323   }
324
325   ~HandshakeCallback() {
326     EXPECT_EQ(state, STATE_SUCCEEDED);
327   }
328
329   void closeSocket() {
330     socket_->close();
331     state = STATE_SUCCEEDED;
332   }
333
334   StateEnum state;
335   std::shared_ptr<AsyncSSLSocket> socket_;
336   ReadCallbackBase *rcb_;
337   ExpectType expect_;
338 };
339
340 class SSLServerAcceptCallbackBase:
341 public folly::AsyncServerSocket::AcceptCallback {
342 public:
343   explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
344   state(STATE_WAITING), hcb_(hcb) {}
345
346   ~SSLServerAcceptCallbackBase() {
347     EXPECT_EQ(state, STATE_SUCCEEDED);
348   }
349
350   void acceptError(const std::exception& ex) noexcept override {
351     std::cerr << "SSLServerAcceptCallbackBase::acceptError "
352               << ex.what() << std::endl;
353     state = STATE_FAILED;
354   }
355
356   void connectionAccepted(int fd, const folly::SocketAddress& clientAddr)
357     noexcept override{
358     printf("Connection accepted\n");
359     std::shared_ptr<AsyncSSLSocket> sslSock;
360     try {
361       // Create a AsyncSSLSocket object with the fd. The socket should be
362       // added to the event base and in the state of accepting SSL connection.
363       sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
364     } catch (const std::exception &e) {
365       LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
366         "object with socket " << e.what() << fd;
367       ::close(fd);
368       acceptError(e);
369       return;
370     }
371
372     connAccepted(sslSock);
373   }
374
375   virtual void connAccepted(
376     const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
377
378   StateEnum state;
379   HandshakeCallback *hcb_;
380   std::shared_ptr<folly::SSLContext> ctx_;
381   folly::EventBase* base_;
382 };
383
384 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
385 public:
386   uint32_t timeout_;
387
388   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
389                                    uint32_t timeout = 0):
390       SSLServerAcceptCallbackBase(hcb),
391       timeout_(timeout) {}
392
393   virtual ~SSLServerAcceptCallback() {
394     if (timeout_ > 0) {
395       // if we set a timeout, we expect failure
396       EXPECT_EQ(hcb_->state, STATE_FAILED);
397       hcb_->setState(STATE_SUCCEEDED);
398     }
399   }
400
401   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
402   void connAccepted(
403     const std::shared_ptr<folly::AsyncSSLSocket> &s)
404     noexcept override {
405     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
406     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
407
408     hcb_->setSocket(sock);
409     sock->sslAccept(hcb_, timeout_);
410     EXPECT_EQ(sock->getSSLState(),
411                       AsyncSSLSocket::STATE_ACCEPTING);
412
413     state = STATE_SUCCEEDED;
414   }
415 };
416
417 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
418 public:
419   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
420       SSLServerAcceptCallback(hcb) {}
421
422   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
423   void connAccepted(
424     const std::shared_ptr<folly::AsyncSSLSocket> &s)
425     noexcept override {
426
427     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
428
429     std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
430               << std::endl;
431     int fd = sock->getFd();
432
433 #ifndef TCP_NOPUSH
434     {
435     // The accepted connection should already have TCP_NODELAY set
436     int value;
437     socklen_t valueLength = sizeof(value);
438     int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
439     EXPECT_EQ(rc, 0);
440     EXPECT_EQ(value, 1);
441     }
442 #endif
443
444     // Unset the TCP_NODELAY option.
445     int value = 0;
446     socklen_t valueLength = sizeof(value);
447     int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
448     EXPECT_EQ(rc, 0);
449
450     rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
451     EXPECT_EQ(rc, 0);
452     EXPECT_EQ(value, 0);
453
454     SSLServerAcceptCallback::connAccepted(sock);
455   }
456 };
457
458 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
459 public:
460   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
461                                              uint32_t timeout = 0):
462     SSLServerAcceptCallback(hcb, timeout) {}
463
464   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
465   void connAccepted(
466     const std::shared_ptr<folly::AsyncSSLSocket> &s)
467     noexcept override {
468     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
469
470     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
471
472     hcb_->setSocket(sock);
473     sock->sslAccept(hcb_, timeout_);
474     ASSERT_TRUE((sock->getSSLState() ==
475                  AsyncSSLSocket::STATE_ACCEPTING) ||
476                 (sock->getSSLState() ==
477                  AsyncSSLSocket::STATE_CACHE_LOOKUP));
478
479     state = STATE_SUCCEEDED;
480   }
481 };
482
483
484 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
485 public:
486   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
487   SSLServerAcceptCallbackBase(hcb)  {}
488
489   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
490   void connAccepted(
491     const std::shared_ptr<folly::AsyncSSLSocket> &s)
492     noexcept override {
493     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
494
495     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
496
497     // The first call to sslAccept() should succeed.
498     hcb_->setSocket(sock);
499     sock->sslAccept(hcb_);
500     EXPECT_EQ(sock->getSSLState(),
501                       AsyncSSLSocket::STATE_ACCEPTING);
502
503     // The second call to sslAccept() should fail.
504     HandshakeCallback callback2(hcb_->rcb_);
505     callback2.setSocket(sock);
506     sock->sslAccept(&callback2);
507     EXPECT_EQ(sock->getSSLState(),
508                       AsyncSSLSocket::STATE_ERROR);
509
510     // Both callbacks should be in the error state.
511     EXPECT_EQ(hcb_->state, STATE_FAILED);
512     EXPECT_EQ(callback2.state, STATE_FAILED);
513
514     sock->detachEventBase();
515
516     state = STATE_SUCCEEDED;
517     hcb_->setState(STATE_SUCCEEDED);
518     callback2.setState(STATE_SUCCEEDED);
519   }
520 };
521
522 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
523 public:
524   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
525   SSLServerAcceptCallbackBase(hcb)  {}
526
527   // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
528   void connAccepted(
529     const std::shared_ptr<folly::AsyncSSLSocket> &s)
530     noexcept override {
531     std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
532
533     auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
534
535     hcb_->setSocket(sock);
536     sock->getEventBase()->tryRunAfterDelay([=] {
537         std::cerr << "Delayed SSL accept, client will have close by now"
538                   << std::endl;
539         // SSL accept will fail
540         EXPECT_EQ(
541           sock->getSSLState(),
542           AsyncSSLSocket::STATE_UNINIT);
543         hcb_->socket_->sslAccept(hcb_);
544         // This registers for an event
545         EXPECT_EQ(
546           sock->getSSLState(),
547           AsyncSSLSocket::STATE_ACCEPTING);
548
549         state = STATE_SUCCEEDED;
550       }, 100);
551   }
552 };
553
554
555 class TestSSLServer {
556  protected:
557   EventBase evb_;
558   std::shared_ptr<folly::SSLContext> ctx_;
559   SSLServerAcceptCallbackBase *acb_;
560   std::shared_ptr<folly::AsyncServerSocket> socket_;
561   folly::SocketAddress address_;
562   pthread_t thread_;
563
564   static void *Main(void *ctx) {
565     TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
566     self->evb_.loop();
567     std::cerr << "Server thread exited event loop" << std::endl;
568     return nullptr;
569   }
570
571  public:
572   // Create a TestSSLServer.
573   // This immediately starts listening on the given port.
574   explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
575
576   // Kill the thread.
577   ~TestSSLServer() {
578     evb_.runInEventBaseThread([&](){
579       socket_->stopAccepting();
580     });
581     std::cerr << "Waiting for server thread to exit" << std::endl;
582     pthread_join(thread_, nullptr);
583   }
584
585   EventBase &getEventBase() { return evb_; }
586
587   const folly::SocketAddress& getAddress() const {
588     return address_;
589   }
590 };
591
592 class TestSSLAsyncCacheServer : public TestSSLServer {
593  public:
594   explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
595         int lookupDelay = 100) :
596       TestSSLServer(acb) {
597     SSL_CTX *sslCtx = ctx_->getSSLCtx();
598     SSL_CTX_sess_set_get_cb(sslCtx,
599                             TestSSLAsyncCacheServer::getSessionCallback);
600     SSL_CTX_set_session_cache_mode(
601       sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
602     asyncCallbacks_ = 0;
603     asyncLookups_ = 0;
604     lookupDelay_ = lookupDelay;
605   }
606
607   uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
608   uint32_t getAsyncLookups() const { return asyncLookups_; }
609
610  private:
611   static uint32_t asyncCallbacks_;
612   static uint32_t asyncLookups_;
613   static uint32_t lookupDelay_;
614
615   static SSL_SESSION *getSessionCallback(SSL *ssl,
616                                          unsigned char *sess_id,
617                                          int id_len,
618                                          int *copyflag) {
619     *copyflag = 0;
620     asyncCallbacks_++;
621 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
622     if (!SSL_want_sess_cache_lookup(ssl)) {
623       // libssl.so mismatch
624       std::cerr << "no async support" << std::endl;
625       return nullptr;
626     }
627
628     AsyncSSLSocket *sslSocket =
629         AsyncSSLSocket::getFromSSL(ssl);
630     assert(sslSocket != nullptr);
631     // Going to simulate an async cache by just running delaying the miss 100ms
632     if (asyncCallbacks_ % 2 == 0) {
633       // This socket is already blocked on lookup, return miss
634       std::cerr << "returning miss" << std::endl;
635     } else {
636       // fresh meat - block it
637       std::cerr << "async lookup" << std::endl;
638       sslSocket->getEventBase()->tryRunAfterDelay(
639         std::bind(&AsyncSSLSocket::restartSSLAccept,
640                   sslSocket), lookupDelay_);
641       *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
642       asyncLookups_++;
643     }
644 #endif
645     return nullptr;
646   }
647 };
648
649 void getfds(int fds[2]);
650
651 void getctx(
652   std::shared_ptr<folly::SSLContext> clientCtx,
653   std::shared_ptr<folly::SSLContext> serverCtx);
654
655 void sslsocketpair(
656   EventBase* eventBase,
657   AsyncSSLSocket::UniquePtr* clientSock,
658   AsyncSSLSocket::UniquePtr* serverSock);
659
660 class BlockingWriteClient :
661   private AsyncSSLSocket::HandshakeCB,
662   private AsyncTransportWrapper::WriteCallback {
663  public:
664   explicit BlockingWriteClient(
665     AsyncSSLSocket::UniquePtr socket)
666     : socket_(std::move(socket)),
667       bufLen_(2500),
668       iovCount_(2000) {
669     // Fill buf_
670     buf_.reset(new uint8_t[bufLen_]);
671     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
672       buf_[n] = n % 0xff;
673     }
674
675     // Initialize iov_
676     iov_.reset(new struct iovec[iovCount_]);
677     for (uint32_t n = 0; n < iovCount_; ++n) {
678       iov_[n].iov_base = buf_.get() + n;
679       if (n & 0x1) {
680         iov_[n].iov_len = n % bufLen_;
681       } else {
682         iov_[n].iov_len = bufLen_ - (n % bufLen_);
683       }
684     }
685
686     socket_->sslConn(this, 100);
687   }
688
689   struct iovec* getIovec() const {
690     return iov_.get();
691   }
692   uint32_t getIovecCount() const {
693     return iovCount_;
694   }
695
696  private:
697   void handshakeSuc(AsyncSSLSocket*) noexcept override {
698     socket_->writev(this, iov_.get(), iovCount_);
699   }
700   void handshakeErr(
701     AsyncSSLSocket*,
702     const AsyncSocketException& ex) noexcept override {
703     ADD_FAILURE() << "client handshake error: " << ex.what();
704   }
705   void writeSuccess() noexcept override {
706     socket_->close();
707   }
708   void writeErr(
709     size_t bytesWritten,
710     const AsyncSocketException& ex) noexcept override {
711     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
712                   << ex.what();
713   }
714
715   AsyncSSLSocket::UniquePtr socket_;
716   uint32_t bufLen_;
717   uint32_t iovCount_;
718   std::unique_ptr<uint8_t[]> buf_;
719   std::unique_ptr<struct iovec[]> iov_;
720 };
721
722 class BlockingWriteServer :
723     private AsyncSSLSocket::HandshakeCB,
724     private AsyncTransportWrapper::ReadCallback {
725  public:
726   explicit BlockingWriteServer(
727     AsyncSSLSocket::UniquePtr socket)
728     : socket_(std::move(socket)),
729       bufSize_(2500 * 2000),
730       bytesRead_(0) {
731     buf_.reset(new uint8_t[bufSize_]);
732     socket_->sslAccept(this, 100);
733   }
734
735   void checkBuffer(struct iovec* iov, uint32_t count) const {
736     uint32_t idx = 0;
737     for (uint32_t n = 0; n < count; ++n) {
738       size_t bytesLeft = bytesRead_ - idx;
739       int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
740                       std::min(iov[n].iov_len, bytesLeft));
741       if (rc != 0) {
742         FAIL() << "buffer mismatch at iovec " << n << "/" << count
743                << ": rc=" << rc;
744
745       }
746       if (iov[n].iov_len > bytesLeft) {
747         FAIL() << "server did not read enough data: "
748                << "ended at byte " << bytesLeft << "/" << iov[n].iov_len
749                << " in iovec " << n << "/" << count;
750       }
751
752       idx += iov[n].iov_len;
753     }
754     if (idx != bytesRead_) {
755       ADD_FAILURE() << "server read extra data: " << bytesRead_
756                     << " bytes read; expected " << idx;
757     }
758   }
759
760  private:
761   void handshakeSuc(AsyncSSLSocket*) noexcept override {
762     // Wait 10ms before reading, so the client's writes will initially block.
763     socket_->getEventBase()->tryRunAfterDelay(
764         [this] { socket_->setReadCB(this); }, 10);
765   }
766   void handshakeErr(
767     AsyncSSLSocket*,
768     const AsyncSocketException& ex) noexcept override {
769     ADD_FAILURE() << "server handshake error: " << ex.what();
770   }
771   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
772     *bufReturn = buf_.get() + bytesRead_;
773     *lenReturn = bufSize_ - bytesRead_;
774   }
775   void readDataAvailable(size_t len) noexcept override {
776     bytesRead_ += len;
777     socket_->setReadCB(nullptr);
778     socket_->getEventBase()->tryRunAfterDelay(
779         [this] { socket_->setReadCB(this); }, 2);
780   }
781   void readEOF() noexcept override {
782     socket_->close();
783   }
784   void readErr(
785     const AsyncSocketException& ex) noexcept override {
786     ADD_FAILURE() << "server read error: " << ex.what();
787   }
788
789   AsyncSSLSocket::UniquePtr socket_;
790   uint32_t bufSize_;
791   uint32_t bytesRead_;
792   std::unique_ptr<uint8_t[]> buf_;
793 };
794
795 class NpnClient :
796   private AsyncSSLSocket::HandshakeCB,
797   private AsyncTransportWrapper::WriteCallback {
798  public:
799   explicit NpnClient(
800     AsyncSSLSocket::UniquePtr socket)
801       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
802     socket_->sslConn(this);
803   }
804
805   const unsigned char* nextProto;
806   unsigned nextProtoLength;
807   SSLContext::NextProtocolType protocolType;
808
809  private:
810   void handshakeSuc(AsyncSSLSocket*) noexcept override {
811     socket_->getSelectedNextProtocol(
812         &nextProto, &nextProtoLength, &protocolType);
813   }
814   void handshakeErr(
815     AsyncSSLSocket*,
816     const AsyncSocketException& ex) noexcept override {
817     ADD_FAILURE() << "client handshake error: " << ex.what();
818   }
819   void writeSuccess() noexcept override {
820     socket_->close();
821   }
822   void writeErr(
823     size_t bytesWritten,
824     const AsyncSocketException& ex) noexcept override {
825     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
826                   << ex.what();
827   }
828
829   AsyncSSLSocket::UniquePtr socket_;
830 };
831
832 class NpnServer :
833     private AsyncSSLSocket::HandshakeCB,
834     private AsyncTransportWrapper::ReadCallback {
835  public:
836   explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
837       : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
838     socket_->sslAccept(this);
839   }
840
841   const unsigned char* nextProto;
842   unsigned nextProtoLength;
843   SSLContext::NextProtocolType protocolType;
844
845  private:
846   void handshakeSuc(AsyncSSLSocket*) noexcept override {
847     socket_->getSelectedNextProtocol(
848         &nextProto, &nextProtoLength, &protocolType);
849   }
850   void handshakeErr(
851     AsyncSSLSocket*,
852     const AsyncSocketException& ex) noexcept override {
853     ADD_FAILURE() << "server handshake error: " << ex.what();
854   }
855   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
856     *lenReturn = 0;
857   }
858   void readDataAvailable(size_t len) noexcept override {
859   }
860   void readEOF() noexcept override {
861     socket_->close();
862   }
863   void readErr(
864     const AsyncSocketException& ex) noexcept override {
865     ADD_FAILURE() << "server read error: " << ex.what();
866   }
867
868   AsyncSSLSocket::UniquePtr socket_;
869 };
870
871 #ifndef OPENSSL_NO_TLSEXT
872 class SNIClient :
873   private AsyncSSLSocket::HandshakeCB,
874   private AsyncTransportWrapper::WriteCallback {
875  public:
876   explicit SNIClient(
877     AsyncSSLSocket::UniquePtr socket)
878       : serverNameMatch(false), socket_(std::move(socket)) {
879     socket_->sslConn(this);
880   }
881
882   bool serverNameMatch;
883
884  private:
885   void handshakeSuc(AsyncSSLSocket*) noexcept override {
886     serverNameMatch = socket_->isServerNameMatch();
887   }
888   void handshakeErr(
889     AsyncSSLSocket*,
890     const AsyncSocketException& ex) noexcept override {
891     ADD_FAILURE() << "client handshake error: " << ex.what();
892   }
893   void writeSuccess() noexcept override {
894     socket_->close();
895   }
896   void writeErr(
897     size_t bytesWritten,
898     const AsyncSocketException& ex) noexcept override {
899     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
900                   << ex.what();
901   }
902
903   AsyncSSLSocket::UniquePtr socket_;
904 };
905
906 class SNIServer :
907     private AsyncSSLSocket::HandshakeCB,
908     private AsyncTransportWrapper::ReadCallback {
909  public:
910   explicit SNIServer(
911     AsyncSSLSocket::UniquePtr socket,
912     const std::shared_ptr<folly::SSLContext>& ctx,
913     const std::shared_ptr<folly::SSLContext>& sniCtx,
914     const std::string& expectedServerName)
915       : serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
916         expectedServerName_(expectedServerName) {
917     ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
918                                          std::placeholders::_1));
919     socket_->sslAccept(this);
920   }
921
922   bool serverNameMatch;
923
924  private:
925   void handshakeSuc(AsyncSSLSocket* ssl) noexcept override {}
926   void handshakeErr(
927     AsyncSSLSocket*,
928     const AsyncSocketException& ex) noexcept override {
929     ADD_FAILURE() << "server handshake error: " << ex.what();
930   }
931   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
932     *lenReturn = 0;
933   }
934   void readDataAvailable(size_t len) noexcept override {
935   }
936   void readEOF() noexcept override {
937     socket_->close();
938   }
939   void readErr(
940     const AsyncSocketException& ex) noexcept override {
941     ADD_FAILURE() << "server read error: " << ex.what();
942   }
943
944   folly::SSLContext::ServerNameCallbackResult
945     serverNameCallback(SSL *ssl) {
946     const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
947     if (sniCtx_ &&
948         sn &&
949         !strcasecmp(expectedServerName_.c_str(), sn)) {
950       AsyncSSLSocket *sslSocket =
951           AsyncSSLSocket::getFromSSL(ssl);
952       sslSocket->switchServerSSLContext(sniCtx_);
953       serverNameMatch = true;
954       return folly::SSLContext::SERVER_NAME_FOUND;
955     } else {
956       serverNameMatch = false;
957       return folly::SSLContext::SERVER_NAME_NOT_FOUND;
958     }
959   }
960
961   AsyncSSLSocket::UniquePtr socket_;
962   std::shared_ptr<folly::SSLContext> sniCtx_;
963   std::string expectedServerName_;
964 };
965 #endif
966
967 class SSLClient : public AsyncSocket::ConnectCallback,
968                   public AsyncTransportWrapper::WriteCallback,
969                   public AsyncTransportWrapper::ReadCallback
970 {
971  private:
972   EventBase *eventBase_;
973   std::shared_ptr<AsyncSSLSocket> sslSocket_;
974   SSL_SESSION *session_;
975   std::shared_ptr<folly::SSLContext> ctx_;
976   uint32_t requests_;
977   folly::SocketAddress address_;
978   uint32_t timeout_;
979   char buf_[128];
980   char readbuf_[128];
981   uint32_t bytesRead_;
982   uint32_t hit_;
983   uint32_t miss_;
984   uint32_t errors_;
985   uint32_t writeAfterConnectErrors_;
986
987   // These settings test that we eventually drain the
988   // socket, even if the maxReadsPerEvent_ is hit during
989   // a event loop iteration.
990   static constexpr size_t kMaxReadsPerEvent = 2;
991   static constexpr size_t kMaxReadBufferSz =
992     sizeof(readbuf_) / kMaxReadsPerEvent / 2;  // 2 event loop iterations
993
994  public:
995   SSLClient(EventBase *eventBase,
996             const folly::SocketAddress& address,
997             uint32_t requests,
998             uint32_t timeout = 0)
999       : eventBase_(eventBase),
1000         session_(nullptr),
1001         requests_(requests),
1002         address_(address),
1003         timeout_(timeout),
1004         bytesRead_(0),
1005         hit_(0),
1006         miss_(0),
1007         errors_(0),
1008         writeAfterConnectErrors_(0) {
1009     ctx_.reset(new folly::SSLContext());
1010     ctx_->setOptions(SSL_OP_NO_TICKET);
1011     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1012     memset(buf_, 'a', sizeof(buf_));
1013   }
1014
1015   ~SSLClient() {
1016     if (session_) {
1017       SSL_SESSION_free(session_);
1018     }
1019     if (errors_ == 0) {
1020       EXPECT_EQ(bytesRead_, sizeof(buf_));
1021     }
1022   }
1023
1024   uint32_t getHit() const { return hit_; }
1025
1026   uint32_t getMiss() const { return miss_; }
1027
1028   uint32_t getErrors() const { return errors_; }
1029
1030   uint32_t getWriteAfterConnectErrors() const {
1031     return writeAfterConnectErrors_;
1032   }
1033
1034   void connect(bool writeNow = false) {
1035     sslSocket_ = AsyncSSLSocket::newSocket(
1036       ctx_, eventBase_);
1037     if (session_ != nullptr) {
1038       sslSocket_->setSSLSession(session_);
1039     }
1040     requests_--;
1041     sslSocket_->connect(this, address_, timeout_);
1042     if (sslSocket_ && writeNow) {
1043       // write some junk, used in an error test
1044       sslSocket_->write(this, buf_, sizeof(buf_));
1045     }
1046   }
1047
1048   void connectSuccess() noexcept override {
1049     std::cerr << "client SSL socket connected" << std::endl;
1050     if (sslSocket_->getSSLSessionReused()) {
1051       hit_++;
1052     } else {
1053       miss_++;
1054       if (session_ != nullptr) {
1055         SSL_SESSION_free(session_);
1056       }
1057       session_ = sslSocket_->getSSLSession();
1058     }
1059
1060     // write()
1061     sslSocket_->setMaxReadsPerEvent(kMaxReadsPerEvent);
1062     sslSocket_->write(this, buf_, sizeof(buf_));
1063     sslSocket_->setReadCB(this);
1064     memset(readbuf_, 'b', sizeof(readbuf_));
1065     bytesRead_ = 0;
1066   }
1067
1068   void connectErr(
1069     const AsyncSocketException& ex) noexcept override {
1070     std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
1071     errors_++;
1072     sslSocket_.reset();
1073   }
1074
1075   void writeSuccess() noexcept override {
1076     std::cerr << "client write success" << std::endl;
1077   }
1078
1079   void writeErr(
1080     size_t bytesWritten,
1081     const AsyncSocketException& ex)
1082     noexcept override {
1083     std::cerr << "client writeError: " << ex.what() << std::endl;
1084     if (!sslSocket_) {
1085       writeAfterConnectErrors_++;
1086     }
1087   }
1088
1089   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
1090     *bufReturn = readbuf_ + bytesRead_;
1091     *lenReturn = std::min(kMaxReadBufferSz, sizeof(readbuf_) - bytesRead_);
1092   }
1093
1094   void readEOF() noexcept override {
1095     std::cerr << "client readEOF" << std::endl;
1096   }
1097
1098   void readErr(
1099     const AsyncSocketException& ex) noexcept override {
1100     std::cerr << "client readError: " << ex.what() << std::endl;
1101   }
1102
1103   void readDataAvailable(size_t len) noexcept override {
1104     std::cerr << "client read data: " << len << std::endl;
1105     bytesRead_ += len;
1106     if (bytesRead_ == sizeof(buf_)) {
1107       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
1108       sslSocket_->closeNow();
1109       sslSocket_.reset();
1110       if (requests_ != 0) {
1111         connect();
1112       }
1113     }
1114   }
1115
1116 };
1117
1118 class SSLHandshakeBase :
1119   public AsyncSSLSocket::HandshakeCB,
1120   private AsyncTransportWrapper::WriteCallback {
1121  public:
1122   explicit SSLHandshakeBase(
1123    AsyncSSLSocket::UniquePtr socket,
1124    bool preverifyResult,
1125    bool verifyResult) :
1126     handshakeVerify_(false),
1127     handshakeSuccess_(false),
1128     handshakeError_(false),
1129     socket_(std::move(socket)),
1130     preverifyResult_(preverifyResult),
1131     verifyResult_(verifyResult) {
1132   }
1133
1134   bool handshakeVerify_;
1135   bool handshakeSuccess_;
1136   bool handshakeError_;
1137   std::chrono::nanoseconds handshakeTime;
1138
1139  protected:
1140   AsyncSSLSocket::UniquePtr socket_;
1141   bool preverifyResult_;
1142   bool verifyResult_;
1143
1144   // HandshakeCallback
1145   bool handshakeVer(
1146    AsyncSSLSocket* sock,
1147    bool preverifyOk,
1148    X509_STORE_CTX* ctx) noexcept override {
1149     handshakeVerify_ = true;
1150
1151     EXPECT_EQ(preverifyResult_, preverifyOk);
1152     return verifyResult_;
1153   }
1154
1155   void handshakeSuc(AsyncSSLSocket*) noexcept override {
1156     handshakeSuccess_ = true;
1157     handshakeTime = socket_->getHandshakeTime();
1158   }
1159
1160   void handshakeErr(
1161    AsyncSSLSocket*,
1162    const AsyncSocketException& ex) noexcept override {
1163     handshakeError_ = true;
1164     handshakeTime = socket_->getHandshakeTime();
1165   }
1166
1167   // WriteCallback
1168   void writeSuccess() noexcept override {
1169     socket_->close();
1170   }
1171
1172   void writeErr(
1173    size_t bytesWritten,
1174    const AsyncSocketException& ex) noexcept override {
1175     ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
1176                   << ex.what();
1177   }
1178 };
1179
1180 class SSLHandshakeClient : public SSLHandshakeBase {
1181  public:
1182   SSLHandshakeClient(
1183    AsyncSSLSocket::UniquePtr socket,
1184    bool preverifyResult,
1185    bool verifyResult) :
1186     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1187     socket_->sslConn(this, 0);
1188   }
1189 };
1190
1191 class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
1192  public:
1193   SSLHandshakeClientNoVerify(
1194    AsyncSSLSocket::UniquePtr socket,
1195    bool preverifyResult,
1196    bool verifyResult) :
1197     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1198     socket_->sslConn(this, 0,
1199       folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1200   }
1201 };
1202
1203 class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
1204  public:
1205   SSLHandshakeClientDoVerify(
1206    AsyncSSLSocket::UniquePtr socket,
1207    bool preverifyResult,
1208    bool verifyResult) :
1209     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1210     socket_->sslConn(this, 0,
1211       folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
1212   }
1213 };
1214
1215 class SSLHandshakeServer : public SSLHandshakeBase {
1216  public:
1217   SSLHandshakeServer(
1218       AsyncSSLSocket::UniquePtr socket,
1219       bool preverifyResult,
1220       bool verifyResult)
1221     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1222     socket_->sslAccept(this, 0);
1223   }
1224 };
1225
1226 class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
1227  public:
1228   SSLHandshakeServerParseClientHello(
1229       AsyncSSLSocket::UniquePtr socket,
1230       bool preverifyResult,
1231       bool verifyResult)
1232       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1233     socket_->enableClientHelloParsing();
1234     socket_->sslAccept(this, 0);
1235   }
1236
1237   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
1238
1239  protected:
1240   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
1241     handshakeSuccess_ = true;
1242     sock->getSSLSharedCiphers(sharedCiphers_);
1243     sock->getSSLServerCiphers(serverCiphers_);
1244     sock->getSSLClientCiphers(clientCiphers_);
1245     chosenCipher_ = sock->getNegotiatedCipherName();
1246   }
1247 };
1248
1249
1250 class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
1251  public:
1252   SSLHandshakeServerNoVerify(
1253       AsyncSSLSocket::UniquePtr socket,
1254       bool preverifyResult,
1255       bool verifyResult)
1256     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1257     socket_->sslAccept(this, 0,
1258       folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1259   }
1260 };
1261
1262 class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
1263  public:
1264   SSLHandshakeServerDoVerify(
1265       AsyncSSLSocket::UniquePtr socket,
1266       bool preverifyResult,
1267       bool verifyResult)
1268     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
1269     socket_->sslAccept(this, 0,
1270       folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1271   }
1272 };
1273
1274 class EventBaseAborter : public AsyncTimeout {
1275  public:
1276   EventBaseAborter(EventBase* eventBase,
1277                    uint32_t timeoutMS)
1278     : AsyncTimeout(
1279       eventBase, AsyncTimeout::InternalEnum::INTERNAL)
1280     , eventBase_(eventBase) {
1281     scheduleTimeout(timeoutMS);
1282   }
1283
1284   void timeoutExpired() noexcept override {
1285     FAIL() << "test timed out";
1286     eventBase_->terminateLoopSoon();
1287   }
1288
1289  private:
1290   EventBase* eventBase_;
1291 };
1292
1293 }