5bb0a82c1a804d63df9c10f97ac67d6ea47820e4
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/SocketAddress.h>
24
25 #include <folly/io/async/test/BlockingSocket.h>
26
27 #include <fstream>
28 #include <gtest/gtest.h>
29 #include <iostream>
30 #include <list>
31 #include <set>
32 #include <unistd.h>
33 #include <fcntl.h>
34 #include <openssl/bio.h>
35 #include <poll.h>
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 #include <netinet/tcp.h>
39 #include <folly/io/Cursor.h>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 namespace folly {
49 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
50 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
51 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
52
53 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
54 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
55 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
56
57 constexpr size_t SSLClient::kMaxReadBufferSz;
58 constexpr size_t SSLClient::kMaxReadsPerEvent;
59
60 inline void BIO_free_fb(BIO* bio) { CHECK_EQ(1, BIO_free(bio)); }
61 using BIO_deleter = folly::static_function_deleter<BIO, &BIO_free_fb>;
62
63 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
64     : ctx_(new folly::SSLContext),
65       acb_(acb),
66       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
67   // Set up the SSL context
68   ctx_->loadCertificate(testCert);
69   ctx_->loadPrivateKey(testKey);
70   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
71
72   acb_->ctx_ = ctx_;
73   acb_->base_ = &evb_;
74
75   //set up the listening socket
76   socket_->bind(0);
77   socket_->getAddress(&address_);
78   socket_->listen(100);
79   socket_->addAcceptCallback(acb_, &evb_);
80   socket_->startAccepting();
81
82   int ret = pthread_create(&thread_, nullptr, Main, this);
83   assert(ret == 0);
84   (void)ret;
85
86   std::cerr << "Accepting connections on " << address_ << std::endl;
87 }
88
89 void getfds(int fds[2]) {
90   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
91     FAIL() << "failed to create socketpair: " << strerror(errno);
92   }
93   for (int idx = 0; idx < 2; ++idx) {
94     int flags = fcntl(fds[idx], F_GETFL, 0);
95     if (flags == -1) {
96       FAIL() << "failed to get flags for socket " << idx << ": "
97              << strerror(errno);
98     }
99     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
100       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
101              << strerror(errno);
102     }
103   }
104 }
105
106 void getctx(
107   std::shared_ptr<folly::SSLContext> clientCtx,
108   std::shared_ptr<folly::SSLContext> serverCtx) {
109   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
110
111   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
112   serverCtx->loadCertificate(
113       testCert);
114   serverCtx->loadPrivateKey(
115       testKey);
116 }
117
118 void sslsocketpair(
119   EventBase* eventBase,
120   AsyncSSLSocket::UniquePtr* clientSock,
121   AsyncSSLSocket::UniquePtr* serverSock) {
122   auto clientCtx = std::make_shared<folly::SSLContext>();
123   auto serverCtx = std::make_shared<folly::SSLContext>();
124   int fds[2];
125   getfds(fds);
126   getctx(clientCtx, serverCtx);
127   clientSock->reset(new AsyncSSLSocket(
128                       clientCtx, eventBase, fds[0], false));
129   serverSock->reset(new AsyncSSLSocket(
130                       serverCtx, eventBase, fds[1], true));
131
132   // (*clientSock)->setSendTimeout(100);
133   // (*serverSock)->setSendTimeout(100);
134 }
135
136 // client protocol filters
137 bool clientProtoFilterPickPony(unsigned char** client,
138   unsigned int* client_len, const unsigned char*, unsigned int ) {
139   //the protocol string in length prefixed byte string. the
140   //length byte is not included in the length
141   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
142   *client = p;
143   *client_len = 7;
144   return true;
145 }
146
147 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
148   const unsigned char*, unsigned int) {
149   return false;
150 }
151
152 std::string getFileAsBuf(const char* fileName) {
153   std::string buffer;
154   folly::readFile(fileName, buffer);
155   return buffer;
156 }
157
158 std::string getCommonName(X509* cert) {
159   X509_NAME* subject = X509_get_subject_name(cert);
160   std::string cn;
161   cn.resize(ub_common_name);
162   X509_NAME_get_text_by_NID(
163       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
164   return cn;
165 }
166
167 /**
168  * Test connecting to, writing to, reading from, and closing the
169  * connection to the SSL server.
170  */
171 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
172   // Start listening on a local port
173   WriteCallbackBase writeCallback;
174   ReadCallback readCallback(&writeCallback);
175   HandshakeCallback handshakeCallback(&readCallback);
176   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
177   TestSSLServer server(&acceptCallback);
178
179   // Set up SSL context.
180   std::shared_ptr<SSLContext> sslContext(new SSLContext());
181   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
182   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
183   //sslContext->authenticate(true, false);
184
185   // connect
186   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
187                                                  sslContext);
188   socket->open();
189
190   // write()
191   uint8_t buf[128];
192   memset(buf, 'a', sizeof(buf));
193   socket->write(buf, sizeof(buf));
194
195   // read()
196   uint8_t readbuf[128];
197   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
198   EXPECT_EQ(bytesRead, 128);
199   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
200
201   // close()
202   socket->close();
203
204   cerr << "ConnectWriteReadClose test completed" << endl;
205 }
206
207 /**
208  * Negative test for handshakeError().
209  */
210 TEST(AsyncSSLSocketTest, HandshakeError) {
211   // Start listening on a local port
212   WriteCallbackBase writeCallback;
213   ReadCallback readCallback(&writeCallback);
214   HandshakeCallback handshakeCallback(&readCallback);
215   HandshakeErrorCallback acceptCallback(&handshakeCallback);
216   TestSSLServer server(&acceptCallback);
217
218   // Set up SSL context.
219   std::shared_ptr<SSLContext> sslContext(new SSLContext());
220   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
221
222   // connect
223   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
224                                                  sslContext);
225   // read()
226   bool ex = false;
227   try {
228     socket->open();
229
230     uint8_t readbuf[128];
231     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
232     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
233   } catch (AsyncSocketException &e) {
234     ex = true;
235   }
236   EXPECT_TRUE(ex);
237
238   // close()
239   socket->close();
240   cerr << "HandshakeError test completed" << endl;
241 }
242
243 /**
244  * Negative test for readError().
245  */
246 TEST(AsyncSSLSocketTest, ReadError) {
247   // Start listening on a local port
248   WriteCallbackBase writeCallback;
249   ReadErrorCallback readCallback(&writeCallback);
250   HandshakeCallback handshakeCallback(&readCallback);
251   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
252   TestSSLServer server(&acceptCallback);
253
254   // Set up SSL context.
255   std::shared_ptr<SSLContext> sslContext(new SSLContext());
256   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
257
258   // connect
259   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
260                                                  sslContext);
261   socket->open();
262
263   // write something to trigger ssl handshake
264   uint8_t buf[128];
265   memset(buf, 'a', sizeof(buf));
266   socket->write(buf, sizeof(buf));
267
268   socket->close();
269   cerr << "ReadError test completed" << endl;
270 }
271
272 /**
273  * Negative test for writeError().
274  */
275 TEST(AsyncSSLSocketTest, WriteError) {
276   // Start listening on a local port
277   WriteCallbackBase writeCallback;
278   WriteErrorCallback readCallback(&writeCallback);
279   HandshakeCallback handshakeCallback(&readCallback);
280   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
281   TestSSLServer server(&acceptCallback);
282
283   // Set up SSL context.
284   std::shared_ptr<SSLContext> sslContext(new SSLContext());
285   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
286
287   // connect
288   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
289                                                  sslContext);
290   socket->open();
291
292   // write something to trigger ssl handshake
293   uint8_t buf[128];
294   memset(buf, 'a', sizeof(buf));
295   socket->write(buf, sizeof(buf));
296
297   socket->close();
298   cerr << "WriteError test completed" << endl;
299 }
300
301 /**
302  * Test a socket with TCP_NODELAY unset.
303  */
304 TEST(AsyncSSLSocketTest, SocketWithDelay) {
305   // Start listening on a local port
306   WriteCallbackBase writeCallback;
307   ReadCallback readCallback(&writeCallback);
308   HandshakeCallback handshakeCallback(&readCallback);
309   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
310   TestSSLServer server(&acceptCallback);
311
312   // Set up SSL context.
313   std::shared_ptr<SSLContext> sslContext(new SSLContext());
314   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
315
316   // connect
317   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
318                                                  sslContext);
319   socket->open();
320
321   // write()
322   uint8_t buf[128];
323   memset(buf, 'a', sizeof(buf));
324   socket->write(buf, sizeof(buf));
325
326   // read()
327   uint8_t readbuf[128];
328   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
329   EXPECT_EQ(bytesRead, 128);
330   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
331
332   // close()
333   socket->close();
334
335   cerr << "SocketWithDelay test completed" << endl;
336 }
337
338 using NextProtocolTypePair =
339     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
340
341 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
342   // For matching protos
343  public:
344   void SetUp() override { getctx(clientCtx, serverCtx); }
345
346   void connect(bool unset = false) {
347     getfds(fds);
348
349     if (unset) {
350       // unsetting NPN for any of [client, server] is enough to make NPN not
351       // work
352       clientCtx->unsetNextProtocols();
353     }
354
355     AsyncSSLSocket::UniquePtr clientSock(
356       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
357     AsyncSSLSocket::UniquePtr serverSock(
358       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
359     client = folly::make_unique<NpnClient>(std::move(clientSock));
360     server = folly::make_unique<NpnServer>(std::move(serverSock));
361
362     eventBase.loop();
363   }
364
365   void expectProtocol(const std::string& proto) {
366     EXPECT_NE(client->nextProtoLength, 0);
367     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
368     EXPECT_EQ(
369         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
370         0);
371     string selected((const char*)client->nextProto, client->nextProtoLength);
372     EXPECT_EQ(proto, selected);
373   }
374
375   void expectNoProtocol() {
376     EXPECT_EQ(client->nextProtoLength, 0);
377     EXPECT_EQ(server->nextProtoLength, 0);
378     EXPECT_EQ(client->nextProto, nullptr);
379     EXPECT_EQ(server->nextProto, nullptr);
380   }
381
382   void expectProtocolType() {
383     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
384         GetParam().second == SSLContext::NextProtocolType::ANY) {
385       EXPECT_EQ(client->protocolType, server->protocolType);
386     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
387                GetParam().second == SSLContext::NextProtocolType::ANY) {
388       // Well not much we can say
389     } else {
390       expectProtocolType(GetParam());
391     }
392   }
393
394   void expectProtocolType(NextProtocolTypePair expected) {
395     EXPECT_EQ(client->protocolType, expected.first);
396     EXPECT_EQ(server->protocolType, expected.second);
397   }
398
399   EventBase eventBase;
400   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
401   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
402   int fds[2];
403   std::unique_ptr<NpnClient> client;
404   std::unique_ptr<NpnServer> server;
405 };
406
407 class NextProtocolNPNOnlyTest : public NextProtocolTest {
408   // For mismatching protos
409 };
410
411 class NextProtocolMismatchTest : public NextProtocolTest {
412   // For mismatching protos
413 };
414
415 TEST_P(NextProtocolTest, NpnTestOverlap) {
416   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
417   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
418                                         GetParam().second);
419
420   connect();
421
422   expectProtocol("baz");
423   expectProtocolType();
424 }
425
426 TEST_P(NextProtocolTest, NpnTestUnset) {
427   // Identical to above test, except that we want unset NPN before
428   // looping.
429   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
430   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
431                                         GetParam().second);
432
433   connect(true /* unset */);
434
435   // if alpn negotiation fails, type will appear as npn
436   expectNoProtocol();
437   EXPECT_EQ(client->protocolType, server->protocolType);
438 }
439
440 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
441   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
442   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
443                                         GetParam().second);
444
445   connect();
446
447   expectNoProtocol();
448   expectProtocolType(
449       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
450 }
451
452 TEST_P(NextProtocolNPNOnlyTest, NpnTestNoOverlap) {
453   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
454   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
455                                         GetParam().second);
456
457   connect();
458
459   expectProtocol("blub");
460   expectProtocolType();
461 }
462
463 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
464   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
465   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
466   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
467                                         GetParam().second);
468
469   connect();
470
471   expectProtocol("ponies");
472   expectProtocolType();
473 }
474
475 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
476   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
477   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
478   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
479                                         GetParam().second);
480
481   connect();
482
483   expectProtocol("blub");
484   expectProtocolType();
485 }
486
487 TEST_P(NextProtocolTest, RandomizedNpnTest) {
488   // Probability that this test will fail is 2^-64, which could be considered
489   // as negligible.
490   const int kTries = 64;
491
492   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
493                                         GetParam().first);
494   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
495                                                   GetParam().second);
496
497   std::set<string> selectedProtocols;
498   for (int i = 0; i < kTries; ++i) {
499     connect();
500
501     EXPECT_NE(client->nextProtoLength, 0);
502     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
503     EXPECT_EQ(
504         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
505         0);
506     string selected((const char*)client->nextProto, client->nextProtoLength);
507     selectedProtocols.insert(selected);
508     expectProtocolType();
509   }
510   EXPECT_EQ(selectedProtocols.size(), 2);
511 }
512
513 INSTANTIATE_TEST_CASE_P(
514     AsyncSSLSocketTest,
515     NextProtocolTest,
516     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
517                                            SSLContext::NextProtocolType::NPN),
518 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
519                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
520                                            SSLContext::NextProtocolType::ALPN),
521 #endif
522                       NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
523                                            SSLContext::NextProtocolType::ANY),
524 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
525                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
526                                            SSLContext::NextProtocolType::ANY),
527 #endif
528                       NextProtocolTypePair(SSLContext::NextProtocolType::ANY,
529                                            SSLContext::NextProtocolType::ANY)));
530
531 INSTANTIATE_TEST_CASE_P(
532     AsyncSSLSocketTest,
533     NextProtocolNPNOnlyTest,
534     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
535                                            SSLContext::NextProtocolType::NPN)));
536
537 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
538 INSTANTIATE_TEST_CASE_P(
539     AsyncSSLSocketTest,
540     NextProtocolMismatchTest,
541     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
542                                            SSLContext::NextProtocolType::ALPN),
543                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
544                                            SSLContext::NextProtocolType::NPN)));
545 #endif
546
547 #ifndef OPENSSL_NO_TLSEXT
548 /**
549  * 1. Client sends TLSEXT_HOSTNAME in client hello.
550  * 2. Server found a match SSL_CTX and use this SSL_CTX to
551  *    continue the SSL handshake.
552  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
553  */
554 TEST(AsyncSSLSocketTest, SNITestMatch) {
555   EventBase eventBase;
556   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
557   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
558   // Use the same SSLContext to continue the handshake after
559   // tlsext_hostname match.
560   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
561   const std::string serverName("xyz.newdev.facebook.com");
562   int fds[2];
563   getfds(fds);
564   getctx(clientCtx, dfServerCtx);
565
566   AsyncSSLSocket::UniquePtr clientSock(
567     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
568   AsyncSSLSocket::UniquePtr serverSock(
569     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
570   SNIClient client(std::move(clientSock));
571   SNIServer server(std::move(serverSock),
572                    dfServerCtx,
573                    hskServerCtx,
574                    serverName);
575
576   eventBase.loop();
577
578   EXPECT_TRUE(client.serverNameMatch);
579   EXPECT_TRUE(server.serverNameMatch);
580 }
581
582 /**
583  * 1. Client sends TLSEXT_HOSTNAME in client hello.
584  * 2. Server cannot find a matching SSL_CTX and continue to use
585  *    the current SSL_CTX to do the handshake.
586  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
587  */
588 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
589   EventBase eventBase;
590   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
591   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
592   // Use the same SSLContext to continue the handshake after
593   // tlsext_hostname match.
594   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
595   const std::string clientRequestingServerName("foo.com");
596   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
597
598   int fds[2];
599   getfds(fds);
600   getctx(clientCtx, dfServerCtx);
601
602   AsyncSSLSocket::UniquePtr clientSock(
603     new AsyncSSLSocket(clientCtx,
604                         &eventBase,
605                         fds[0],
606                         clientRequestingServerName));
607   AsyncSSLSocket::UniquePtr serverSock(
608     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
609   SNIClient client(std::move(clientSock));
610   SNIServer server(std::move(serverSock),
611                    dfServerCtx,
612                    hskServerCtx,
613                    serverExpectedServerName);
614
615   eventBase.loop();
616
617   EXPECT_TRUE(!client.serverNameMatch);
618   EXPECT_TRUE(!server.serverNameMatch);
619 }
620 /**
621  * 1. Client sends TLSEXT_HOSTNAME in client hello.
622  * 2. We then change the serverName.
623  * 3. We expect that we get 'false' as the result for serNameMatch.
624  */
625
626 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
627    EventBase eventBase;
628   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
629   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
630   // Use the same SSLContext to continue the handshake after
631   // tlsext_hostname match.
632   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
633   const std::string serverName("xyz.newdev.facebook.com");
634   int fds[2];
635   getfds(fds);
636   getctx(clientCtx, dfServerCtx);
637
638   AsyncSSLSocket::UniquePtr clientSock(
639     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
640   //Change the server name
641   std::string newName("new.com");
642   clientSock->setServerName(newName);
643   AsyncSSLSocket::UniquePtr serverSock(
644     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
645   SNIClient client(std::move(clientSock));
646   SNIServer server(std::move(serverSock),
647                    dfServerCtx,
648                    hskServerCtx,
649                    serverName);
650
651   eventBase.loop();
652
653   EXPECT_TRUE(!client.serverNameMatch);
654 }
655
656 /**
657  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
658  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
659  */
660 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
661   EventBase eventBase;
662   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
663   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
664   // Use the same SSLContext to continue the handshake after
665   // tlsext_hostname match.
666   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
667   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
668
669   int fds[2];
670   getfds(fds);
671   getctx(clientCtx, dfServerCtx);
672
673   AsyncSSLSocket::UniquePtr clientSock(
674     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
675   AsyncSSLSocket::UniquePtr serverSock(
676     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
677   SNIClient client(std::move(clientSock));
678   SNIServer server(std::move(serverSock),
679                    dfServerCtx,
680                    hskServerCtx,
681                    serverExpectedServerName);
682
683   eventBase.loop();
684
685   EXPECT_TRUE(!client.serverNameMatch);
686   EXPECT_TRUE(!server.serverNameMatch);
687 }
688
689 #endif
690 /**
691  * Test SSL client socket
692  */
693 TEST(AsyncSSLSocketTest, SSLClientTest) {
694   // Start listening on a local port
695   WriteCallbackBase writeCallback;
696   ReadCallback readCallback(&writeCallback);
697   HandshakeCallback handshakeCallback(&readCallback);
698   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
699   TestSSLServer server(&acceptCallback);
700
701   // Set up SSL client
702   EventBase eventBase;
703   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
704
705   client->connect();
706   EventBaseAborter eba(&eventBase, 3000);
707   eventBase.loop();
708
709   EXPECT_EQ(client->getMiss(), 1);
710   EXPECT_EQ(client->getHit(), 0);
711
712   cerr << "SSLClientTest test completed" << endl;
713 }
714
715
716 /**
717  * Test SSL client socket session re-use
718  */
719 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
720   // Start listening on a local port
721   WriteCallbackBase writeCallback;
722   ReadCallback readCallback(&writeCallback);
723   HandshakeCallback handshakeCallback(&readCallback);
724   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
725   TestSSLServer server(&acceptCallback);
726
727   // Set up SSL client
728   EventBase eventBase;
729   auto client =
730       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
731
732   client->connect();
733   EventBaseAborter eba(&eventBase, 3000);
734   eventBase.loop();
735
736   EXPECT_EQ(client->getMiss(), 1);
737   EXPECT_EQ(client->getHit(), 9);
738
739   cerr << "SSLClientTestReuse test completed" << endl;
740 }
741
742 /**
743  * Test SSL client socket timeout
744  */
745 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
746   // Start listening on a local port
747   EmptyReadCallback readCallback;
748   HandshakeCallback handshakeCallback(&readCallback,
749                                       HandshakeCallback::EXPECT_ERROR);
750   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
751   TestSSLServer server(&acceptCallback);
752
753   // Set up SSL client
754   EventBase eventBase;
755   auto client =
756       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
757   client->connect(true /* write before connect completes */);
758   EventBaseAborter eba(&eventBase, 3000);
759   eventBase.loop();
760
761   usleep(100000);
762   // This is checking that the connectError callback precedes any queued
763   // writeError callbacks.  This matches AsyncSocket's behavior
764   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
765   EXPECT_EQ(client->getErrors(), 1);
766   EXPECT_EQ(client->getMiss(), 0);
767   EXPECT_EQ(client->getHit(), 0);
768
769   cerr << "SSLClientTimeoutTest test completed" << endl;
770 }
771
772
773 /**
774  * Test SSL server async cache
775  */
776 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
777   // Start listening on a local port
778   WriteCallbackBase writeCallback;
779   ReadCallback readCallback(&writeCallback);
780   HandshakeCallback handshakeCallback(&readCallback);
781   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
782   TestSSLAsyncCacheServer server(&acceptCallback);
783
784   // Set up SSL client
785   EventBase eventBase;
786   auto client =
787       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
788
789   client->connect();
790   EventBaseAborter eba(&eventBase, 3000);
791   eventBase.loop();
792
793   EXPECT_EQ(server.getAsyncCallbacks(), 18);
794   EXPECT_EQ(server.getAsyncLookups(), 9);
795   EXPECT_EQ(client->getMiss(), 10);
796   EXPECT_EQ(client->getHit(), 0);
797
798   cerr << "SSLServerAsyncCacheTest test completed" << endl;
799 }
800
801
802 /**
803  * Test SSL server accept timeout with cache path
804  */
805 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
806   // Start listening on a local port
807   WriteCallbackBase writeCallback;
808   ReadCallback readCallback(&writeCallback);
809   EmptyReadCallback clientReadCallback;
810   HandshakeCallback handshakeCallback(&readCallback);
811   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
812   TestSSLAsyncCacheServer server(&acceptCallback);
813
814   // Set up SSL client
815   EventBase eventBase;
816   // only do a TCP connect
817   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
818   sock->connect(nullptr, server.getAddress());
819   clientReadCallback.tcpSocket_ = sock;
820   sock->setReadCB(&clientReadCallback);
821
822   EventBaseAborter eba(&eventBase, 3000);
823   eventBase.loop();
824
825   EXPECT_EQ(readCallback.state, STATE_WAITING);
826
827   cerr << "SSLServerTimeoutTest test completed" << endl;
828 }
829
830 /**
831  * Test SSL server accept timeout with cache path
832  */
833 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
834   // Start listening on a local port
835   WriteCallbackBase writeCallback;
836   ReadCallback readCallback(&writeCallback);
837   HandshakeCallback handshakeCallback(&readCallback);
838   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
839   TestSSLAsyncCacheServer server(&acceptCallback);
840
841   // Set up SSL client
842   EventBase eventBase;
843   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
844
845   client->connect();
846   EventBaseAborter eba(&eventBase, 3000);
847   eventBase.loop();
848
849   EXPECT_EQ(server.getAsyncCallbacks(), 1);
850   EXPECT_EQ(server.getAsyncLookups(), 1);
851   EXPECT_EQ(client->getErrors(), 1);
852   EXPECT_EQ(client->getMiss(), 1);
853   EXPECT_EQ(client->getHit(), 0);
854
855   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
856 }
857
858 /**
859  * Test SSL server accept timeout with cache path
860  */
861 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
862   // Start listening on a local port
863   WriteCallbackBase writeCallback;
864   ReadCallback readCallback(&writeCallback);
865   HandshakeCallback handshakeCallback(&readCallback,
866                                       HandshakeCallback::EXPECT_ERROR);
867   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
868   TestSSLAsyncCacheServer server(&acceptCallback, 500);
869
870   // Set up SSL client
871   EventBase eventBase;
872   auto client =
873       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
874
875   client->connect();
876   EventBaseAborter eba(&eventBase, 3000);
877   eventBase.loop();
878
879   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
880       handshakeCallback.closeSocket();});
881   // give time for the cache lookup to come back and find it closed
882   usleep(500000);
883
884   EXPECT_EQ(server.getAsyncCallbacks(), 1);
885   EXPECT_EQ(server.getAsyncLookups(), 1);
886   EXPECT_EQ(client->getErrors(), 1);
887   EXPECT_EQ(client->getMiss(), 1);
888   EXPECT_EQ(client->getHit(), 0);
889
890   cerr << "SSLServerCacheCloseTest test completed" << endl;
891 }
892
893 /**
894  * Verify Client Ciphers obtained using SSL MSG Callback.
895  */
896 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
897   EventBase eventBase;
898   auto clientCtx = std::make_shared<SSLContext>();
899   auto serverCtx = std::make_shared<SSLContext>();
900   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
901   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
902   serverCtx->loadPrivateKey(testKey);
903   serverCtx->loadCertificate(testCert);
904   serverCtx->loadTrustedCertificates(testCA);
905   serverCtx->loadClientCAList(testCA);
906
907   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
908   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
909   clientCtx->loadPrivateKey(testKey);
910   clientCtx->loadCertificate(testCert);
911   clientCtx->loadTrustedCertificates(testCA);
912
913   int fds[2];
914   getfds(fds);
915
916   AsyncSSLSocket::UniquePtr clientSock(
917       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
918   AsyncSSLSocket::UniquePtr serverSock(
919       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
920
921   SSLHandshakeClient client(std::move(clientSock), true, true);
922   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
923
924   eventBase.loop();
925
926   EXPECT_EQ(server.clientCiphers_,
927             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
928   EXPECT_TRUE(client.handshakeVerify_);
929   EXPECT_TRUE(client.handshakeSuccess_);
930   EXPECT_TRUE(!client.handshakeError_);
931   EXPECT_TRUE(server.handshakeVerify_);
932   EXPECT_TRUE(server.handshakeSuccess_);
933   EXPECT_TRUE(!server.handshakeError_);
934 }
935
936 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
937   EventBase eventBase;
938   auto ctx = std::make_shared<SSLContext>();
939
940   int fds[2];
941   getfds(fds);
942
943   int bufLen = 42;
944   uint8_t majorVersion = 18;
945   uint8_t minorVersion = 25;
946
947   // Create callback buf
948   auto buf = IOBuf::create(bufLen);
949   buf->append(bufLen);
950   folly::io::RWPrivateCursor cursor(buf.get());
951   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
952   cursor.write<uint16_t>(0);
953   cursor.write<uint8_t>(38);
954   cursor.write<uint8_t>(majorVersion);
955   cursor.write<uint8_t>(minorVersion);
956   cursor.skip(32);
957   cursor.write<uint32_t>(0);
958
959   SSL* ssl = ctx->createSSL();
960   SCOPE_EXIT { SSL_free(ssl); };
961   AsyncSSLSocket::UniquePtr sock(
962       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
963   sock->enableClientHelloParsing();
964
965   // Test client hello parsing in one packet
966   AsyncSSLSocket::clientHelloParsingCallback(
967       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
968   buf.reset();
969
970   auto parsedClientHello = sock->getClientHelloInfo();
971   EXPECT_TRUE(parsedClientHello != nullptr);
972   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
973   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
974 }
975
976 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
977   EventBase eventBase;
978   auto ctx = std::make_shared<SSLContext>();
979
980   int fds[2];
981   getfds(fds);
982
983   int bufLen = 42;
984   uint8_t majorVersion = 18;
985   uint8_t minorVersion = 25;
986
987   // Create callback buf
988   auto buf = IOBuf::create(bufLen);
989   buf->append(bufLen);
990   folly::io::RWPrivateCursor cursor(buf.get());
991   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
992   cursor.write<uint16_t>(0);
993   cursor.write<uint8_t>(38);
994   cursor.write<uint8_t>(majorVersion);
995   cursor.write<uint8_t>(minorVersion);
996   cursor.skip(32);
997   cursor.write<uint32_t>(0);
998
999   SSL* ssl = ctx->createSSL();
1000   SCOPE_EXIT { SSL_free(ssl); };
1001   AsyncSSLSocket::UniquePtr sock(
1002       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1003   sock->enableClientHelloParsing();
1004
1005   // Test parsing with two packets with first packet size < 3
1006   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1007   AsyncSSLSocket::clientHelloParsingCallback(
1008       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1009       ssl, sock.get());
1010   bufCopy.reset();
1011   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1012   AsyncSSLSocket::clientHelloParsingCallback(
1013       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1014       ssl, sock.get());
1015   bufCopy.reset();
1016
1017   auto parsedClientHello = sock->getClientHelloInfo();
1018   EXPECT_TRUE(parsedClientHello != nullptr);
1019   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1020   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1021 }
1022
1023 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1024   EventBase eventBase;
1025   auto ctx = std::make_shared<SSLContext>();
1026
1027   int fds[2];
1028   getfds(fds);
1029
1030   int bufLen = 42;
1031   uint8_t majorVersion = 18;
1032   uint8_t minorVersion = 25;
1033
1034   // Create callback buf
1035   auto buf = IOBuf::create(bufLen);
1036   buf->append(bufLen);
1037   folly::io::RWPrivateCursor cursor(buf.get());
1038   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1039   cursor.write<uint16_t>(0);
1040   cursor.write<uint8_t>(38);
1041   cursor.write<uint8_t>(majorVersion);
1042   cursor.write<uint8_t>(minorVersion);
1043   cursor.skip(32);
1044   cursor.write<uint32_t>(0);
1045
1046   SSL* ssl = ctx->createSSL();
1047   SCOPE_EXIT { SSL_free(ssl); };
1048   AsyncSSLSocket::UniquePtr sock(
1049       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1050   sock->enableClientHelloParsing();
1051
1052   // Test parsing with multiple small packets
1053   for (uint64_t i = 0; i < buf->length(); i += 3) {
1054     auto bufCopy = folly::IOBuf::copyBuffer(
1055         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1056     AsyncSSLSocket::clientHelloParsingCallback(
1057         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1058         ssl, sock.get());
1059     bufCopy.reset();
1060   }
1061
1062   auto parsedClientHello = sock->getClientHelloInfo();
1063   EXPECT_TRUE(parsedClientHello != nullptr);
1064   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1065   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1066 }
1067
1068 /**
1069  * Verify sucessful behavior of SSL certificate validation.
1070  */
1071 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1072   EventBase eventBase;
1073   auto clientCtx = std::make_shared<SSLContext>();
1074   auto dfServerCtx = std::make_shared<SSLContext>();
1075
1076   int fds[2];
1077   getfds(fds);
1078   getctx(clientCtx, dfServerCtx);
1079
1080   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1081   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1082
1083   AsyncSSLSocket::UniquePtr clientSock(
1084     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1085   AsyncSSLSocket::UniquePtr serverSock(
1086     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1087
1088   SSLHandshakeClient client(std::move(clientSock), true, true);
1089   clientCtx->loadTrustedCertificates(testCA);
1090
1091   SSLHandshakeServer server(std::move(serverSock), true, true);
1092
1093   eventBase.loop();
1094
1095   EXPECT_TRUE(client.handshakeVerify_);
1096   EXPECT_TRUE(client.handshakeSuccess_);
1097   EXPECT_TRUE(!client.handshakeError_);
1098   EXPECT_LE(0, client.handshakeTime.count());
1099   EXPECT_TRUE(!server.handshakeVerify_);
1100   EXPECT_TRUE(server.handshakeSuccess_);
1101   EXPECT_TRUE(!server.handshakeError_);
1102   EXPECT_LE(0, server.handshakeTime.count());
1103 }
1104
1105 /**
1106  * Verify that the client's verification callback is able to fail SSL
1107  * connection establishment.
1108  */
1109 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1110   EventBase eventBase;
1111   auto clientCtx = std::make_shared<SSLContext>();
1112   auto dfServerCtx = std::make_shared<SSLContext>();
1113
1114   int fds[2];
1115   getfds(fds);
1116   getctx(clientCtx, dfServerCtx);
1117
1118   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1119   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1120
1121   AsyncSSLSocket::UniquePtr clientSock(
1122     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1123   AsyncSSLSocket::UniquePtr serverSock(
1124     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1125
1126   SSLHandshakeClient client(std::move(clientSock), true, false);
1127   clientCtx->loadTrustedCertificates(testCA);
1128
1129   SSLHandshakeServer server(std::move(serverSock), true, true);
1130
1131   eventBase.loop();
1132
1133   EXPECT_TRUE(client.handshakeVerify_);
1134   EXPECT_TRUE(!client.handshakeSuccess_);
1135   EXPECT_TRUE(client.handshakeError_);
1136   EXPECT_LE(0, client.handshakeTime.count());
1137   EXPECT_TRUE(!server.handshakeVerify_);
1138   EXPECT_TRUE(!server.handshakeSuccess_);
1139   EXPECT_TRUE(server.handshakeError_);
1140   EXPECT_LE(0, server.handshakeTime.count());
1141 }
1142
1143 /**
1144  * Verify that the options in SSLContext can be overridden in
1145  * sslConnect/Accept.i.e specifying that no validation should be performed
1146  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1147  * the validation callback.
1148  */
1149 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1150   EventBase eventBase;
1151   auto clientCtx = std::make_shared<SSLContext>();
1152   auto dfServerCtx = std::make_shared<SSLContext>();
1153
1154   int fds[2];
1155   getfds(fds);
1156   getctx(clientCtx, dfServerCtx);
1157
1158   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1159   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1160
1161   AsyncSSLSocket::UniquePtr clientSock(
1162     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1163   AsyncSSLSocket::UniquePtr serverSock(
1164     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1165
1166   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1167   clientCtx->loadTrustedCertificates(testCA);
1168
1169   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1170
1171   eventBase.loop();
1172
1173   EXPECT_TRUE(!client.handshakeVerify_);
1174   EXPECT_TRUE(client.handshakeSuccess_);
1175   EXPECT_TRUE(!client.handshakeError_);
1176   EXPECT_LE(0, client.handshakeTime.count());
1177   EXPECT_TRUE(!server.handshakeVerify_);
1178   EXPECT_TRUE(server.handshakeSuccess_);
1179   EXPECT_TRUE(!server.handshakeError_);
1180   EXPECT_LE(0, server.handshakeTime.count());
1181 }
1182
1183 /**
1184  * Verify that the options in SSLContext can be overridden in
1185  * sslConnect/Accept. Enable verification even if context says otherwise.
1186  * Test requireClientCert with client cert
1187  */
1188 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1189   EventBase eventBase;
1190   auto clientCtx = std::make_shared<SSLContext>();
1191   auto serverCtx = std::make_shared<SSLContext>();
1192   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1193   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1194   serverCtx->loadPrivateKey(testKey);
1195   serverCtx->loadCertificate(testCert);
1196   serverCtx->loadTrustedCertificates(testCA);
1197   serverCtx->loadClientCAList(testCA);
1198
1199   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1200   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1201   clientCtx->loadPrivateKey(testKey);
1202   clientCtx->loadCertificate(testCert);
1203   clientCtx->loadTrustedCertificates(testCA);
1204
1205   int fds[2];
1206   getfds(fds);
1207
1208   AsyncSSLSocket::UniquePtr clientSock(
1209       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1210   AsyncSSLSocket::UniquePtr serverSock(
1211       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1212
1213   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1214   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1215
1216   eventBase.loop();
1217
1218   EXPECT_TRUE(client.handshakeVerify_);
1219   EXPECT_TRUE(client.handshakeSuccess_);
1220   EXPECT_FALSE(client.handshakeError_);
1221   EXPECT_LE(0, client.handshakeTime.count());
1222   EXPECT_TRUE(server.handshakeVerify_);
1223   EXPECT_TRUE(server.handshakeSuccess_);
1224   EXPECT_FALSE(server.handshakeError_);
1225   EXPECT_LE(0, server.handshakeTime.count());
1226 }
1227
1228 /**
1229  * Verify that the client's verification callback is able to override
1230  * the preverification failure and allow a successful connection.
1231  */
1232 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1233   EventBase eventBase;
1234   auto clientCtx = std::make_shared<SSLContext>();
1235   auto dfServerCtx = std::make_shared<SSLContext>();
1236
1237   int fds[2];
1238   getfds(fds);
1239   getctx(clientCtx, dfServerCtx);
1240
1241   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1242   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1243
1244   AsyncSSLSocket::UniquePtr clientSock(
1245     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1246   AsyncSSLSocket::UniquePtr serverSock(
1247     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1248
1249   SSLHandshakeClient client(std::move(clientSock), false, true);
1250   SSLHandshakeServer server(std::move(serverSock), true, true);
1251
1252   eventBase.loop();
1253
1254   EXPECT_TRUE(client.handshakeVerify_);
1255   EXPECT_TRUE(client.handshakeSuccess_);
1256   EXPECT_TRUE(!client.handshakeError_);
1257   EXPECT_LE(0, client.handshakeTime.count());
1258   EXPECT_TRUE(!server.handshakeVerify_);
1259   EXPECT_TRUE(server.handshakeSuccess_);
1260   EXPECT_TRUE(!server.handshakeError_);
1261   EXPECT_LE(0, server.handshakeTime.count());
1262 }
1263
1264 /**
1265  * Verify that specifying that no validation should be performed allows an
1266  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1267  * callback.
1268  */
1269 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1270   EventBase eventBase;
1271   auto clientCtx = std::make_shared<SSLContext>();
1272   auto dfServerCtx = std::make_shared<SSLContext>();
1273
1274   int fds[2];
1275   getfds(fds);
1276   getctx(clientCtx, dfServerCtx);
1277
1278   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1279   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1280
1281   AsyncSSLSocket::UniquePtr clientSock(
1282     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1283   AsyncSSLSocket::UniquePtr serverSock(
1284     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1285
1286   SSLHandshakeClient client(std::move(clientSock), false, false);
1287   SSLHandshakeServer server(std::move(serverSock), false, false);
1288
1289   eventBase.loop();
1290
1291   EXPECT_TRUE(!client.handshakeVerify_);
1292   EXPECT_TRUE(client.handshakeSuccess_);
1293   EXPECT_TRUE(!client.handshakeError_);
1294   EXPECT_LE(0, client.handshakeTime.count());
1295   EXPECT_TRUE(!server.handshakeVerify_);
1296   EXPECT_TRUE(server.handshakeSuccess_);
1297   EXPECT_TRUE(!server.handshakeError_);
1298   EXPECT_LE(0, server.handshakeTime.count());
1299 }
1300
1301 /**
1302  * Test requireClientCert with client cert
1303  */
1304 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1305   EventBase eventBase;
1306   auto clientCtx = std::make_shared<SSLContext>();
1307   auto serverCtx = std::make_shared<SSLContext>();
1308   serverCtx->setVerificationOption(
1309       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1310   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1311   serverCtx->loadPrivateKey(testKey);
1312   serverCtx->loadCertificate(testCert);
1313   serverCtx->loadTrustedCertificates(testCA);
1314   serverCtx->loadClientCAList(testCA);
1315
1316   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1317   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1318   clientCtx->loadPrivateKey(testKey);
1319   clientCtx->loadCertificate(testCert);
1320   clientCtx->loadTrustedCertificates(testCA);
1321
1322   int fds[2];
1323   getfds(fds);
1324
1325   AsyncSSLSocket::UniquePtr clientSock(
1326       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1327   AsyncSSLSocket::UniquePtr serverSock(
1328       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1329
1330   SSLHandshakeClient client(std::move(clientSock), true, true);
1331   SSLHandshakeServer server(std::move(serverSock), true, true);
1332
1333   eventBase.loop();
1334
1335   EXPECT_TRUE(client.handshakeVerify_);
1336   EXPECT_TRUE(client.handshakeSuccess_);
1337   EXPECT_FALSE(client.handshakeError_);
1338   EXPECT_LE(0, client.handshakeTime.count());
1339   EXPECT_TRUE(server.handshakeVerify_);
1340   EXPECT_TRUE(server.handshakeSuccess_);
1341   EXPECT_FALSE(server.handshakeError_);
1342   EXPECT_LE(0, server.handshakeTime.count());
1343 }
1344
1345
1346 /**
1347  * Test requireClientCert with no client cert
1348  */
1349 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1350   EventBase eventBase;
1351   auto clientCtx = std::make_shared<SSLContext>();
1352   auto serverCtx = std::make_shared<SSLContext>();
1353   serverCtx->setVerificationOption(
1354       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1355   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1356   serverCtx->loadPrivateKey(testKey);
1357   serverCtx->loadCertificate(testCert);
1358   serverCtx->loadTrustedCertificates(testCA);
1359   serverCtx->loadClientCAList(testCA);
1360   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1361   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1362
1363   int fds[2];
1364   getfds(fds);
1365
1366   AsyncSSLSocket::UniquePtr clientSock(
1367       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1368   AsyncSSLSocket::UniquePtr serverSock(
1369       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1370
1371   SSLHandshakeClient client(std::move(clientSock), false, false);
1372   SSLHandshakeServer server(std::move(serverSock), false, false);
1373
1374   eventBase.loop();
1375
1376   EXPECT_FALSE(server.handshakeVerify_);
1377   EXPECT_FALSE(server.handshakeSuccess_);
1378   EXPECT_TRUE(server.handshakeError_);
1379   EXPECT_LE(0, client.handshakeTime.count());
1380   EXPECT_LE(0, server.handshakeTime.count());
1381 }
1382
1383 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1384   auto cert = getFileAsBuf(testCert);
1385   auto key = getFileAsBuf(testKey);
1386
1387   std::unique_ptr<BIO, BIO_deleter> certBio(BIO_new(BIO_s_mem()));
1388   BIO_write(certBio.get(), cert.data(), cert.size());
1389   std::unique_ptr<BIO, BIO_deleter> keyBio(BIO_new(BIO_s_mem()));
1390   BIO_write(keyBio.get(), key.data(), key.size());
1391
1392   // Create SSL structs from buffers to get properties
1393   X509_UniquePtr certStruct(
1394       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1395   EVP_PKEY_UniquePtr keyStruct(
1396       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1397   certBio = nullptr;
1398   keyBio = nullptr;
1399
1400   auto origCommonName = getCommonName(certStruct.get());
1401   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1402   certStruct = nullptr;
1403   keyStruct = nullptr;
1404
1405   auto ctx = std::make_shared<SSLContext>();
1406   ctx->loadPrivateKeyFromBufferPEM(key);
1407   ctx->loadCertificateFromBufferPEM(cert);
1408   ctx->loadTrustedCertificates(testCA);
1409
1410   SSL_UniquePtr ssl(ctx->createSSL());
1411
1412   auto newCert = SSL_get_certificate(ssl.get());
1413   auto newKey = SSL_get_privatekey(ssl.get());
1414
1415   // Get properties from SSL struct
1416   auto newCommonName = getCommonName(newCert);
1417   auto newKeySize = EVP_PKEY_bits(newKey);
1418
1419   // Check that the key and cert have the expected properties
1420   EXPECT_EQ(origCommonName, newCommonName);
1421   EXPECT_EQ(origKeySize, newKeySize);
1422 }
1423
1424 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1425   EventBase eb;
1426
1427   // Set up SSL context.
1428   auto sslContext = std::make_shared<SSLContext>();
1429   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1430
1431   // create SSL socket
1432   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1433
1434   EXPECT_EQ(1500, socket->getMinWriteSize());
1435
1436   socket->setMinWriteSize(0);
1437   EXPECT_EQ(0, socket->getMinWriteSize());
1438   socket->setMinWriteSize(50000);
1439   EXPECT_EQ(50000, socket->getMinWriteSize());
1440 }
1441
1442 class ReadCallbackTerminator : public ReadCallback {
1443  public:
1444   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1445       : ReadCallback(wcb)
1446       , base_(base) {}
1447
1448   // Do not write data back, terminate the loop.
1449   void readDataAvailable(size_t len) noexcept override {
1450     std::cerr << "readDataAvailable, len " << len << std::endl;
1451
1452     currentBuffer.length = len;
1453
1454     buffers.push_back(currentBuffer);
1455     currentBuffer.reset();
1456     state = STATE_SUCCEEDED;
1457
1458     socket_->setReadCB(nullptr);
1459     base_->terminateLoopSoon();
1460   }
1461  private:
1462   EventBase* base_;
1463 };
1464
1465
1466 /**
1467  * Test a full unencrypted codepath
1468  */
1469 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1470   EventBase base;
1471
1472   auto clientCtx = std::make_shared<folly::SSLContext>();
1473   auto serverCtx = std::make_shared<folly::SSLContext>();
1474   int fds[2];
1475   getfds(fds);
1476   getctx(clientCtx, serverCtx);
1477   auto client = AsyncSSLSocket::newSocket(
1478                   clientCtx, &base, fds[0], false, true);
1479   auto server = AsyncSSLSocket::newSocket(
1480                   serverCtx, &base, fds[1], true, true);
1481
1482   ReadCallbackTerminator readCallback(&base, nullptr);
1483   server->setReadCB(&readCallback);
1484   readCallback.setSocket(server);
1485
1486   uint8_t buf[128];
1487   memset(buf, 'a', sizeof(buf));
1488   client->write(nullptr, buf, sizeof(buf));
1489
1490   // Check that bytes are unencrypted
1491   char c;
1492   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1493   EXPECT_EQ('a', c);
1494
1495   EventBaseAborter eba(&base, 3000);
1496   base.loop();
1497
1498   EXPECT_EQ(1, readCallback.buffers.size());
1499   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1500
1501   server->setReadCB(&readCallback);
1502
1503   // Unencrypted
1504   server->sslAccept(nullptr);
1505   client->sslConn(nullptr);
1506
1507   // Do NOT wait for handshake, writing should be queued and happen after
1508
1509   client->write(nullptr, buf, sizeof(buf));
1510
1511   // Check that bytes are *not* unencrypted
1512   char c2;
1513   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1514   EXPECT_NE('a', c2);
1515
1516
1517   base.loop();
1518
1519   EXPECT_EQ(2, readCallback.buffers.size());
1520   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1521 }
1522
1523 } // namespace
1524
1525 ///////////////////////////////////////////////////////////////////////////
1526 // init_unit_test_suite
1527 ///////////////////////////////////////////////////////////////////////////
1528 namespace {
1529 struct Initializer {
1530   Initializer() {
1531     signal(SIGPIPE, SIG_IGN);
1532   }
1533 };
1534 Initializer initializer;
1535 } // anonymous