Add support for ALPN
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/SocketAddress.h>
24
25 #include <folly/io/async/test/BlockingSocket.h>
26
27 #include <gtest/gtest.h>
28 #include <iostream>
29 #include <list>
30 #include <set>
31 #include <unistd.h>
32 #include <fcntl.h>
33 #include <poll.h>
34 #include <sys/types.h>
35 #include <sys/socket.h>
36 #include <netinet/tcp.h>
37 #include <folly/io/Cursor.h>
38
39 using std::string;
40 using std::vector;
41 using std::min;
42 using std::cerr;
43 using std::endl;
44 using std::list;
45
46 namespace folly {
47 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
48 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
49 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
50
51 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
52 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
53 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
54
55 constexpr size_t SSLClient::kMaxReadBufferSz;
56 constexpr size_t SSLClient::kMaxReadsPerEvent;
57
58 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) :
59 ctx_(new folly::SSLContext),
60     acb_(acb),
61   socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
62   // Set up the SSL context
63   ctx_->loadCertificate(testCert);
64   ctx_->loadPrivateKey(testKey);
65   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
66
67   acb_->ctx_ = ctx_;
68   acb_->base_ = &evb_;
69
70   //set up the listening socket
71   socket_->bind(0);
72   socket_->getAddress(&address_);
73   socket_->listen(100);
74   socket_->addAcceptCallback(acb_, &evb_);
75   socket_->startAccepting();
76
77   int ret = pthread_create(&thread_, nullptr, Main, this);
78   assert(ret == 0);
79   (void)ret;
80
81   std::cerr << "Accepting connections on " << address_ << std::endl;
82 }
83
84 void getfds(int fds[2]) {
85   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
86     FAIL() << "failed to create socketpair: " << strerror(errno);
87   }
88   for (int idx = 0; idx < 2; ++idx) {
89     int flags = fcntl(fds[idx], F_GETFL, 0);
90     if (flags == -1) {
91       FAIL() << "failed to get flags for socket " << idx << ": "
92              << strerror(errno);
93     }
94     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
95       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
96              << strerror(errno);
97     }
98   }
99 }
100
101 void getctx(
102   std::shared_ptr<folly::SSLContext> clientCtx,
103   std::shared_ptr<folly::SSLContext> serverCtx) {
104   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
105
106   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
107   serverCtx->loadCertificate(
108       testCert);
109   serverCtx->loadPrivateKey(
110       testKey);
111 }
112
113 void sslsocketpair(
114   EventBase* eventBase,
115   AsyncSSLSocket::UniquePtr* clientSock,
116   AsyncSSLSocket::UniquePtr* serverSock) {
117   auto clientCtx = std::make_shared<folly::SSLContext>();
118   auto serverCtx = std::make_shared<folly::SSLContext>();
119   int fds[2];
120   getfds(fds);
121   getctx(clientCtx, serverCtx);
122   clientSock->reset(new AsyncSSLSocket(
123                       clientCtx, eventBase, fds[0], false));
124   serverSock->reset(new AsyncSSLSocket(
125                       serverCtx, eventBase, fds[1], true));
126
127   // (*clientSock)->setSendTimeout(100);
128   // (*serverSock)->setSendTimeout(100);
129 }
130
131 // client protocol filters
132 bool clientProtoFilterPickPony(unsigned char** client,
133   unsigned int* client_len, const unsigned char*, unsigned int ) {
134   //the protocol string in length prefixed byte string. the
135   //length byte is not included in the length
136   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
137   *client = p;
138   *client_len = 7;
139   return true;
140 }
141
142 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
143   const unsigned char*, unsigned int) {
144   return false;
145 }
146
147 /**
148  * Test connecting to, writing to, reading from, and closing the
149  * connection to the SSL server.
150  */
151 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
152   // Start listening on a local port
153   WriteCallbackBase writeCallback;
154   ReadCallback readCallback(&writeCallback);
155   HandshakeCallback handshakeCallback(&readCallback);
156   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
157   TestSSLServer server(&acceptCallback);
158
159   // Set up SSL context.
160   std::shared_ptr<SSLContext> sslContext(new SSLContext());
161   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
162   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
163   //sslContext->authenticate(true, false);
164
165   // connect
166   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
167                                                  sslContext);
168   socket->open();
169
170   // write()
171   uint8_t buf[128];
172   memset(buf, 'a', sizeof(buf));
173   socket->write(buf, sizeof(buf));
174
175   // read()
176   uint8_t readbuf[128];
177   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
178   EXPECT_EQ(bytesRead, 128);
179   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
180
181   // close()
182   socket->close();
183
184   cerr << "ConnectWriteReadClose test completed" << endl;
185 }
186
187 /**
188  * Negative test for handshakeError().
189  */
190 TEST(AsyncSSLSocketTest, HandshakeError) {
191   // Start listening on a local port
192   WriteCallbackBase writeCallback;
193   ReadCallback readCallback(&writeCallback);
194   HandshakeCallback handshakeCallback(&readCallback);
195   HandshakeErrorCallback acceptCallback(&handshakeCallback);
196   TestSSLServer server(&acceptCallback);
197
198   // Set up SSL context.
199   std::shared_ptr<SSLContext> sslContext(new SSLContext());
200   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
201
202   // connect
203   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
204                                                  sslContext);
205   // read()
206   bool ex = false;
207   try {
208     socket->open();
209
210     uint8_t readbuf[128];
211     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
212     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
213   } catch (AsyncSocketException &e) {
214     ex = true;
215   }
216   EXPECT_TRUE(ex);
217
218   // close()
219   socket->close();
220   cerr << "HandshakeError test completed" << endl;
221 }
222
223 /**
224  * Negative test for readError().
225  */
226 TEST(AsyncSSLSocketTest, ReadError) {
227   // Start listening on a local port
228   WriteCallbackBase writeCallback;
229   ReadErrorCallback readCallback(&writeCallback);
230   HandshakeCallback handshakeCallback(&readCallback);
231   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
232   TestSSLServer server(&acceptCallback);
233
234   // Set up SSL context.
235   std::shared_ptr<SSLContext> sslContext(new SSLContext());
236   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
237
238   // connect
239   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
240                                                  sslContext);
241   socket->open();
242
243   // write something to trigger ssl handshake
244   uint8_t buf[128];
245   memset(buf, 'a', sizeof(buf));
246   socket->write(buf, sizeof(buf));
247
248   socket->close();
249   cerr << "ReadError test completed" << endl;
250 }
251
252 /**
253  * Negative test for writeError().
254  */
255 TEST(AsyncSSLSocketTest, WriteError) {
256   // Start listening on a local port
257   WriteCallbackBase writeCallback;
258   WriteErrorCallback readCallback(&writeCallback);
259   HandshakeCallback handshakeCallback(&readCallback);
260   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
261   TestSSLServer server(&acceptCallback);
262
263   // Set up SSL context.
264   std::shared_ptr<SSLContext> sslContext(new SSLContext());
265   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
266
267   // connect
268   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
269                                                  sslContext);
270   socket->open();
271
272   // write something to trigger ssl handshake
273   uint8_t buf[128];
274   memset(buf, 'a', sizeof(buf));
275   socket->write(buf, sizeof(buf));
276
277   socket->close();
278   cerr << "WriteError test completed" << endl;
279 }
280
281 /**
282  * Test a socket with TCP_NODELAY unset.
283  */
284 TEST(AsyncSSLSocketTest, SocketWithDelay) {
285   // Start listening on a local port
286   WriteCallbackBase writeCallback;
287   ReadCallback readCallback(&writeCallback);
288   HandshakeCallback handshakeCallback(&readCallback);
289   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
290   TestSSLServer server(&acceptCallback);
291
292   // Set up SSL context.
293   std::shared_ptr<SSLContext> sslContext(new SSLContext());
294   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
295
296   // connect
297   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
298                                                  sslContext);
299   socket->open();
300
301   // write()
302   uint8_t buf[128];
303   memset(buf, 'a', sizeof(buf));
304   socket->write(buf, sizeof(buf));
305
306   // read()
307   uint8_t readbuf[128];
308   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
309   EXPECT_EQ(bytesRead, 128);
310   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
311
312   // close()
313   socket->close();
314
315   cerr << "SocketWithDelay test completed" << endl;
316 }
317
318 using NextProtocolTypePair =
319     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
320
321 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
322   // For matching protos
323  public:
324   void SetUp() override { getctx(clientCtx, serverCtx); }
325
326   void connect(bool unset = false) {
327     getfds(fds);
328
329     if (unset) {
330       // unsetting NPN for any of [client, server] is enough to make NPN not
331       // work
332       clientCtx->unsetNextProtocols();
333     }
334
335     AsyncSSLSocket::UniquePtr clientSock(
336       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
337     AsyncSSLSocket::UniquePtr serverSock(
338       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
339     client = folly::make_unique<NpnClient>(std::move(clientSock));
340     server = folly::make_unique<NpnServer>(std::move(serverSock));
341
342     eventBase.loop();
343   }
344
345   void expectProtocol(const std::string& proto) {
346     EXPECT_NE(client->nextProtoLength, 0);
347     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
348     EXPECT_EQ(
349         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
350         0);
351     string selected((const char*)client->nextProto, client->nextProtoLength);
352     EXPECT_EQ(proto, selected);
353   }
354
355   void expectNoProtocol() {
356     EXPECT_EQ(client->nextProtoLength, 0);
357     EXPECT_EQ(server->nextProtoLength, 0);
358     EXPECT_EQ(client->nextProto, nullptr);
359     EXPECT_EQ(server->nextProto, nullptr);
360   }
361
362   void expectProtocolType() {
363     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
364         GetParam().second == SSLContext::NextProtocolType::ANY) {
365       EXPECT_EQ(client->protocolType, server->protocolType);
366     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
367                GetParam().second == SSLContext::NextProtocolType::ANY) {
368       // Well not much we can say
369     } else {
370       expectProtocolType(GetParam());
371     }
372   }
373
374   void expectProtocolType(NextProtocolTypePair expected) {
375     EXPECT_EQ(client->protocolType, expected.first);
376     EXPECT_EQ(server->protocolType, expected.second);
377   }
378
379   EventBase eventBase;
380   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
381   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
382   int fds[2];
383   std::unique_ptr<NpnClient> client;
384   std::unique_ptr<NpnServer> server;
385 };
386
387 class NextProtocolNPNOnlyTest : public NextProtocolTest {
388   // For mismatching protos
389 };
390
391 class NextProtocolMismatchTest : public NextProtocolTest {
392   // For mismatching protos
393 };
394
395 TEST_P(NextProtocolTest, NpnTestOverlap) {
396   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
397   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
398                                         GetParam().second);
399
400   connect();
401
402   expectProtocol("baz");
403   expectProtocolType();
404 }
405
406 TEST_P(NextProtocolTest, NpnTestUnset) {
407   // Identical to above test, except that we want unset NPN before
408   // looping.
409   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
410   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
411                                         GetParam().second);
412
413   connect(true /* unset */);
414
415   // if alpn negotiation fails, type will appear as npn
416   expectNoProtocol();
417   EXPECT_EQ(client->protocolType, server->protocolType);
418 }
419
420 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
421   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
422   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
423                                         GetParam().second);
424
425   connect();
426
427   expectNoProtocol();
428   expectProtocolType(
429       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
430 }
431
432 TEST_P(NextProtocolNPNOnlyTest, NpnTestNoOverlap) {
433   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
434   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
435                                         GetParam().second);
436
437   connect();
438
439   expectProtocol("blub");
440   expectProtocolType();
441 }
442
443 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
444   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
445   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
446   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
447                                         GetParam().second);
448
449   connect();
450
451   expectProtocol("ponies");
452   expectProtocolType();
453 }
454
455 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
456   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
457   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
458   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
459                                         GetParam().second);
460
461   connect();
462
463   expectProtocol("blub");
464   expectProtocolType();
465 }
466
467 TEST_P(NextProtocolTest, RandomizedNpnTest) {
468   // Probability that this test will fail is 2^-64, which could be considered
469   // as negligible.
470   const int kTries = 64;
471
472   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
473                                         GetParam().first);
474   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
475                                                   GetParam().second);
476
477   std::set<string> selectedProtocols;
478   for (int i = 0; i < kTries; ++i) {
479     connect();
480
481     EXPECT_NE(client->nextProtoLength, 0);
482     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
483     EXPECT_EQ(
484         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
485         0);
486     string selected((const char*)client->nextProto, client->nextProtoLength);
487     selectedProtocols.insert(selected);
488     expectProtocolType();
489   }
490   EXPECT_EQ(selectedProtocols.size(), 2);
491 }
492
493 INSTANTIATE_TEST_CASE_P(
494     AsyncSSLSocketTest,
495     NextProtocolTest,
496     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
497                                            SSLContext::NextProtocolType::NPN),
498 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
499                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
500                                            SSLContext::NextProtocolType::ALPN),
501 #endif
502                       NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
503                                            SSLContext::NextProtocolType::ANY),
504 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
505                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
506                                            SSLContext::NextProtocolType::ANY),
507 #endif
508                       NextProtocolTypePair(SSLContext::NextProtocolType::ANY,
509                                            SSLContext::NextProtocolType::ANY)));
510
511 INSTANTIATE_TEST_CASE_P(
512     AsyncSSLSocketTest,
513     NextProtocolNPNOnlyTest,
514     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
515                                            SSLContext::NextProtocolType::NPN)));
516
517 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
518 INSTANTIATE_TEST_CASE_P(
519     AsyncSSLSocketTest,
520     NextProtocolMismatchTest,
521     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
522                                            SSLContext::NextProtocolType::ALPN),
523                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
524                                            SSLContext::NextProtocolType::NPN)));
525 #endif
526
527 #ifndef OPENSSL_NO_TLSEXT
528 /**
529  * 1. Client sends TLSEXT_HOSTNAME in client hello.
530  * 2. Server found a match SSL_CTX and use this SSL_CTX to
531  *    continue the SSL handshake.
532  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
533  */
534 TEST(AsyncSSLSocketTest, SNITestMatch) {
535   EventBase eventBase;
536   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
537   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
538   // Use the same SSLContext to continue the handshake after
539   // tlsext_hostname match.
540   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
541   const std::string serverName("xyz.newdev.facebook.com");
542   int fds[2];
543   getfds(fds);
544   getctx(clientCtx, dfServerCtx);
545
546   AsyncSSLSocket::UniquePtr clientSock(
547     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
548   AsyncSSLSocket::UniquePtr serverSock(
549     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
550   SNIClient client(std::move(clientSock));
551   SNIServer server(std::move(serverSock),
552                    dfServerCtx,
553                    hskServerCtx,
554                    serverName);
555
556   eventBase.loop();
557
558   EXPECT_TRUE(client.serverNameMatch);
559   EXPECT_TRUE(server.serverNameMatch);
560 }
561
562 /**
563  * 1. Client sends TLSEXT_HOSTNAME in client hello.
564  * 2. Server cannot find a matching SSL_CTX and continue to use
565  *    the current SSL_CTX to do the handshake.
566  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
567  */
568 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
569   EventBase eventBase;
570   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
571   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
572   // Use the same SSLContext to continue the handshake after
573   // tlsext_hostname match.
574   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
575   const std::string clientRequestingServerName("foo.com");
576   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
577
578   int fds[2];
579   getfds(fds);
580   getctx(clientCtx, dfServerCtx);
581
582   AsyncSSLSocket::UniquePtr clientSock(
583     new AsyncSSLSocket(clientCtx,
584                         &eventBase,
585                         fds[0],
586                         clientRequestingServerName));
587   AsyncSSLSocket::UniquePtr serverSock(
588     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
589   SNIClient client(std::move(clientSock));
590   SNIServer server(std::move(serverSock),
591                    dfServerCtx,
592                    hskServerCtx,
593                    serverExpectedServerName);
594
595   eventBase.loop();
596
597   EXPECT_TRUE(!client.serverNameMatch);
598   EXPECT_TRUE(!server.serverNameMatch);
599 }
600 /**
601  * 1. Client sends TLSEXT_HOSTNAME in client hello.
602  * 2. We then change the serverName.
603  * 3. We expect that we get 'false' as the result for serNameMatch.
604  */
605
606 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
607    EventBase eventBase;
608   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
609   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
610   // Use the same SSLContext to continue the handshake after
611   // tlsext_hostname match.
612   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
613   const std::string serverName("xyz.newdev.facebook.com");
614   int fds[2];
615   getfds(fds);
616   getctx(clientCtx, dfServerCtx);
617
618   AsyncSSLSocket::UniquePtr clientSock(
619     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
620   //Change the server name
621   std::string newName("new.com");
622   clientSock->setServerName(newName);
623   AsyncSSLSocket::UniquePtr serverSock(
624     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
625   SNIClient client(std::move(clientSock));
626   SNIServer server(std::move(serverSock),
627                    dfServerCtx,
628                    hskServerCtx,
629                    serverName);
630
631   eventBase.loop();
632
633   EXPECT_TRUE(!client.serverNameMatch);
634 }
635
636 /**
637  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
638  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
639  */
640 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
641   EventBase eventBase;
642   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
643   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
644   // Use the same SSLContext to continue the handshake after
645   // tlsext_hostname match.
646   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
647   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
648
649   int fds[2];
650   getfds(fds);
651   getctx(clientCtx, dfServerCtx);
652
653   AsyncSSLSocket::UniquePtr clientSock(
654     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
655   AsyncSSLSocket::UniquePtr serverSock(
656     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
657   SNIClient client(std::move(clientSock));
658   SNIServer server(std::move(serverSock),
659                    dfServerCtx,
660                    hskServerCtx,
661                    serverExpectedServerName);
662
663   eventBase.loop();
664
665   EXPECT_TRUE(!client.serverNameMatch);
666   EXPECT_TRUE(!server.serverNameMatch);
667 }
668
669 #endif
670 /**
671  * Test SSL client socket
672  */
673 TEST(AsyncSSLSocketTest, SSLClientTest) {
674   // Start listening on a local port
675   WriteCallbackBase writeCallback;
676   ReadCallback readCallback(&writeCallback);
677   HandshakeCallback handshakeCallback(&readCallback);
678   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
679   TestSSLServer server(&acceptCallback);
680
681   // Set up SSL client
682   EventBase eventBase;
683   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
684
685   client->connect();
686   EventBaseAborter eba(&eventBase, 3000);
687   eventBase.loop();
688
689   EXPECT_EQ(client->getMiss(), 1);
690   EXPECT_EQ(client->getHit(), 0);
691
692   cerr << "SSLClientTest test completed" << endl;
693 }
694
695
696 /**
697  * Test SSL client socket session re-use
698  */
699 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
700   // Start listening on a local port
701   WriteCallbackBase writeCallback;
702   ReadCallback readCallback(&writeCallback);
703   HandshakeCallback handshakeCallback(&readCallback);
704   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
705   TestSSLServer server(&acceptCallback);
706
707   // Set up SSL client
708   EventBase eventBase;
709   auto client =
710       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
711
712   client->connect();
713   EventBaseAborter eba(&eventBase, 3000);
714   eventBase.loop();
715
716   EXPECT_EQ(client->getMiss(), 1);
717   EXPECT_EQ(client->getHit(), 9);
718
719   cerr << "SSLClientTestReuse test completed" << endl;
720 }
721
722 /**
723  * Test SSL client socket timeout
724  */
725 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
726   // Start listening on a local port
727   EmptyReadCallback readCallback;
728   HandshakeCallback handshakeCallback(&readCallback,
729                                       HandshakeCallback::EXPECT_ERROR);
730   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
731   TestSSLServer server(&acceptCallback);
732
733   // Set up SSL client
734   EventBase eventBase;
735   auto client =
736       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
737   client->connect(true /* write before connect completes */);
738   EventBaseAborter eba(&eventBase, 3000);
739   eventBase.loop();
740
741   usleep(100000);
742   // This is checking that the connectError callback precedes any queued
743   // writeError callbacks.  This matches AsyncSocket's behavior
744   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
745   EXPECT_EQ(client->getErrors(), 1);
746   EXPECT_EQ(client->getMiss(), 0);
747   EXPECT_EQ(client->getHit(), 0);
748
749   cerr << "SSLClientTimeoutTest test completed" << endl;
750 }
751
752
753 /**
754  * Test SSL server async cache
755  */
756 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
757   // Start listening on a local port
758   WriteCallbackBase writeCallback;
759   ReadCallback readCallback(&writeCallback);
760   HandshakeCallback handshakeCallback(&readCallback);
761   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
762   TestSSLAsyncCacheServer server(&acceptCallback);
763
764   // Set up SSL client
765   EventBase eventBase;
766   auto client =
767       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
768
769   client->connect();
770   EventBaseAborter eba(&eventBase, 3000);
771   eventBase.loop();
772
773   EXPECT_EQ(server.getAsyncCallbacks(), 18);
774   EXPECT_EQ(server.getAsyncLookups(), 9);
775   EXPECT_EQ(client->getMiss(), 10);
776   EXPECT_EQ(client->getHit(), 0);
777
778   cerr << "SSLServerAsyncCacheTest test completed" << endl;
779 }
780
781
782 /**
783  * Test SSL server accept timeout with cache path
784  */
785 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
786   // Start listening on a local port
787   WriteCallbackBase writeCallback;
788   ReadCallback readCallback(&writeCallback);
789   EmptyReadCallback clientReadCallback;
790   HandshakeCallback handshakeCallback(&readCallback);
791   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
792   TestSSLAsyncCacheServer server(&acceptCallback);
793
794   // Set up SSL client
795   EventBase eventBase;
796   // only do a TCP connect
797   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
798   sock->connect(nullptr, server.getAddress());
799   clientReadCallback.tcpSocket_ = sock;
800   sock->setReadCB(&clientReadCallback);
801
802   EventBaseAborter eba(&eventBase, 3000);
803   eventBase.loop();
804
805   EXPECT_EQ(readCallback.state, STATE_WAITING);
806
807   cerr << "SSLServerTimeoutTest test completed" << endl;
808 }
809
810 /**
811  * Test SSL server accept timeout with cache path
812  */
813 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
814   // Start listening on a local port
815   WriteCallbackBase writeCallback;
816   ReadCallback readCallback(&writeCallback);
817   HandshakeCallback handshakeCallback(&readCallback);
818   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
819   TestSSLAsyncCacheServer server(&acceptCallback);
820
821   // Set up SSL client
822   EventBase eventBase;
823   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
824
825   client->connect();
826   EventBaseAborter eba(&eventBase, 3000);
827   eventBase.loop();
828
829   EXPECT_EQ(server.getAsyncCallbacks(), 1);
830   EXPECT_EQ(server.getAsyncLookups(), 1);
831   EXPECT_EQ(client->getErrors(), 1);
832   EXPECT_EQ(client->getMiss(), 1);
833   EXPECT_EQ(client->getHit(), 0);
834
835   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
836 }
837
838 /**
839  * Test SSL server accept timeout with cache path
840  */
841 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
842   // Start listening on a local port
843   WriteCallbackBase writeCallback;
844   ReadCallback readCallback(&writeCallback);
845   HandshakeCallback handshakeCallback(&readCallback,
846                                       HandshakeCallback::EXPECT_ERROR);
847   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
848   TestSSLAsyncCacheServer server(&acceptCallback, 500);
849
850   // Set up SSL client
851   EventBase eventBase;
852   auto client =
853       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
854
855   client->connect();
856   EventBaseAborter eba(&eventBase, 3000);
857   eventBase.loop();
858
859   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
860       handshakeCallback.closeSocket();});
861   // give time for the cache lookup to come back and find it closed
862   usleep(500000);
863
864   EXPECT_EQ(server.getAsyncCallbacks(), 1);
865   EXPECT_EQ(server.getAsyncLookups(), 1);
866   EXPECT_EQ(client->getErrors(), 1);
867   EXPECT_EQ(client->getMiss(), 1);
868   EXPECT_EQ(client->getHit(), 0);
869
870   cerr << "SSLServerCacheCloseTest test completed" << endl;
871 }
872
873 /**
874  * Verify Client Ciphers obtained using SSL MSG Callback.
875  */
876 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
877   EventBase eventBase;
878   auto clientCtx = std::make_shared<SSLContext>();
879   auto serverCtx = std::make_shared<SSLContext>();
880   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
881   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
882   serverCtx->loadPrivateKey(testKey);
883   serverCtx->loadCertificate(testCert);
884   serverCtx->loadTrustedCertificates(testCA);
885   serverCtx->loadClientCAList(testCA);
886
887   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
888   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
889   clientCtx->loadPrivateKey(testKey);
890   clientCtx->loadCertificate(testCert);
891   clientCtx->loadTrustedCertificates(testCA);
892
893   int fds[2];
894   getfds(fds);
895
896   AsyncSSLSocket::UniquePtr clientSock(
897       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
898   AsyncSSLSocket::UniquePtr serverSock(
899       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
900
901   SSLHandshakeClient client(std::move(clientSock), true, true);
902   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
903
904   eventBase.loop();
905
906   EXPECT_EQ(server.clientCiphers_,
907             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
908   EXPECT_TRUE(client.handshakeVerify_);
909   EXPECT_TRUE(client.handshakeSuccess_);
910   EXPECT_TRUE(!client.handshakeError_);
911   EXPECT_TRUE(server.handshakeVerify_);
912   EXPECT_TRUE(server.handshakeSuccess_);
913   EXPECT_TRUE(!server.handshakeError_);
914 }
915
916 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
917   EventBase eventBase;
918   auto ctx = std::make_shared<SSLContext>();
919
920   int fds[2];
921   getfds(fds);
922
923   int bufLen = 42;
924   uint8_t majorVersion = 18;
925   uint8_t minorVersion = 25;
926
927   // Create callback buf
928   auto buf = IOBuf::create(bufLen);
929   buf->append(bufLen);
930   folly::io::RWPrivateCursor cursor(buf.get());
931   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
932   cursor.write<uint16_t>(0);
933   cursor.write<uint8_t>(38);
934   cursor.write<uint8_t>(majorVersion);
935   cursor.write<uint8_t>(minorVersion);
936   cursor.skip(32);
937   cursor.write<uint32_t>(0);
938
939   SSL* ssl = ctx->createSSL();
940   SCOPE_EXIT { SSL_free(ssl); };
941   AsyncSSLSocket::UniquePtr sock(
942       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
943   sock->enableClientHelloParsing();
944
945   // Test client hello parsing in one packet
946   AsyncSSLSocket::clientHelloParsingCallback(
947       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
948   buf.reset();
949
950   auto parsedClientHello = sock->getClientHelloInfo();
951   EXPECT_TRUE(parsedClientHello != nullptr);
952   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
953   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
954 }
955
956 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
957   EventBase eventBase;
958   auto ctx = std::make_shared<SSLContext>();
959
960   int fds[2];
961   getfds(fds);
962
963   int bufLen = 42;
964   uint8_t majorVersion = 18;
965   uint8_t minorVersion = 25;
966
967   // Create callback buf
968   auto buf = IOBuf::create(bufLen);
969   buf->append(bufLen);
970   folly::io::RWPrivateCursor cursor(buf.get());
971   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
972   cursor.write<uint16_t>(0);
973   cursor.write<uint8_t>(38);
974   cursor.write<uint8_t>(majorVersion);
975   cursor.write<uint8_t>(minorVersion);
976   cursor.skip(32);
977   cursor.write<uint32_t>(0);
978
979   SSL* ssl = ctx->createSSL();
980   SCOPE_EXIT { SSL_free(ssl); };
981   AsyncSSLSocket::UniquePtr sock(
982       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
983   sock->enableClientHelloParsing();
984
985   // Test parsing with two packets with first packet size < 3
986   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
987   AsyncSSLSocket::clientHelloParsingCallback(
988       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
989       ssl, sock.get());
990   bufCopy.reset();
991   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
992   AsyncSSLSocket::clientHelloParsingCallback(
993       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
994       ssl, sock.get());
995   bufCopy.reset();
996
997   auto parsedClientHello = sock->getClientHelloInfo();
998   EXPECT_TRUE(parsedClientHello != nullptr);
999   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1000   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1001 }
1002
1003 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1004   EventBase eventBase;
1005   auto ctx = std::make_shared<SSLContext>();
1006
1007   int fds[2];
1008   getfds(fds);
1009
1010   int bufLen = 42;
1011   uint8_t majorVersion = 18;
1012   uint8_t minorVersion = 25;
1013
1014   // Create callback buf
1015   auto buf = IOBuf::create(bufLen);
1016   buf->append(bufLen);
1017   folly::io::RWPrivateCursor cursor(buf.get());
1018   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1019   cursor.write<uint16_t>(0);
1020   cursor.write<uint8_t>(38);
1021   cursor.write<uint8_t>(majorVersion);
1022   cursor.write<uint8_t>(minorVersion);
1023   cursor.skip(32);
1024   cursor.write<uint32_t>(0);
1025
1026   SSL* ssl = ctx->createSSL();
1027   SCOPE_EXIT { SSL_free(ssl); };
1028   AsyncSSLSocket::UniquePtr sock(
1029       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1030   sock->enableClientHelloParsing();
1031
1032   // Test parsing with multiple small packets
1033   for (uint64_t i = 0; i < buf->length(); i += 3) {
1034     auto bufCopy = folly::IOBuf::copyBuffer(
1035         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1036     AsyncSSLSocket::clientHelloParsingCallback(
1037         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1038         ssl, sock.get());
1039     bufCopy.reset();
1040   }
1041
1042   auto parsedClientHello = sock->getClientHelloInfo();
1043   EXPECT_TRUE(parsedClientHello != nullptr);
1044   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1045   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1046 }
1047
1048 /**
1049  * Verify sucessful behavior of SSL certificate validation.
1050  */
1051 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1052   EventBase eventBase;
1053   auto clientCtx = std::make_shared<SSLContext>();
1054   auto dfServerCtx = std::make_shared<SSLContext>();
1055
1056   int fds[2];
1057   getfds(fds);
1058   getctx(clientCtx, dfServerCtx);
1059
1060   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1061   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1062
1063   AsyncSSLSocket::UniquePtr clientSock(
1064     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1065   AsyncSSLSocket::UniquePtr serverSock(
1066     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1067
1068   SSLHandshakeClient client(std::move(clientSock), true, true);
1069   clientCtx->loadTrustedCertificates(testCA);
1070
1071   SSLHandshakeServer server(std::move(serverSock), true, true);
1072
1073   eventBase.loop();
1074
1075   EXPECT_TRUE(client.handshakeVerify_);
1076   EXPECT_TRUE(client.handshakeSuccess_);
1077   EXPECT_TRUE(!client.handshakeError_);
1078   EXPECT_LE(0, client.handshakeTime.count());
1079   EXPECT_TRUE(!server.handshakeVerify_);
1080   EXPECT_TRUE(server.handshakeSuccess_);
1081   EXPECT_TRUE(!server.handshakeError_);
1082   EXPECT_LE(0, server.handshakeTime.count());
1083 }
1084
1085 /**
1086  * Verify that the client's verification callback is able to fail SSL
1087  * connection establishment.
1088  */
1089 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1090   EventBase eventBase;
1091   auto clientCtx = std::make_shared<SSLContext>();
1092   auto dfServerCtx = std::make_shared<SSLContext>();
1093
1094   int fds[2];
1095   getfds(fds);
1096   getctx(clientCtx, dfServerCtx);
1097
1098   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1099   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1100
1101   AsyncSSLSocket::UniquePtr clientSock(
1102     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1103   AsyncSSLSocket::UniquePtr serverSock(
1104     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1105
1106   SSLHandshakeClient client(std::move(clientSock), true, false);
1107   clientCtx->loadTrustedCertificates(testCA);
1108
1109   SSLHandshakeServer server(std::move(serverSock), true, true);
1110
1111   eventBase.loop();
1112
1113   EXPECT_TRUE(client.handshakeVerify_);
1114   EXPECT_TRUE(!client.handshakeSuccess_);
1115   EXPECT_TRUE(client.handshakeError_);
1116   EXPECT_LE(0, client.handshakeTime.count());
1117   EXPECT_TRUE(!server.handshakeVerify_);
1118   EXPECT_TRUE(!server.handshakeSuccess_);
1119   EXPECT_TRUE(server.handshakeError_);
1120   EXPECT_LE(0, server.handshakeTime.count());
1121 }
1122
1123 /**
1124  * Verify that the options in SSLContext can be overridden in
1125  * sslConnect/Accept.i.e specifying that no validation should be performed
1126  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1127  * the validation callback.
1128  */
1129 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1130   EventBase eventBase;
1131   auto clientCtx = std::make_shared<SSLContext>();
1132   auto dfServerCtx = std::make_shared<SSLContext>();
1133
1134   int fds[2];
1135   getfds(fds);
1136   getctx(clientCtx, dfServerCtx);
1137
1138   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1139   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1140
1141   AsyncSSLSocket::UniquePtr clientSock(
1142     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1143   AsyncSSLSocket::UniquePtr serverSock(
1144     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1145
1146   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1147   clientCtx->loadTrustedCertificates(testCA);
1148
1149   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1150
1151   eventBase.loop();
1152
1153   EXPECT_TRUE(!client.handshakeVerify_);
1154   EXPECT_TRUE(client.handshakeSuccess_);
1155   EXPECT_TRUE(!client.handshakeError_);
1156   EXPECT_LE(0, client.handshakeTime.count());
1157   EXPECT_TRUE(!server.handshakeVerify_);
1158   EXPECT_TRUE(server.handshakeSuccess_);
1159   EXPECT_TRUE(!server.handshakeError_);
1160   EXPECT_LE(0, server.handshakeTime.count());
1161 }
1162
1163 /**
1164  * Verify that the options in SSLContext can be overridden in
1165  * sslConnect/Accept. Enable verification even if context says otherwise.
1166  * Test requireClientCert with client cert
1167  */
1168 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1169   EventBase eventBase;
1170   auto clientCtx = std::make_shared<SSLContext>();
1171   auto serverCtx = std::make_shared<SSLContext>();
1172   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1173   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1174   serverCtx->loadPrivateKey(testKey);
1175   serverCtx->loadCertificate(testCert);
1176   serverCtx->loadTrustedCertificates(testCA);
1177   serverCtx->loadClientCAList(testCA);
1178
1179   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1180   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1181   clientCtx->loadPrivateKey(testKey);
1182   clientCtx->loadCertificate(testCert);
1183   clientCtx->loadTrustedCertificates(testCA);
1184
1185   int fds[2];
1186   getfds(fds);
1187
1188   AsyncSSLSocket::UniquePtr clientSock(
1189       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1190   AsyncSSLSocket::UniquePtr serverSock(
1191       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1192
1193   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1194   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1195
1196   eventBase.loop();
1197
1198   EXPECT_TRUE(client.handshakeVerify_);
1199   EXPECT_TRUE(client.handshakeSuccess_);
1200   EXPECT_FALSE(client.handshakeError_);
1201   EXPECT_LE(0, client.handshakeTime.count());
1202   EXPECT_TRUE(server.handshakeVerify_);
1203   EXPECT_TRUE(server.handshakeSuccess_);
1204   EXPECT_FALSE(server.handshakeError_);
1205   EXPECT_LE(0, server.handshakeTime.count());
1206 }
1207
1208 /**
1209  * Verify that the client's verification callback is able to override
1210  * the preverification failure and allow a successful connection.
1211  */
1212 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1213   EventBase eventBase;
1214   auto clientCtx = std::make_shared<SSLContext>();
1215   auto dfServerCtx = std::make_shared<SSLContext>();
1216
1217   int fds[2];
1218   getfds(fds);
1219   getctx(clientCtx, dfServerCtx);
1220
1221   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1222   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1223
1224   AsyncSSLSocket::UniquePtr clientSock(
1225     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1226   AsyncSSLSocket::UniquePtr serverSock(
1227     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1228
1229   SSLHandshakeClient client(std::move(clientSock), false, true);
1230   SSLHandshakeServer server(std::move(serverSock), true, true);
1231
1232   eventBase.loop();
1233
1234   EXPECT_TRUE(client.handshakeVerify_);
1235   EXPECT_TRUE(client.handshakeSuccess_);
1236   EXPECT_TRUE(!client.handshakeError_);
1237   EXPECT_LE(0, client.handshakeTime.count());
1238   EXPECT_TRUE(!server.handshakeVerify_);
1239   EXPECT_TRUE(server.handshakeSuccess_);
1240   EXPECT_TRUE(!server.handshakeError_);
1241   EXPECT_LE(0, server.handshakeTime.count());
1242 }
1243
1244 /**
1245  * Verify that specifying that no validation should be performed allows an
1246  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1247  * callback.
1248  */
1249 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1250   EventBase eventBase;
1251   auto clientCtx = std::make_shared<SSLContext>();
1252   auto dfServerCtx = std::make_shared<SSLContext>();
1253
1254   int fds[2];
1255   getfds(fds);
1256   getctx(clientCtx, dfServerCtx);
1257
1258   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1259   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1260
1261   AsyncSSLSocket::UniquePtr clientSock(
1262     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1263   AsyncSSLSocket::UniquePtr serverSock(
1264     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1265
1266   SSLHandshakeClient client(std::move(clientSock), false, false);
1267   SSLHandshakeServer server(std::move(serverSock), false, false);
1268
1269   eventBase.loop();
1270
1271   EXPECT_TRUE(!client.handshakeVerify_);
1272   EXPECT_TRUE(client.handshakeSuccess_);
1273   EXPECT_TRUE(!client.handshakeError_);
1274   EXPECT_LE(0, client.handshakeTime.count());
1275   EXPECT_TRUE(!server.handshakeVerify_);
1276   EXPECT_TRUE(server.handshakeSuccess_);
1277   EXPECT_TRUE(!server.handshakeError_);
1278   EXPECT_LE(0, server.handshakeTime.count());
1279 }
1280
1281 /**
1282  * Test requireClientCert with client cert
1283  */
1284 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1285   EventBase eventBase;
1286   auto clientCtx = std::make_shared<SSLContext>();
1287   auto serverCtx = std::make_shared<SSLContext>();
1288   serverCtx->setVerificationOption(
1289       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1290   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1291   serverCtx->loadPrivateKey(testKey);
1292   serverCtx->loadCertificate(testCert);
1293   serverCtx->loadTrustedCertificates(testCA);
1294   serverCtx->loadClientCAList(testCA);
1295
1296   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1297   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1298   clientCtx->loadPrivateKey(testKey);
1299   clientCtx->loadCertificate(testCert);
1300   clientCtx->loadTrustedCertificates(testCA);
1301
1302   int fds[2];
1303   getfds(fds);
1304
1305   AsyncSSLSocket::UniquePtr clientSock(
1306       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1307   AsyncSSLSocket::UniquePtr serverSock(
1308       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1309
1310   SSLHandshakeClient client(std::move(clientSock), true, true);
1311   SSLHandshakeServer server(std::move(serverSock), true, true);
1312
1313   eventBase.loop();
1314
1315   EXPECT_TRUE(client.handshakeVerify_);
1316   EXPECT_TRUE(client.handshakeSuccess_);
1317   EXPECT_FALSE(client.handshakeError_);
1318   EXPECT_LE(0, client.handshakeTime.count());
1319   EXPECT_TRUE(server.handshakeVerify_);
1320   EXPECT_TRUE(server.handshakeSuccess_);
1321   EXPECT_FALSE(server.handshakeError_);
1322   EXPECT_LE(0, server.handshakeTime.count());
1323 }
1324
1325
1326 /**
1327  * Test requireClientCert with no client cert
1328  */
1329 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1330   EventBase eventBase;
1331   auto clientCtx = std::make_shared<SSLContext>();
1332   auto serverCtx = std::make_shared<SSLContext>();
1333   serverCtx->setVerificationOption(
1334       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1335   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1336   serverCtx->loadPrivateKey(testKey);
1337   serverCtx->loadCertificate(testCert);
1338   serverCtx->loadTrustedCertificates(testCA);
1339   serverCtx->loadClientCAList(testCA);
1340   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1341   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1342
1343   int fds[2];
1344   getfds(fds);
1345
1346   AsyncSSLSocket::UniquePtr clientSock(
1347       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1348   AsyncSSLSocket::UniquePtr serverSock(
1349       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1350
1351   SSLHandshakeClient client(std::move(clientSock), false, false);
1352   SSLHandshakeServer server(std::move(serverSock), false, false);
1353
1354   eventBase.loop();
1355
1356   EXPECT_FALSE(server.handshakeVerify_);
1357   EXPECT_FALSE(server.handshakeSuccess_);
1358   EXPECT_TRUE(server.handshakeError_);
1359   EXPECT_LE(0, client.handshakeTime.count());
1360   EXPECT_LE(0, server.handshakeTime.count());
1361 }
1362
1363 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1364   EventBase eb;
1365
1366   // Set up SSL context.
1367   auto sslContext = std::make_shared<SSLContext>();
1368   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1369
1370   // create SSL socket
1371   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1372
1373   EXPECT_EQ(1500, socket->getMinWriteSize());
1374
1375   socket->setMinWriteSize(0);
1376   EXPECT_EQ(0, socket->getMinWriteSize());
1377   socket->setMinWriteSize(50000);
1378   EXPECT_EQ(50000, socket->getMinWriteSize());
1379 }
1380
1381 class ReadCallbackTerminator : public ReadCallback {
1382  public:
1383   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1384       : ReadCallback(wcb)
1385       , base_(base) {}
1386
1387   // Do not write data back, terminate the loop.
1388   void readDataAvailable(size_t len) noexcept override {
1389     std::cerr << "readDataAvailable, len " << len << std::endl;
1390
1391     currentBuffer.length = len;
1392
1393     buffers.push_back(currentBuffer);
1394     currentBuffer.reset();
1395     state = STATE_SUCCEEDED;
1396
1397     socket_->setReadCB(nullptr);
1398     base_->terminateLoopSoon();
1399   }
1400  private:
1401   EventBase* base_;
1402 };
1403
1404
1405 /**
1406  * Test a full unencrypted codepath
1407  */
1408 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1409   EventBase base;
1410
1411   auto clientCtx = std::make_shared<folly::SSLContext>();
1412   auto serverCtx = std::make_shared<folly::SSLContext>();
1413   int fds[2];
1414   getfds(fds);
1415   getctx(clientCtx, serverCtx);
1416   auto client = AsyncSSLSocket::newSocket(
1417                   clientCtx, &base, fds[0], false, true);
1418   auto server = AsyncSSLSocket::newSocket(
1419                   serverCtx, &base, fds[1], true, true);
1420
1421   ReadCallbackTerminator readCallback(&base, nullptr);
1422   server->setReadCB(&readCallback);
1423   readCallback.setSocket(server);
1424
1425   uint8_t buf[128];
1426   memset(buf, 'a', sizeof(buf));
1427   client->write(nullptr, buf, sizeof(buf));
1428
1429   // Check that bytes are unencrypted
1430   char c;
1431   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1432   EXPECT_EQ('a', c);
1433
1434   EventBaseAborter eba(&base, 3000);
1435   base.loop();
1436
1437   EXPECT_EQ(1, readCallback.buffers.size());
1438   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1439
1440   server->setReadCB(&readCallback);
1441
1442   // Unencrypted
1443   server->sslAccept(nullptr);
1444   client->sslConn(nullptr);
1445
1446   // Do NOT wait for handshake, writing should be queued and happen after
1447
1448   client->write(nullptr, buf, sizeof(buf));
1449
1450   // Check that bytes are *not* unencrypted
1451   char c2;
1452   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1453   EXPECT_NE('a', c2);
1454
1455
1456   base.loop();
1457
1458   EXPECT_EQ(2, readCallback.buffers.size());
1459   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1460 }
1461
1462 } // namespace
1463
1464 ///////////////////////////////////////////////////////////////////////////
1465 // init_unit_test_suite
1466 ///////////////////////////////////////////////////////////////////////////
1467 namespace {
1468 struct Initializer {
1469   Initializer() {
1470     signal(SIGPIPE, SIG_IGN);
1471   }
1472 };
1473 Initializer initializer;
1474 } // anonymous