Fix leaks in tests
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/SocketAddress.h>
24
25 #include <folly/io/async/test/BlockingSocket.h>
26
27 #include <gtest/gtest.h>
28 #include <iostream>
29 #include <list>
30 #include <set>
31 #include <unistd.h>
32 #include <fcntl.h>
33 #include <poll.h>
34 #include <sys/types.h>
35 #include <sys/socket.h>
36 #include <netinet/tcp.h>
37 #include <folly/io/Cursor.h>
38
39 using std::string;
40 using std::vector;
41 using std::min;
42 using std::cerr;
43 using std::endl;
44 using std::list;
45
46 namespace folly {
47 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
48 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
49 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
50
51 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
52 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
53 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
54
55 constexpr size_t SSLClient::kMaxReadBufferSz;
56 constexpr size_t SSLClient::kMaxReadsPerEvent;
57
58 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) :
59 ctx_(new folly::SSLContext),
60     acb_(acb),
61   socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
62   // Set up the SSL context
63   ctx_->loadCertificate(testCert);
64   ctx_->loadPrivateKey(testKey);
65   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
66
67   acb_->ctx_ = ctx_;
68   acb_->base_ = &evb_;
69
70   //set up the listening socket
71   socket_->bind(0);
72   socket_->getAddress(&address_);
73   socket_->listen(100);
74   socket_->addAcceptCallback(acb_, &evb_);
75   socket_->startAccepting();
76
77   int ret = pthread_create(&thread_, nullptr, Main, this);
78   assert(ret == 0);
79
80   std::cerr << "Accepting connections on " << address_ << std::endl;
81 }
82
83 void getfds(int fds[2]) {
84   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
85     FAIL() << "failed to create socketpair: " << strerror(errno);
86   }
87   for (int idx = 0; idx < 2; ++idx) {
88     int flags = fcntl(fds[idx], F_GETFL, 0);
89     if (flags == -1) {
90       FAIL() << "failed to get flags for socket " << idx << ": "
91              << strerror(errno);
92     }
93     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
94       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
95              << strerror(errno);
96     }
97   }
98 }
99
100 void getctx(
101   std::shared_ptr<folly::SSLContext> clientCtx,
102   std::shared_ptr<folly::SSLContext> serverCtx) {
103   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
104
105   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
106   serverCtx->loadCertificate(
107       testCert);
108   serverCtx->loadPrivateKey(
109       testKey);
110 }
111
112 void sslsocketpair(
113   EventBase* eventBase,
114   AsyncSSLSocket::UniquePtr* clientSock,
115   AsyncSSLSocket::UniquePtr* serverSock) {
116   auto clientCtx = std::make_shared<folly::SSLContext>();
117   auto serverCtx = std::make_shared<folly::SSLContext>();
118   int fds[2];
119   getfds(fds);
120   getctx(clientCtx, serverCtx);
121   clientSock->reset(new AsyncSSLSocket(
122                       clientCtx, eventBase, fds[0], false));
123   serverSock->reset(new AsyncSSLSocket(
124                       serverCtx, eventBase, fds[1], true));
125
126   // (*clientSock)->setSendTimeout(100);
127   // (*serverSock)->setSendTimeout(100);
128 }
129
130
131 /**
132  * Test connecting to, writing to, reading from, and closing the
133  * connection to the SSL server.
134  */
135 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
136   // Start listening on a local port
137   WriteCallbackBase writeCallback;
138   ReadCallback readCallback(&writeCallback);
139   HandshakeCallback handshakeCallback(&readCallback);
140   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
141   TestSSLServer server(&acceptCallback);
142
143   // Set up SSL context.
144   std::shared_ptr<SSLContext> sslContext(new SSLContext());
145   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
146   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
147   //sslContext->authenticate(true, false);
148
149   // connect
150   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
151                                                  sslContext);
152   socket->open();
153
154   // write()
155   uint8_t buf[128];
156   memset(buf, 'a', sizeof(buf));
157   socket->write(buf, sizeof(buf));
158
159   // read()
160   uint8_t readbuf[128];
161   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
162   EXPECT_EQ(bytesRead, 128);
163   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
164
165   // close()
166   socket->close();
167
168   cerr << "ConnectWriteReadClose test completed" << endl;
169 }
170
171 /**
172  * Negative test for handshakeError().
173  */
174 TEST(AsyncSSLSocketTest, HandshakeError) {
175   // Start listening on a local port
176   WriteCallbackBase writeCallback;
177   ReadCallback readCallback(&writeCallback);
178   HandshakeCallback handshakeCallback(&readCallback);
179   HandshakeErrorCallback acceptCallback(&handshakeCallback);
180   TestSSLServer server(&acceptCallback);
181
182   // Set up SSL context.
183   std::shared_ptr<SSLContext> sslContext(new SSLContext());
184   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
185
186   // connect
187   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
188                                                  sslContext);
189   // read()
190   bool ex = false;
191   try {
192     socket->open();
193
194     uint8_t readbuf[128];
195     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
196   } catch (AsyncSocketException &e) {
197     ex = true;
198   }
199   EXPECT_TRUE(ex);
200
201   // close()
202   socket->close();
203   cerr << "HandshakeError test completed" << endl;
204 }
205
206 /**
207  * Negative test for readError().
208  */
209 TEST(AsyncSSLSocketTest, ReadError) {
210   // Start listening on a local port
211   WriteCallbackBase writeCallback;
212   ReadErrorCallback readCallback(&writeCallback);
213   HandshakeCallback handshakeCallback(&readCallback);
214   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
215   TestSSLServer server(&acceptCallback);
216
217   // Set up SSL context.
218   std::shared_ptr<SSLContext> sslContext(new SSLContext());
219   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
220
221   // connect
222   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
223                                                  sslContext);
224   socket->open();
225
226   // write something to trigger ssl handshake
227   uint8_t buf[128];
228   memset(buf, 'a', sizeof(buf));
229   socket->write(buf, sizeof(buf));
230
231   socket->close();
232   cerr << "ReadError test completed" << endl;
233 }
234
235 /**
236  * Negative test for writeError().
237  */
238 TEST(AsyncSSLSocketTest, WriteError) {
239   // Start listening on a local port
240   WriteCallbackBase writeCallback;
241   WriteErrorCallback readCallback(&writeCallback);
242   HandshakeCallback handshakeCallback(&readCallback);
243   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
244   TestSSLServer server(&acceptCallback);
245
246   // Set up SSL context.
247   std::shared_ptr<SSLContext> sslContext(new SSLContext());
248   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
249
250   // connect
251   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
252                                                  sslContext);
253   socket->open();
254
255   // write something to trigger ssl handshake
256   uint8_t buf[128];
257   memset(buf, 'a', sizeof(buf));
258   socket->write(buf, sizeof(buf));
259
260   socket->close();
261   cerr << "WriteError test completed" << endl;
262 }
263
264 /**
265  * Test a socket with TCP_NODELAY unset.
266  */
267 TEST(AsyncSSLSocketTest, SocketWithDelay) {
268   // Start listening on a local port
269   WriteCallbackBase writeCallback;
270   ReadCallback readCallback(&writeCallback);
271   HandshakeCallback handshakeCallback(&readCallback);
272   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
273   TestSSLServer server(&acceptCallback);
274
275   // Set up SSL context.
276   std::shared_ptr<SSLContext> sslContext(new SSLContext());
277   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
278
279   // connect
280   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
281                                                  sslContext);
282   socket->open();
283
284   // write()
285   uint8_t buf[128];
286   memset(buf, 'a', sizeof(buf));
287   socket->write(buf, sizeof(buf));
288
289   // read()
290   uint8_t readbuf[128];
291   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
292   EXPECT_EQ(bytesRead, 128);
293   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
294
295   // close()
296   socket->close();
297
298   cerr << "SocketWithDelay test completed" << endl;
299 }
300
301 TEST(AsyncSSLSocketTest, NpnTestOverlap) {
302   EventBase eventBase;
303   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
304   std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
305   int fds[2];
306   getfds(fds);
307   getctx(clientCtx, serverCtx);
308
309   clientCtx->setAdvertisedNextProtocols({"blub","baz"});
310   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
311
312   AsyncSSLSocket::UniquePtr clientSock(
313     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
314   AsyncSSLSocket::UniquePtr serverSock(
315     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
316   NpnClient client(std::move(clientSock));
317   NpnServer server(std::move(serverSock));
318
319   eventBase.loop();
320
321   EXPECT_TRUE(client.nextProtoLength != 0);
322   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
323   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
324                            server.nextProtoLength), 0);
325   string selected((const char*)client.nextProto, client.nextProtoLength);
326   EXPECT_EQ(selected.compare("baz"), 0);
327 }
328
329 TEST(AsyncSSLSocketTest, NpnTestUnset) {
330   // Identical to above test, except that we want unset NPN before
331   // looping.
332   EventBase eventBase;
333   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
334   std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
335   int fds[2];
336   getfds(fds);
337   getctx(clientCtx, serverCtx);
338
339   clientCtx->setAdvertisedNextProtocols({"blub","baz"});
340   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
341
342   AsyncSSLSocket::UniquePtr clientSock(
343     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
344   AsyncSSLSocket::UniquePtr serverSock(
345     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
346
347   // unsetting NPN for any of [client, server] is enought to make NPN not
348   // work
349   clientCtx->unsetNextProtocols();
350
351   NpnClient client(std::move(clientSock));
352   NpnServer server(std::move(serverSock));
353
354   eventBase.loop();
355
356   EXPECT_TRUE(client.nextProtoLength == 0);
357   EXPECT_TRUE(server.nextProtoLength == 0);
358   EXPECT_TRUE(client.nextProto == nullptr);
359   EXPECT_TRUE(server.nextProto == nullptr);
360 }
361
362 TEST(AsyncSSLSocketTest, NpnTestNoOverlap) {
363   EventBase eventBase;
364   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
365   std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
366   int fds[2];
367   getfds(fds);
368   getctx(clientCtx, serverCtx);
369
370   clientCtx->setAdvertisedNextProtocols({"blub"});
371   serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
372
373   AsyncSSLSocket::UniquePtr clientSock(
374     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
375   AsyncSSLSocket::UniquePtr serverSock(
376     new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
377   NpnClient client(std::move(clientSock));
378   NpnServer server(std::move(serverSock));
379
380   eventBase.loop();
381
382   EXPECT_TRUE(client.nextProtoLength != 0);
383   EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
384   EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
385                            server.nextProtoLength), 0);
386   string selected((const char*)client.nextProto, client.nextProtoLength);
387   EXPECT_EQ(selected.compare("blub"), 0);
388 }
389
390 TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
391   // Probability that this test will fail is 2^-64, which could be considered
392   // as negligible.
393   const int kTries = 64;
394
395   std::set<string> selectedProtocols;
396   for (int i = 0; i < kTries; ++i) {
397     EventBase eventBase;
398     std::shared_ptr<SSLContext> clientCtx = std::make_shared<SSLContext>();
399     std::shared_ptr<SSLContext> serverCtx = std::make_shared<SSLContext>();
400     int fds[2];
401     getfds(fds);
402     getctx(clientCtx, serverCtx);
403
404     clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
405     serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}},
406         {1, {"bar"}}});
407
408
409     AsyncSSLSocket::UniquePtr clientSock(
410       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
411     AsyncSSLSocket::UniquePtr serverSock(
412       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
413     NpnClient client(std::move(clientSock));
414     NpnServer server(std::move(serverSock));
415
416     eventBase.loop();
417
418     EXPECT_TRUE(client.nextProtoLength != 0);
419     EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
420     EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
421                              server.nextProtoLength), 0);
422     string selected((const char*)client.nextProto, client.nextProtoLength);
423     selectedProtocols.insert(selected);
424   }
425   EXPECT_EQ(selectedProtocols.size(), 2);
426 }
427
428
429 #ifndef OPENSSL_NO_TLSEXT
430 /**
431  * 1. Client sends TLSEXT_HOSTNAME in client hello.
432  * 2. Server found a match SSL_CTX and use this SSL_CTX to
433  *    continue the SSL handshake.
434  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
435  */
436 TEST(AsyncSSLSocketTest, SNITestMatch) {
437   EventBase eventBase;
438   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
439   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
440   // Use the same SSLContext to continue the handshake after
441   // tlsext_hostname match.
442   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
443   const std::string serverName("xyz.newdev.facebook.com");
444   int fds[2];
445   getfds(fds);
446   getctx(clientCtx, dfServerCtx);
447
448   AsyncSSLSocket::UniquePtr clientSock(
449     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
450   AsyncSSLSocket::UniquePtr serverSock(
451     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
452   SNIClient client(std::move(clientSock));
453   SNIServer server(std::move(serverSock),
454                    dfServerCtx,
455                    hskServerCtx,
456                    serverName);
457
458   eventBase.loop();
459
460   EXPECT_TRUE(client.serverNameMatch);
461   EXPECT_TRUE(server.serverNameMatch);
462 }
463
464 /**
465  * 1. Client sends TLSEXT_HOSTNAME in client hello.
466  * 2. Server cannot find a matching SSL_CTX and continue to use
467  *    the current SSL_CTX to do the handshake.
468  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
469  */
470 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
471   EventBase eventBase;
472   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
473   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
474   // Use the same SSLContext to continue the handshake after
475   // tlsext_hostname match.
476   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
477   const std::string clientRequestingServerName("foo.com");
478   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
479
480   int fds[2];
481   getfds(fds);
482   getctx(clientCtx, dfServerCtx);
483
484   AsyncSSLSocket::UniquePtr clientSock(
485     new AsyncSSLSocket(clientCtx,
486                         &eventBase,
487                         fds[0],
488                         clientRequestingServerName));
489   AsyncSSLSocket::UniquePtr serverSock(
490     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
491   SNIClient client(std::move(clientSock));
492   SNIServer server(std::move(serverSock),
493                    dfServerCtx,
494                    hskServerCtx,
495                    serverExpectedServerName);
496
497   eventBase.loop();
498
499   EXPECT_TRUE(!client.serverNameMatch);
500   EXPECT_TRUE(!server.serverNameMatch);
501 }
502 /**
503  * 1. Client sends TLSEXT_HOSTNAME in client hello.
504  * 2. We then change the serverName.
505  * 3. We expect that we get 'false' as the result for serNameMatch.
506  */
507
508 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
509    EventBase eventBase;
510   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
511   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
512   // Use the same SSLContext to continue the handshake after
513   // tlsext_hostname match.
514   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
515   const std::string serverName("xyz.newdev.facebook.com");
516   int fds[2];
517   getfds(fds);
518   getctx(clientCtx, dfServerCtx);
519
520   AsyncSSLSocket::UniquePtr clientSock(
521     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
522   //Change the server name
523   std::string newName("new.com");
524   clientSock->setServerName(newName);
525   AsyncSSLSocket::UniquePtr serverSock(
526     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
527   SNIClient client(std::move(clientSock));
528   SNIServer server(std::move(serverSock),
529                    dfServerCtx,
530                    hskServerCtx,
531                    serverName);
532
533   eventBase.loop();
534
535   EXPECT_TRUE(!client.serverNameMatch);
536 }
537
538 /**
539  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
540  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
541  */
542 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
543   EventBase eventBase;
544   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
545   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
546   // Use the same SSLContext to continue the handshake after
547   // tlsext_hostname match.
548   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
549   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
550
551   int fds[2];
552   getfds(fds);
553   getctx(clientCtx, dfServerCtx);
554
555   AsyncSSLSocket::UniquePtr clientSock(
556     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
557   AsyncSSLSocket::UniquePtr serverSock(
558     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
559   SNIClient client(std::move(clientSock));
560   SNIServer server(std::move(serverSock),
561                    dfServerCtx,
562                    hskServerCtx,
563                    serverExpectedServerName);
564
565   eventBase.loop();
566
567   EXPECT_TRUE(!client.serverNameMatch);
568   EXPECT_TRUE(!server.serverNameMatch);
569 }
570
571 #endif
572 /**
573  * Test SSL client socket
574  */
575 TEST(AsyncSSLSocketTest, SSLClientTest) {
576   // Start listening on a local port
577   WriteCallbackBase writeCallback;
578   ReadCallback readCallback(&writeCallback);
579   HandshakeCallback handshakeCallback(&readCallback);
580   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
581   TestSSLServer server(&acceptCallback);
582
583   // Set up SSL client
584   EventBase eventBase;
585   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
586                                              1));
587
588   client->connect();
589   EventBaseAborter eba(&eventBase, 3000);
590   eventBase.loop();
591
592   EXPECT_EQ(client->getMiss(), 1);
593   EXPECT_EQ(client->getHit(), 0);
594
595   cerr << "SSLClientTest test completed" << endl;
596 }
597
598
599 /**
600  * Test SSL client socket session re-use
601  */
602 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
603   // Start listening on a local port
604   WriteCallbackBase writeCallback;
605   ReadCallback readCallback(&writeCallback);
606   HandshakeCallback handshakeCallback(&readCallback);
607   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
608   TestSSLServer server(&acceptCallback);
609
610   // Set up SSL client
611   EventBase eventBase;
612   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
613                                              10));
614
615   client->connect();
616   EventBaseAborter eba(&eventBase, 3000);
617   eventBase.loop();
618
619   EXPECT_EQ(client->getMiss(), 1);
620   EXPECT_EQ(client->getHit(), 9);
621
622   cerr << "SSLClientTestReuse test completed" << endl;
623 }
624
625 /**
626  * Test SSL client socket timeout
627  */
628 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
629   // Start listening on a local port
630   EmptyReadCallback readCallback;
631   HandshakeCallback handshakeCallback(&readCallback,
632                                       HandshakeCallback::EXPECT_ERROR);
633   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
634   TestSSLServer server(&acceptCallback);
635
636   // Set up SSL client
637   EventBase eventBase;
638   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
639                                              1, 10));
640   client->connect(true /* write before connect completes */);
641   EventBaseAborter eba(&eventBase, 3000);
642   eventBase.loop();
643
644   usleep(100000);
645   // This is checking that the connectError callback precedes any queued
646   // writeError callbacks.  This matches AsyncSocket's behavior
647   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
648   EXPECT_EQ(client->getErrors(), 1);
649   EXPECT_EQ(client->getMiss(), 0);
650   EXPECT_EQ(client->getHit(), 0);
651
652   cerr << "SSLClientTimeoutTest test completed" << endl;
653 }
654
655
656 /**
657  * Test SSL server async cache
658  */
659 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
660   // Start listening on a local port
661   WriteCallbackBase writeCallback;
662   ReadCallback readCallback(&writeCallback);
663   HandshakeCallback handshakeCallback(&readCallback);
664   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
665   TestSSLAsyncCacheServer server(&acceptCallback);
666
667   // Set up SSL client
668   EventBase eventBase;
669   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
670                                              10, 500));
671
672   client->connect();
673   EventBaseAborter eba(&eventBase, 3000);
674   eventBase.loop();
675
676   EXPECT_EQ(server.getAsyncCallbacks(), 18);
677   EXPECT_EQ(server.getAsyncLookups(), 9);
678   EXPECT_EQ(client->getMiss(), 10);
679   EXPECT_EQ(client->getHit(), 0);
680
681   cerr << "SSLServerAsyncCacheTest test completed" << endl;
682 }
683
684
685 /**
686  * Test SSL server accept timeout with cache path
687  */
688 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
689   // Start listening on a local port
690   WriteCallbackBase writeCallback;
691   ReadCallback readCallback(&writeCallback);
692   EmptyReadCallback clientReadCallback;
693   HandshakeCallback handshakeCallback(&readCallback);
694   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
695   TestSSLAsyncCacheServer server(&acceptCallback);
696
697   // Set up SSL client
698   EventBase eventBase;
699   // only do a TCP connect
700   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
701   sock->connect(nullptr, server.getAddress());
702   clientReadCallback.tcpSocket_ = sock;
703   sock->setReadCB(&clientReadCallback);
704
705   EventBaseAborter eba(&eventBase, 3000);
706   eventBase.loop();
707
708   EXPECT_EQ(readCallback.state, STATE_WAITING);
709
710   cerr << "SSLServerTimeoutTest test completed" << endl;
711 }
712
713 /**
714  * Test SSL server accept timeout with cache path
715  */
716 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
717   // Start listening on a local port
718   WriteCallbackBase writeCallback;
719   ReadCallback readCallback(&writeCallback);
720   HandshakeCallback handshakeCallback(&readCallback);
721   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
722   TestSSLAsyncCacheServer server(&acceptCallback);
723
724   // Set up SSL client
725   EventBase eventBase;
726   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
727                                              2));
728
729   client->connect();
730   EventBaseAborter eba(&eventBase, 3000);
731   eventBase.loop();
732
733   EXPECT_EQ(server.getAsyncCallbacks(), 1);
734   EXPECT_EQ(server.getAsyncLookups(), 1);
735   EXPECT_EQ(client->getErrors(), 1);
736   EXPECT_EQ(client->getMiss(), 1);
737   EXPECT_EQ(client->getHit(), 0);
738
739   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
740 }
741
742 /**
743  * Test SSL server accept timeout with cache path
744  */
745 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
746   // Start listening on a local port
747   WriteCallbackBase writeCallback;
748   ReadCallback readCallback(&writeCallback);
749   HandshakeCallback handshakeCallback(&readCallback,
750                                       HandshakeCallback::EXPECT_ERROR);
751   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
752   TestSSLAsyncCacheServer server(&acceptCallback, 500);
753
754   // Set up SSL client
755   EventBase eventBase;
756   std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
757                                              2, 100));
758
759   client->connect();
760   EventBaseAborter eba(&eventBase, 3000);
761   eventBase.loop();
762
763   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
764       handshakeCallback.closeSocket();});
765   // give time for the cache lookup to come back and find it closed
766   usleep(500000);
767
768   EXPECT_EQ(server.getAsyncCallbacks(), 1);
769   EXPECT_EQ(server.getAsyncLookups(), 1);
770   EXPECT_EQ(client->getErrors(), 1);
771   EXPECT_EQ(client->getMiss(), 1);
772   EXPECT_EQ(client->getHit(), 0);
773
774   cerr << "SSLServerCacheCloseTest test completed" << endl;
775 }
776
777 /**
778  * Verify Client Ciphers obtained using SSL MSG Callback.
779  */
780 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
781   EventBase eventBase;
782   auto clientCtx = std::make_shared<SSLContext>();
783   auto serverCtx = std::make_shared<SSLContext>();
784   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
785   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
786   serverCtx->loadPrivateKey(testKey);
787   serverCtx->loadCertificate(testCert);
788   serverCtx->loadTrustedCertificates(testCA);
789   serverCtx->loadClientCAList(testCA);
790
791   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
792   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
793   clientCtx->loadPrivateKey(testKey);
794   clientCtx->loadCertificate(testCert);
795   clientCtx->loadTrustedCertificates(testCA);
796
797   int fds[2];
798   getfds(fds);
799
800   AsyncSSLSocket::UniquePtr clientSock(
801       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
802   AsyncSSLSocket::UniquePtr serverSock(
803       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
804
805   SSLHandshakeClient client(std::move(clientSock), true, true);
806   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
807
808   eventBase.loop();
809
810   EXPECT_EQ(server.clientCiphers_,
811             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
812   EXPECT_TRUE(client.handshakeVerify_);
813   EXPECT_TRUE(client.handshakeSuccess_);
814   EXPECT_TRUE(!client.handshakeError_);
815   EXPECT_TRUE(server.handshakeVerify_);
816   EXPECT_TRUE(server.handshakeSuccess_);
817   EXPECT_TRUE(!server.handshakeError_);
818 }
819
820 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
821   EventBase eventBase;
822   auto ctx = std::make_shared<SSLContext>();
823
824   int fds[2];
825   getfds(fds);
826
827   int bufLen = 42;
828   uint8_t majorVersion = 18;
829   uint8_t minorVersion = 25;
830
831   // Create callback buf
832   auto buf = IOBuf::create(bufLen);
833   buf->append(bufLen);
834   folly::io::RWPrivateCursor cursor(buf.get());
835   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
836   cursor.write<uint16_t>(0);
837   cursor.write<uint8_t>(38);
838   cursor.write<uint8_t>(majorVersion);
839   cursor.write<uint8_t>(minorVersion);
840   cursor.skip(32);
841   cursor.write<uint32_t>(0);
842
843   SSL* ssl = ctx->createSSL();
844   SCOPE_EXIT { SSL_free(ssl); };
845   AsyncSSLSocket::UniquePtr sock(
846       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
847   sock->enableClientHelloParsing();
848
849   // Test client hello parsing in one packet
850   AsyncSSLSocket::clientHelloParsingCallback(
851       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
852   buf.reset();
853
854   auto parsedClientHello = sock->getClientHelloInfo();
855   EXPECT_TRUE(parsedClientHello != nullptr);
856   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
857   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
858 }
859
860 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
861   EventBase eventBase;
862   auto ctx = std::make_shared<SSLContext>();
863
864   int fds[2];
865   getfds(fds);
866
867   int bufLen = 42;
868   uint8_t majorVersion = 18;
869   uint8_t minorVersion = 25;
870
871   // Create callback buf
872   auto buf = IOBuf::create(bufLen);
873   buf->append(bufLen);
874   folly::io::RWPrivateCursor cursor(buf.get());
875   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
876   cursor.write<uint16_t>(0);
877   cursor.write<uint8_t>(38);
878   cursor.write<uint8_t>(majorVersion);
879   cursor.write<uint8_t>(minorVersion);
880   cursor.skip(32);
881   cursor.write<uint32_t>(0);
882
883   SSL* ssl = ctx->createSSL();
884   SCOPE_EXIT { SSL_free(ssl); };
885   AsyncSSLSocket::UniquePtr sock(
886       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
887   sock->enableClientHelloParsing();
888
889   // Test parsing with two packets with first packet size < 3
890   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
891   AsyncSSLSocket::clientHelloParsingCallback(
892       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
893       ssl, sock.get());
894   bufCopy.reset();
895   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
896   AsyncSSLSocket::clientHelloParsingCallback(
897       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
898       ssl, sock.get());
899   bufCopy.reset();
900
901   auto parsedClientHello = sock->getClientHelloInfo();
902   EXPECT_TRUE(parsedClientHello != nullptr);
903   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
904   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
905 }
906
907 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
908   EventBase eventBase;
909   auto ctx = std::make_shared<SSLContext>();
910
911   int fds[2];
912   getfds(fds);
913
914   int bufLen = 42;
915   uint8_t majorVersion = 18;
916   uint8_t minorVersion = 25;
917
918   // Create callback buf
919   auto buf = IOBuf::create(bufLen);
920   buf->append(bufLen);
921   folly::io::RWPrivateCursor cursor(buf.get());
922   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
923   cursor.write<uint16_t>(0);
924   cursor.write<uint8_t>(38);
925   cursor.write<uint8_t>(majorVersion);
926   cursor.write<uint8_t>(minorVersion);
927   cursor.skip(32);
928   cursor.write<uint32_t>(0);
929
930   SSL* ssl = ctx->createSSL();
931   SCOPE_EXIT { SSL_free(ssl); };
932   AsyncSSLSocket::UniquePtr sock(
933       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
934   sock->enableClientHelloParsing();
935
936   // Test parsing with multiple small packets
937   for (uint64_t i = 0; i < buf->length(); i += 3) {
938     auto bufCopy = folly::IOBuf::copyBuffer(
939         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
940     AsyncSSLSocket::clientHelloParsingCallback(
941         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
942         ssl, sock.get());
943     bufCopy.reset();
944   }
945
946   auto parsedClientHello = sock->getClientHelloInfo();
947   EXPECT_TRUE(parsedClientHello != nullptr);
948   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
949   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
950 }
951
952 /**
953  * Verify sucessful behavior of SSL certificate validation.
954  */
955 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
956   EventBase eventBase;
957   auto clientCtx = std::make_shared<SSLContext>();
958   auto dfServerCtx = std::make_shared<SSLContext>();
959
960   int fds[2];
961   getfds(fds);
962   getctx(clientCtx, dfServerCtx);
963
964   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
965   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
966
967   AsyncSSLSocket::UniquePtr clientSock(
968     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
969   AsyncSSLSocket::UniquePtr serverSock(
970     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
971
972   SSLHandshakeClient client(std::move(clientSock), true, true);
973   clientCtx->loadTrustedCertificates(testCA);
974
975   SSLHandshakeServer server(std::move(serverSock), true, true);
976
977   eventBase.loop();
978
979   EXPECT_TRUE(client.handshakeVerify_);
980   EXPECT_TRUE(client.handshakeSuccess_);
981   EXPECT_TRUE(!client.handshakeError_);
982   EXPECT_TRUE(!server.handshakeVerify_);
983   EXPECT_TRUE(server.handshakeSuccess_);
984   EXPECT_TRUE(!server.handshakeError_);
985 }
986
987 /**
988  * Verify that the client's verification callback is able to fail SSL
989  * connection establishment.
990  */
991 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
992   EventBase eventBase;
993   auto clientCtx = std::make_shared<SSLContext>();
994   auto dfServerCtx = std::make_shared<SSLContext>();
995
996   int fds[2];
997   getfds(fds);
998   getctx(clientCtx, dfServerCtx);
999
1000   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1001   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1002
1003   AsyncSSLSocket::UniquePtr clientSock(
1004     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1005   AsyncSSLSocket::UniquePtr serverSock(
1006     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1007
1008   SSLHandshakeClient client(std::move(clientSock), true, false);
1009   clientCtx->loadTrustedCertificates(testCA);
1010
1011   SSLHandshakeServer server(std::move(serverSock), true, true);
1012
1013   eventBase.loop();
1014
1015   EXPECT_TRUE(client.handshakeVerify_);
1016   EXPECT_TRUE(!client.handshakeSuccess_);
1017   EXPECT_TRUE(client.handshakeError_);
1018   EXPECT_TRUE(!server.handshakeVerify_);
1019   EXPECT_TRUE(!server.handshakeSuccess_);
1020   EXPECT_TRUE(server.handshakeError_);
1021 }
1022
1023 /**
1024  * Verify that the options in SSLContext can be overridden in
1025  * sslConnect/Accept.i.e specifying that no validation should be performed
1026  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1027  * the validation callback.
1028  */
1029 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1030   EventBase eventBase;
1031   auto clientCtx = std::make_shared<SSLContext>();
1032   auto dfServerCtx = std::make_shared<SSLContext>();
1033
1034   int fds[2];
1035   getfds(fds);
1036   getctx(clientCtx, dfServerCtx);
1037
1038   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1039   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1040
1041   AsyncSSLSocket::UniquePtr clientSock(
1042     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1043   AsyncSSLSocket::UniquePtr serverSock(
1044     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1045
1046   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1047   clientCtx->loadTrustedCertificates(testCA);
1048
1049   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1050
1051   eventBase.loop();
1052
1053   EXPECT_TRUE(!client.handshakeVerify_);
1054   EXPECT_TRUE(client.handshakeSuccess_);
1055   EXPECT_TRUE(!client.handshakeError_);
1056   EXPECT_TRUE(!server.handshakeVerify_);
1057   EXPECT_TRUE(server.handshakeSuccess_);
1058   EXPECT_TRUE(!server.handshakeError_);
1059 }
1060
1061 /**
1062  * Verify that the options in SSLContext can be overridden in
1063  * sslConnect/Accept. Enable verification even if context says otherwise.
1064  * Test requireClientCert with client cert
1065  */
1066 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1067   EventBase eventBase;
1068   auto clientCtx = std::make_shared<SSLContext>();
1069   auto serverCtx = std::make_shared<SSLContext>();
1070   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1071   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1072   serverCtx->loadPrivateKey(testKey);
1073   serverCtx->loadCertificate(testCert);
1074   serverCtx->loadTrustedCertificates(testCA);
1075   serverCtx->loadClientCAList(testCA);
1076
1077   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1078   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1079   clientCtx->loadPrivateKey(testKey);
1080   clientCtx->loadCertificate(testCert);
1081   clientCtx->loadTrustedCertificates(testCA);
1082
1083   int fds[2];
1084   getfds(fds);
1085
1086   AsyncSSLSocket::UniquePtr clientSock(
1087       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1088   AsyncSSLSocket::UniquePtr serverSock(
1089       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1090
1091   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1092   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1093
1094   eventBase.loop();
1095
1096   EXPECT_TRUE(client.handshakeVerify_);
1097   EXPECT_TRUE(client.handshakeSuccess_);
1098   EXPECT_FALSE(client.handshakeError_);
1099   EXPECT_TRUE(server.handshakeVerify_);
1100   EXPECT_TRUE(server.handshakeSuccess_);
1101   EXPECT_FALSE(server.handshakeError_);
1102 }
1103
1104 /**
1105  * Verify that the client's verification callback is able to override
1106  * the preverification failure and allow a successful connection.
1107  */
1108 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1109   EventBase eventBase;
1110   auto clientCtx = std::make_shared<SSLContext>();
1111   auto dfServerCtx = std::make_shared<SSLContext>();
1112
1113   int fds[2];
1114   getfds(fds);
1115   getctx(clientCtx, dfServerCtx);
1116
1117   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1118   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1119
1120   AsyncSSLSocket::UniquePtr clientSock(
1121     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1122   AsyncSSLSocket::UniquePtr serverSock(
1123     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1124
1125   SSLHandshakeClient client(std::move(clientSock), false, true);
1126   SSLHandshakeServer server(std::move(serverSock), true, true);
1127
1128   eventBase.loop();
1129
1130   EXPECT_TRUE(client.handshakeVerify_);
1131   EXPECT_TRUE(client.handshakeSuccess_);
1132   EXPECT_TRUE(!client.handshakeError_);
1133   EXPECT_TRUE(!server.handshakeVerify_);
1134   EXPECT_TRUE(server.handshakeSuccess_);
1135   EXPECT_TRUE(!server.handshakeError_);
1136 }
1137
1138 /**
1139  * Verify that specifying that no validation should be performed allows an
1140  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1141  * callback.
1142  */
1143 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1144   EventBase eventBase;
1145   auto clientCtx = std::make_shared<SSLContext>();
1146   auto dfServerCtx = std::make_shared<SSLContext>();
1147
1148   int fds[2];
1149   getfds(fds);
1150   getctx(clientCtx, dfServerCtx);
1151
1152   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1153   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1154
1155   AsyncSSLSocket::UniquePtr clientSock(
1156     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1157   AsyncSSLSocket::UniquePtr serverSock(
1158     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1159
1160   SSLHandshakeClient client(std::move(clientSock), false, false);
1161   SSLHandshakeServer server(std::move(serverSock), false, false);
1162
1163   eventBase.loop();
1164
1165   EXPECT_TRUE(!client.handshakeVerify_);
1166   EXPECT_TRUE(client.handshakeSuccess_);
1167   EXPECT_TRUE(!client.handshakeError_);
1168   EXPECT_TRUE(!server.handshakeVerify_);
1169   EXPECT_TRUE(server.handshakeSuccess_);
1170   EXPECT_TRUE(!server.handshakeError_);
1171 }
1172
1173 /**
1174  * Test requireClientCert with client cert
1175  */
1176 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1177   EventBase eventBase;
1178   auto clientCtx = std::make_shared<SSLContext>();
1179   auto serverCtx = std::make_shared<SSLContext>();
1180   serverCtx->setVerificationOption(
1181       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1182   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1183   serverCtx->loadPrivateKey(testKey);
1184   serverCtx->loadCertificate(testCert);
1185   serverCtx->loadTrustedCertificates(testCA);
1186   serverCtx->loadClientCAList(testCA);
1187
1188   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1189   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1190   clientCtx->loadPrivateKey(testKey);
1191   clientCtx->loadCertificate(testCert);
1192   clientCtx->loadTrustedCertificates(testCA);
1193
1194   int fds[2];
1195   getfds(fds);
1196
1197   AsyncSSLSocket::UniquePtr clientSock(
1198       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1199   AsyncSSLSocket::UniquePtr serverSock(
1200       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1201
1202   SSLHandshakeClient client(std::move(clientSock), true, true);
1203   SSLHandshakeServer server(std::move(serverSock), true, true);
1204
1205   eventBase.loop();
1206
1207   EXPECT_TRUE(client.handshakeVerify_);
1208   EXPECT_TRUE(client.handshakeSuccess_);
1209   EXPECT_FALSE(client.handshakeError_);
1210   EXPECT_TRUE(server.handshakeVerify_);
1211   EXPECT_TRUE(server.handshakeSuccess_);
1212   EXPECT_FALSE(server.handshakeError_);
1213 }
1214
1215
1216 /**
1217  * Test requireClientCert with no client cert
1218  */
1219 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1220   EventBase eventBase;
1221   auto clientCtx = std::make_shared<SSLContext>();
1222   auto serverCtx = std::make_shared<SSLContext>();
1223   serverCtx->setVerificationOption(
1224       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1225   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1226   serverCtx->loadPrivateKey(testKey);
1227   serverCtx->loadCertificate(testCert);
1228   serverCtx->loadTrustedCertificates(testCA);
1229   serverCtx->loadClientCAList(testCA);
1230   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1231   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1232
1233   int fds[2];
1234   getfds(fds);
1235
1236   AsyncSSLSocket::UniquePtr clientSock(
1237       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1238   AsyncSSLSocket::UniquePtr serverSock(
1239       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1240
1241   SSLHandshakeClient client(std::move(clientSock), false, false);
1242   SSLHandshakeServer server(std::move(serverSock), false, false);
1243
1244   eventBase.loop();
1245
1246   EXPECT_FALSE(server.handshakeVerify_);
1247   EXPECT_FALSE(server.handshakeSuccess_);
1248   EXPECT_TRUE(server.handshakeError_);
1249 }
1250
1251 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1252   EventBase eb;
1253
1254   // Set up SSL context.
1255   auto sslContext = std::make_shared<SSLContext>();
1256   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1257
1258   // create SSL socket
1259   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1260
1261   EXPECT_EQ(1500, socket->getMinWriteSize());
1262
1263   socket->setMinWriteSize(0);
1264   EXPECT_EQ(0, socket->getMinWriteSize());
1265   socket->setMinWriteSize(50000);
1266   EXPECT_EQ(50000, socket->getMinWriteSize());
1267 }
1268
1269 class ReadCallbackTerminator : public ReadCallback {
1270  public:
1271   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1272       : ReadCallback(wcb)
1273       , base_(base) {}
1274
1275   // Do not write data back, terminate the loop.
1276   void readDataAvailable(size_t len) noexcept override {
1277     std::cerr << "readDataAvailable, len " << len << std::endl;
1278
1279     currentBuffer.length = len;
1280
1281     buffers.push_back(currentBuffer);
1282     currentBuffer.reset();
1283     state = STATE_SUCCEEDED;
1284
1285     socket_->setReadCB(nullptr);
1286     base_->terminateLoopSoon();
1287   }
1288  private:
1289   EventBase* base_;
1290 };
1291
1292
1293 /**
1294  * Test a full unencrypted codepath
1295  */
1296 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1297   EventBase base;
1298
1299   auto clientCtx = std::make_shared<folly::SSLContext>();
1300   auto serverCtx = std::make_shared<folly::SSLContext>();
1301   int fds[2];
1302   getfds(fds);
1303   getctx(clientCtx, serverCtx);
1304   auto client = AsyncSSLSocket::newSocket(
1305                   clientCtx, &base, fds[0], false, true);
1306   auto server = AsyncSSLSocket::newSocket(
1307                   serverCtx, &base, fds[1], true, true);
1308
1309   ReadCallbackTerminator readCallback(&base, nullptr);
1310   server->setReadCB(&readCallback);
1311   readCallback.setSocket(server);
1312
1313   uint8_t buf[128];
1314   memset(buf, 'a', sizeof(buf));
1315   client->write(nullptr, buf, sizeof(buf));
1316
1317   // Check that bytes are unencrypted
1318   char c;
1319   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1320   EXPECT_EQ('a', c);
1321
1322   EventBaseAborter eba(&base, 3000);
1323   base.loop();
1324
1325   EXPECT_EQ(1, readCallback.buffers.size());
1326   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1327
1328   server->setReadCB(&readCallback);
1329
1330   // Unencrypted
1331   server->sslAccept(nullptr);
1332   client->sslConn(nullptr);
1333
1334   // Do NOT wait for handshake, writing should be queued and happen after
1335
1336   client->write(nullptr, buf, sizeof(buf));
1337
1338   // Check that bytes are *not* unencrypted
1339   char c2;
1340   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1341   EXPECT_NE('a', c2);
1342
1343
1344   base.loop();
1345
1346   EXPECT_EQ(2, readCallback.buffers.size());
1347   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1348 }
1349
1350 } // namespace
1351
1352 ///////////////////////////////////////////////////////////////////////////
1353 // init_unit_test_suite
1354 ///////////////////////////////////////////////////////////////////////////
1355 namespace {
1356 struct Initializer {
1357   Initializer() {
1358     signal(SIGPIPE, SIG_IGN);
1359   }
1360 };
1361 Initializer initializer;
1362 } // anonymous