Strict validation for certs
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/SocketAddress.h>
24
25 #include <folly/io/async/test/BlockingSocket.h>
26
27 #include <fstream>
28 #include <gtest/gtest.h>
29 #include <iostream>
30 #include <list>
31 #include <set>
32 #include <unistd.h>
33 #include <fcntl.h>
34 #include <openssl/bio.h>
35 #include <poll.h>
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 #include <netinet/tcp.h>
39 #include <folly/io/Cursor.h>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 namespace folly {
49 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
50 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
51 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
52
53 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
54 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
55 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
56
57 constexpr size_t SSLClient::kMaxReadBufferSz;
58 constexpr size_t SSLClient::kMaxReadsPerEvent;
59
60 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
61     : ctx_(new folly::SSLContext),
62       acb_(acb),
63       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
64   // Set up the SSL context
65   ctx_->loadCertificate(testCert);
66   ctx_->loadPrivateKey(testKey);
67   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
68
69   acb_->ctx_ = ctx_;
70   acb_->base_ = &evb_;
71
72   //set up the listening socket
73   socket_->bind(0);
74   socket_->getAddress(&address_);
75   socket_->listen(100);
76   socket_->addAcceptCallback(acb_, &evb_);
77   socket_->startAccepting();
78
79   int ret = pthread_create(&thread_, nullptr, Main, this);
80   assert(ret == 0);
81   (void)ret;
82
83   std::cerr << "Accepting connections on " << address_ << std::endl;
84 }
85
86 void getfds(int fds[2]) {
87   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
88     FAIL() << "failed to create socketpair: " << strerror(errno);
89   }
90   for (int idx = 0; idx < 2; ++idx) {
91     int flags = fcntl(fds[idx], F_GETFL, 0);
92     if (flags == -1) {
93       FAIL() << "failed to get flags for socket " << idx << ": "
94              << strerror(errno);
95     }
96     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
97       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
98              << strerror(errno);
99     }
100   }
101 }
102
103 void getctx(
104   std::shared_ptr<folly::SSLContext> clientCtx,
105   std::shared_ptr<folly::SSLContext> serverCtx) {
106   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
107
108   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
109   serverCtx->loadCertificate(
110       testCert);
111   serverCtx->loadPrivateKey(
112       testKey);
113 }
114
115 void sslsocketpair(
116   EventBase* eventBase,
117   AsyncSSLSocket::UniquePtr* clientSock,
118   AsyncSSLSocket::UniquePtr* serverSock) {
119   auto clientCtx = std::make_shared<folly::SSLContext>();
120   auto serverCtx = std::make_shared<folly::SSLContext>();
121   int fds[2];
122   getfds(fds);
123   getctx(clientCtx, serverCtx);
124   clientSock->reset(new AsyncSSLSocket(
125                       clientCtx, eventBase, fds[0], false));
126   serverSock->reset(new AsyncSSLSocket(
127                       serverCtx, eventBase, fds[1], true));
128
129   // (*clientSock)->setSendTimeout(100);
130   // (*serverSock)->setSendTimeout(100);
131 }
132
133 // client protocol filters
134 bool clientProtoFilterPickPony(unsigned char** client,
135   unsigned int* client_len, const unsigned char*, unsigned int ) {
136   //the protocol string in length prefixed byte string. the
137   //length byte is not included in the length
138   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
139   *client = p;
140   *client_len = 7;
141   return true;
142 }
143
144 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
145   const unsigned char*, unsigned int) {
146   return false;
147 }
148
149 std::string getFileAsBuf(const char* fileName) {
150   std::string buffer;
151   folly::readFile(fileName, buffer);
152   return buffer;
153 }
154
155 std::string getCommonName(X509* cert) {
156   X509_NAME* subject = X509_get_subject_name(cert);
157   std::string cn;
158   cn.resize(ub_common_name);
159   X509_NAME_get_text_by_NID(
160       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
161   return cn;
162 }
163
164 /**
165  * Test connecting to, writing to, reading from, and closing the
166  * connection to the SSL server.
167  */
168 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
169   // Start listening on a local port
170   WriteCallbackBase writeCallback;
171   ReadCallback readCallback(&writeCallback);
172   HandshakeCallback handshakeCallback(&readCallback);
173   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
174   TestSSLServer server(&acceptCallback);
175
176   // Set up SSL context.
177   std::shared_ptr<SSLContext> sslContext(new SSLContext());
178   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
179   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
180   //sslContext->authenticate(true, false);
181
182   // connect
183   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
184                                                  sslContext);
185   socket->open();
186
187   // write()
188   uint8_t buf[128];
189   memset(buf, 'a', sizeof(buf));
190   socket->write(buf, sizeof(buf));
191
192   // read()
193   uint8_t readbuf[128];
194   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
195   EXPECT_EQ(bytesRead, 128);
196   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
197
198   // close()
199   socket->close();
200
201   cerr << "ConnectWriteReadClose test completed" << endl;
202 }
203
204 /**
205  * Negative test for handshakeError().
206  */
207 TEST(AsyncSSLSocketTest, HandshakeError) {
208   // Start listening on a local port
209   WriteCallbackBase writeCallback;
210   ReadCallback readCallback(&writeCallback);
211   HandshakeCallback handshakeCallback(&readCallback);
212   HandshakeErrorCallback acceptCallback(&handshakeCallback);
213   TestSSLServer server(&acceptCallback);
214
215   // Set up SSL context.
216   std::shared_ptr<SSLContext> sslContext(new SSLContext());
217   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
218
219   // connect
220   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
221                                                  sslContext);
222   // read()
223   bool ex = false;
224   try {
225     socket->open();
226
227     uint8_t readbuf[128];
228     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
229     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
230   } catch (AsyncSocketException &e) {
231     ex = true;
232   }
233   EXPECT_TRUE(ex);
234
235   // close()
236   socket->close();
237   cerr << "HandshakeError test completed" << endl;
238 }
239
240 /**
241  * Negative test for readError().
242  */
243 TEST(AsyncSSLSocketTest, ReadError) {
244   // Start listening on a local port
245   WriteCallbackBase writeCallback;
246   ReadErrorCallback readCallback(&writeCallback);
247   HandshakeCallback handshakeCallback(&readCallback);
248   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
249   TestSSLServer server(&acceptCallback);
250
251   // Set up SSL context.
252   std::shared_ptr<SSLContext> sslContext(new SSLContext());
253   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
254
255   // connect
256   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
257                                                  sslContext);
258   socket->open();
259
260   // write something to trigger ssl handshake
261   uint8_t buf[128];
262   memset(buf, 'a', sizeof(buf));
263   socket->write(buf, sizeof(buf));
264
265   socket->close();
266   cerr << "ReadError test completed" << endl;
267 }
268
269 /**
270  * Negative test for writeError().
271  */
272 TEST(AsyncSSLSocketTest, WriteError) {
273   // Start listening on a local port
274   WriteCallbackBase writeCallback;
275   WriteErrorCallback readCallback(&writeCallback);
276   HandshakeCallback handshakeCallback(&readCallback);
277   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
278   TestSSLServer server(&acceptCallback);
279
280   // Set up SSL context.
281   std::shared_ptr<SSLContext> sslContext(new SSLContext());
282   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
283
284   // connect
285   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
286                                                  sslContext);
287   socket->open();
288
289   // write something to trigger ssl handshake
290   uint8_t buf[128];
291   memset(buf, 'a', sizeof(buf));
292   socket->write(buf, sizeof(buf));
293
294   socket->close();
295   cerr << "WriteError test completed" << endl;
296 }
297
298 /**
299  * Test a socket with TCP_NODELAY unset.
300  */
301 TEST(AsyncSSLSocketTest, SocketWithDelay) {
302   // Start listening on a local port
303   WriteCallbackBase writeCallback;
304   ReadCallback readCallback(&writeCallback);
305   HandshakeCallback handshakeCallback(&readCallback);
306   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
307   TestSSLServer server(&acceptCallback);
308
309   // Set up SSL context.
310   std::shared_ptr<SSLContext> sslContext(new SSLContext());
311   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
312
313   // connect
314   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
315                                                  sslContext);
316   socket->open();
317
318   // write()
319   uint8_t buf[128];
320   memset(buf, 'a', sizeof(buf));
321   socket->write(buf, sizeof(buf));
322
323   // read()
324   uint8_t readbuf[128];
325   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
326   EXPECT_EQ(bytesRead, 128);
327   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
328
329   // close()
330   socket->close();
331
332   cerr << "SocketWithDelay test completed" << endl;
333 }
334
335 using NextProtocolTypePair =
336     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
337
338 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
339   // For matching protos
340  public:
341   void SetUp() override { getctx(clientCtx, serverCtx); }
342
343   void connect(bool unset = false) {
344     getfds(fds);
345
346     if (unset) {
347       // unsetting NPN for any of [client, server] is enough to make NPN not
348       // work
349       clientCtx->unsetNextProtocols();
350     }
351
352     AsyncSSLSocket::UniquePtr clientSock(
353       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
354     AsyncSSLSocket::UniquePtr serverSock(
355       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
356     client = folly::make_unique<NpnClient>(std::move(clientSock));
357     server = folly::make_unique<NpnServer>(std::move(serverSock));
358
359     eventBase.loop();
360   }
361
362   void expectProtocol(const std::string& proto) {
363     EXPECT_NE(client->nextProtoLength, 0);
364     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
365     EXPECT_EQ(
366         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
367         0);
368     string selected((const char*)client->nextProto, client->nextProtoLength);
369     EXPECT_EQ(proto, selected);
370   }
371
372   void expectNoProtocol() {
373     EXPECT_EQ(client->nextProtoLength, 0);
374     EXPECT_EQ(server->nextProtoLength, 0);
375     EXPECT_EQ(client->nextProto, nullptr);
376     EXPECT_EQ(server->nextProto, nullptr);
377   }
378
379   void expectProtocolType() {
380     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
381         GetParam().second == SSLContext::NextProtocolType::ANY) {
382       EXPECT_EQ(client->protocolType, server->protocolType);
383     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
384                GetParam().second == SSLContext::NextProtocolType::ANY) {
385       // Well not much we can say
386     } else {
387       expectProtocolType(GetParam());
388     }
389   }
390
391   void expectProtocolType(NextProtocolTypePair expected) {
392     EXPECT_EQ(client->protocolType, expected.first);
393     EXPECT_EQ(server->protocolType, expected.second);
394   }
395
396   EventBase eventBase;
397   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
398   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
399   int fds[2];
400   std::unique_ptr<NpnClient> client;
401   std::unique_ptr<NpnServer> server;
402 };
403
404 class NextProtocolNPNOnlyTest : public NextProtocolTest {
405   // For mismatching protos
406 };
407
408 class NextProtocolMismatchTest : public NextProtocolTest {
409   // For mismatching protos
410 };
411
412 TEST_P(NextProtocolTest, NpnTestOverlap) {
413   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
414   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
415                                         GetParam().second);
416
417   connect();
418
419   expectProtocol("baz");
420   expectProtocolType();
421 }
422
423 TEST_P(NextProtocolTest, NpnTestUnset) {
424   // Identical to above test, except that we want unset NPN before
425   // looping.
426   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
427   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
428                                         GetParam().second);
429
430   connect(true /* unset */);
431
432   // if alpn negotiation fails, type will appear as npn
433   expectNoProtocol();
434   EXPECT_EQ(client->protocolType, server->protocolType);
435 }
436
437 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
438   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
439   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
440                                         GetParam().second);
441
442   connect();
443
444   expectNoProtocol();
445   expectProtocolType(
446       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
447 }
448
449 TEST_P(NextProtocolNPNOnlyTest, NpnTestNoOverlap) {
450   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
451   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
452                                         GetParam().second);
453
454   connect();
455
456   expectProtocol("blub");
457   expectProtocolType();
458 }
459
460 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
461   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
462   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
463   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
464                                         GetParam().second);
465
466   connect();
467
468   expectProtocol("ponies");
469   expectProtocolType();
470 }
471
472 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
473   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
474   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
475   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
476                                         GetParam().second);
477
478   connect();
479
480   expectProtocol("blub");
481   expectProtocolType();
482 }
483
484 TEST_P(NextProtocolTest, RandomizedNpnTest) {
485   // Probability that this test will fail is 2^-64, which could be considered
486   // as negligible.
487   const int kTries = 64;
488
489   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
490                                         GetParam().first);
491   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
492                                                   GetParam().second);
493
494   std::set<string> selectedProtocols;
495   for (int i = 0; i < kTries; ++i) {
496     connect();
497
498     EXPECT_NE(client->nextProtoLength, 0);
499     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
500     EXPECT_EQ(
501         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
502         0);
503     string selected((const char*)client->nextProto, client->nextProtoLength);
504     selectedProtocols.insert(selected);
505     expectProtocolType();
506   }
507   EXPECT_EQ(selectedProtocols.size(), 2);
508 }
509
510 INSTANTIATE_TEST_CASE_P(
511     AsyncSSLSocketTest,
512     NextProtocolTest,
513     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
514                                            SSLContext::NextProtocolType::NPN),
515 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
516                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
517                                            SSLContext::NextProtocolType::ALPN),
518 #endif
519                       NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
520                                            SSLContext::NextProtocolType::ANY),
521 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
522                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
523                                            SSLContext::NextProtocolType::ANY),
524 #endif
525                       NextProtocolTypePair(SSLContext::NextProtocolType::ANY,
526                                            SSLContext::NextProtocolType::ANY)));
527
528 INSTANTIATE_TEST_CASE_P(
529     AsyncSSLSocketTest,
530     NextProtocolNPNOnlyTest,
531     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
532                                            SSLContext::NextProtocolType::NPN)));
533
534 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
535 INSTANTIATE_TEST_CASE_P(
536     AsyncSSLSocketTest,
537     NextProtocolMismatchTest,
538     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
539                                            SSLContext::NextProtocolType::ALPN),
540                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
541                                            SSLContext::NextProtocolType::NPN)));
542 #endif
543
544 #ifndef OPENSSL_NO_TLSEXT
545 /**
546  * 1. Client sends TLSEXT_HOSTNAME in client hello.
547  * 2. Server found a match SSL_CTX and use this SSL_CTX to
548  *    continue the SSL handshake.
549  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
550  */
551 TEST(AsyncSSLSocketTest, SNITestMatch) {
552   EventBase eventBase;
553   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
554   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
555   // Use the same SSLContext to continue the handshake after
556   // tlsext_hostname match.
557   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
558   const std::string serverName("xyz.newdev.facebook.com");
559   int fds[2];
560   getfds(fds);
561   getctx(clientCtx, dfServerCtx);
562
563   AsyncSSLSocket::UniquePtr clientSock(
564     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
565   AsyncSSLSocket::UniquePtr serverSock(
566     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
567   SNIClient client(std::move(clientSock));
568   SNIServer server(std::move(serverSock),
569                    dfServerCtx,
570                    hskServerCtx,
571                    serverName);
572
573   eventBase.loop();
574
575   EXPECT_TRUE(client.serverNameMatch);
576   EXPECT_TRUE(server.serverNameMatch);
577 }
578
579 /**
580  * 1. Client sends TLSEXT_HOSTNAME in client hello.
581  * 2. Server cannot find a matching SSL_CTX and continue to use
582  *    the current SSL_CTX to do the handshake.
583  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
584  */
585 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
586   EventBase eventBase;
587   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
588   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
589   // Use the same SSLContext to continue the handshake after
590   // tlsext_hostname match.
591   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
592   const std::string clientRequestingServerName("foo.com");
593   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
594
595   int fds[2];
596   getfds(fds);
597   getctx(clientCtx, dfServerCtx);
598
599   AsyncSSLSocket::UniquePtr clientSock(
600     new AsyncSSLSocket(clientCtx,
601                         &eventBase,
602                         fds[0],
603                         clientRequestingServerName));
604   AsyncSSLSocket::UniquePtr serverSock(
605     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
606   SNIClient client(std::move(clientSock));
607   SNIServer server(std::move(serverSock),
608                    dfServerCtx,
609                    hskServerCtx,
610                    serverExpectedServerName);
611
612   eventBase.loop();
613
614   EXPECT_TRUE(!client.serverNameMatch);
615   EXPECT_TRUE(!server.serverNameMatch);
616 }
617 /**
618  * 1. Client sends TLSEXT_HOSTNAME in client hello.
619  * 2. We then change the serverName.
620  * 3. We expect that we get 'false' as the result for serNameMatch.
621  */
622
623 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
624    EventBase eventBase;
625   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
626   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
627   // Use the same SSLContext to continue the handshake after
628   // tlsext_hostname match.
629   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
630   const std::string serverName("xyz.newdev.facebook.com");
631   int fds[2];
632   getfds(fds);
633   getctx(clientCtx, dfServerCtx);
634
635   AsyncSSLSocket::UniquePtr clientSock(
636     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
637   //Change the server name
638   std::string newName("new.com");
639   clientSock->setServerName(newName);
640   AsyncSSLSocket::UniquePtr serverSock(
641     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
642   SNIClient client(std::move(clientSock));
643   SNIServer server(std::move(serverSock),
644                    dfServerCtx,
645                    hskServerCtx,
646                    serverName);
647
648   eventBase.loop();
649
650   EXPECT_TRUE(!client.serverNameMatch);
651 }
652
653 /**
654  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
655  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
656  */
657 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
658   EventBase eventBase;
659   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
660   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
661   // Use the same SSLContext to continue the handshake after
662   // tlsext_hostname match.
663   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
664   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
665
666   int fds[2];
667   getfds(fds);
668   getctx(clientCtx, dfServerCtx);
669
670   AsyncSSLSocket::UniquePtr clientSock(
671     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
672   AsyncSSLSocket::UniquePtr serverSock(
673     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
674   SNIClient client(std::move(clientSock));
675   SNIServer server(std::move(serverSock),
676                    dfServerCtx,
677                    hskServerCtx,
678                    serverExpectedServerName);
679
680   eventBase.loop();
681
682   EXPECT_TRUE(!client.serverNameMatch);
683   EXPECT_TRUE(!server.serverNameMatch);
684 }
685
686 #endif
687 /**
688  * Test SSL client socket
689  */
690 TEST(AsyncSSLSocketTest, SSLClientTest) {
691   // Start listening on a local port
692   WriteCallbackBase writeCallback;
693   ReadCallback readCallback(&writeCallback);
694   HandshakeCallback handshakeCallback(&readCallback);
695   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
696   TestSSLServer server(&acceptCallback);
697
698   // Set up SSL client
699   EventBase eventBase;
700   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
701
702   client->connect();
703   EventBaseAborter eba(&eventBase, 3000);
704   eventBase.loop();
705
706   EXPECT_EQ(client->getMiss(), 1);
707   EXPECT_EQ(client->getHit(), 0);
708
709   cerr << "SSLClientTest test completed" << endl;
710 }
711
712
713 /**
714  * Test SSL client socket session re-use
715  */
716 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
717   // Start listening on a local port
718   WriteCallbackBase writeCallback;
719   ReadCallback readCallback(&writeCallback);
720   HandshakeCallback handshakeCallback(&readCallback);
721   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
722   TestSSLServer server(&acceptCallback);
723
724   // Set up SSL client
725   EventBase eventBase;
726   auto client =
727       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
728
729   client->connect();
730   EventBaseAborter eba(&eventBase, 3000);
731   eventBase.loop();
732
733   EXPECT_EQ(client->getMiss(), 1);
734   EXPECT_EQ(client->getHit(), 9);
735
736   cerr << "SSLClientTestReuse test completed" << endl;
737 }
738
739 /**
740  * Test SSL client socket timeout
741  */
742 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
743   // Start listening on a local port
744   EmptyReadCallback readCallback;
745   HandshakeCallback handshakeCallback(&readCallback,
746                                       HandshakeCallback::EXPECT_ERROR);
747   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
748   TestSSLServer server(&acceptCallback);
749
750   // Set up SSL client
751   EventBase eventBase;
752   auto client =
753       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
754   client->connect(true /* write before connect completes */);
755   EventBaseAborter eba(&eventBase, 3000);
756   eventBase.loop();
757
758   usleep(100000);
759   // This is checking that the connectError callback precedes any queued
760   // writeError callbacks.  This matches AsyncSocket's behavior
761   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
762   EXPECT_EQ(client->getErrors(), 1);
763   EXPECT_EQ(client->getMiss(), 0);
764   EXPECT_EQ(client->getHit(), 0);
765
766   cerr << "SSLClientTimeoutTest test completed" << endl;
767 }
768
769
770 /**
771  * Test SSL server async cache
772  */
773 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
774   // Start listening on a local port
775   WriteCallbackBase writeCallback;
776   ReadCallback readCallback(&writeCallback);
777   HandshakeCallback handshakeCallback(&readCallback);
778   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
779   TestSSLAsyncCacheServer server(&acceptCallback);
780
781   // Set up SSL client
782   EventBase eventBase;
783   auto client =
784       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
785
786   client->connect();
787   EventBaseAborter eba(&eventBase, 3000);
788   eventBase.loop();
789
790   EXPECT_EQ(server.getAsyncCallbacks(), 18);
791   EXPECT_EQ(server.getAsyncLookups(), 9);
792   EXPECT_EQ(client->getMiss(), 10);
793   EXPECT_EQ(client->getHit(), 0);
794
795   cerr << "SSLServerAsyncCacheTest test completed" << endl;
796 }
797
798
799 /**
800  * Test SSL server accept timeout with cache path
801  */
802 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
803   // Start listening on a local port
804   WriteCallbackBase writeCallback;
805   ReadCallback readCallback(&writeCallback);
806   EmptyReadCallback clientReadCallback;
807   HandshakeCallback handshakeCallback(&readCallback);
808   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
809   TestSSLAsyncCacheServer server(&acceptCallback);
810
811   // Set up SSL client
812   EventBase eventBase;
813   // only do a TCP connect
814   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
815   sock->connect(nullptr, server.getAddress());
816   clientReadCallback.tcpSocket_ = sock;
817   sock->setReadCB(&clientReadCallback);
818
819   EventBaseAborter eba(&eventBase, 3000);
820   eventBase.loop();
821
822   EXPECT_EQ(readCallback.state, STATE_WAITING);
823
824   cerr << "SSLServerTimeoutTest test completed" << endl;
825 }
826
827 /**
828  * Test SSL server accept timeout with cache path
829  */
830 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
831   // Start listening on a local port
832   WriteCallbackBase writeCallback;
833   ReadCallback readCallback(&writeCallback);
834   HandshakeCallback handshakeCallback(&readCallback);
835   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
836   TestSSLAsyncCacheServer server(&acceptCallback);
837
838   // Set up SSL client
839   EventBase eventBase;
840   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
841
842   client->connect();
843   EventBaseAborter eba(&eventBase, 3000);
844   eventBase.loop();
845
846   EXPECT_EQ(server.getAsyncCallbacks(), 1);
847   EXPECT_EQ(server.getAsyncLookups(), 1);
848   EXPECT_EQ(client->getErrors(), 1);
849   EXPECT_EQ(client->getMiss(), 1);
850   EXPECT_EQ(client->getHit(), 0);
851
852   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
853 }
854
855 /**
856  * Test SSL server accept timeout with cache path
857  */
858 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
859   // Start listening on a local port
860   WriteCallbackBase writeCallback;
861   ReadCallback readCallback(&writeCallback);
862   HandshakeCallback handshakeCallback(&readCallback,
863                                       HandshakeCallback::EXPECT_ERROR);
864   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
865   TestSSLAsyncCacheServer server(&acceptCallback, 500);
866
867   // Set up SSL client
868   EventBase eventBase;
869   auto client =
870       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
871
872   client->connect();
873   EventBaseAborter eba(&eventBase, 3000);
874   eventBase.loop();
875
876   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
877       handshakeCallback.closeSocket();});
878   // give time for the cache lookup to come back and find it closed
879   handshakeCallback.waitForHandshake();
880
881   EXPECT_EQ(server.getAsyncCallbacks(), 1);
882   EXPECT_EQ(server.getAsyncLookups(), 1);
883   EXPECT_EQ(client->getErrors(), 1);
884   EXPECT_EQ(client->getMiss(), 1);
885   EXPECT_EQ(client->getHit(), 0);
886
887   cerr << "SSLServerCacheCloseTest test completed" << endl;
888 }
889
890 /**
891  * Verify Client Ciphers obtained using SSL MSG Callback.
892  */
893 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
894   EventBase eventBase;
895   auto clientCtx = std::make_shared<SSLContext>();
896   auto serverCtx = std::make_shared<SSLContext>();
897   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
898   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
899   serverCtx->loadPrivateKey(testKey);
900   serverCtx->loadCertificate(testCert);
901   serverCtx->loadTrustedCertificates(testCA);
902   serverCtx->loadClientCAList(testCA);
903
904   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
905   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
906   clientCtx->loadPrivateKey(testKey);
907   clientCtx->loadCertificate(testCert);
908   clientCtx->loadTrustedCertificates(testCA);
909
910   int fds[2];
911   getfds(fds);
912
913   AsyncSSLSocket::UniquePtr clientSock(
914       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
915   AsyncSSLSocket::UniquePtr serverSock(
916       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
917
918   SSLHandshakeClient client(std::move(clientSock), true, true);
919   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
920
921   eventBase.loop();
922
923   EXPECT_EQ(server.clientCiphers_,
924             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
925   EXPECT_TRUE(client.handshakeVerify_);
926   EXPECT_TRUE(client.handshakeSuccess_);
927   EXPECT_TRUE(!client.handshakeError_);
928   EXPECT_TRUE(server.handshakeVerify_);
929   EXPECT_TRUE(server.handshakeSuccess_);
930   EXPECT_TRUE(!server.handshakeError_);
931 }
932
933 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
934   EventBase eventBase;
935   auto ctx = std::make_shared<SSLContext>();
936
937   int fds[2];
938   getfds(fds);
939
940   int bufLen = 42;
941   uint8_t majorVersion = 18;
942   uint8_t minorVersion = 25;
943
944   // Create callback buf
945   auto buf = IOBuf::create(bufLen);
946   buf->append(bufLen);
947   folly::io::RWPrivateCursor cursor(buf.get());
948   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
949   cursor.write<uint16_t>(0);
950   cursor.write<uint8_t>(38);
951   cursor.write<uint8_t>(majorVersion);
952   cursor.write<uint8_t>(minorVersion);
953   cursor.skip(32);
954   cursor.write<uint32_t>(0);
955
956   SSL* ssl = ctx->createSSL();
957   SCOPE_EXIT { SSL_free(ssl); };
958   AsyncSSLSocket::UniquePtr sock(
959       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
960   sock->enableClientHelloParsing();
961
962   // Test client hello parsing in one packet
963   AsyncSSLSocket::clientHelloParsingCallback(
964       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
965   buf.reset();
966
967   auto parsedClientHello = sock->getClientHelloInfo();
968   EXPECT_TRUE(parsedClientHello != nullptr);
969   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
970   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
971 }
972
973 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
974   EventBase eventBase;
975   auto ctx = std::make_shared<SSLContext>();
976
977   int fds[2];
978   getfds(fds);
979
980   int bufLen = 42;
981   uint8_t majorVersion = 18;
982   uint8_t minorVersion = 25;
983
984   // Create callback buf
985   auto buf = IOBuf::create(bufLen);
986   buf->append(bufLen);
987   folly::io::RWPrivateCursor cursor(buf.get());
988   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
989   cursor.write<uint16_t>(0);
990   cursor.write<uint8_t>(38);
991   cursor.write<uint8_t>(majorVersion);
992   cursor.write<uint8_t>(minorVersion);
993   cursor.skip(32);
994   cursor.write<uint32_t>(0);
995
996   SSL* ssl = ctx->createSSL();
997   SCOPE_EXIT { SSL_free(ssl); };
998   AsyncSSLSocket::UniquePtr sock(
999       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1000   sock->enableClientHelloParsing();
1001
1002   // Test parsing with two packets with first packet size < 3
1003   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1004   AsyncSSLSocket::clientHelloParsingCallback(
1005       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1006       ssl, sock.get());
1007   bufCopy.reset();
1008   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1009   AsyncSSLSocket::clientHelloParsingCallback(
1010       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1011       ssl, sock.get());
1012   bufCopy.reset();
1013
1014   auto parsedClientHello = sock->getClientHelloInfo();
1015   EXPECT_TRUE(parsedClientHello != nullptr);
1016   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1017   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1018 }
1019
1020 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1021   EventBase eventBase;
1022   auto ctx = std::make_shared<SSLContext>();
1023
1024   int fds[2];
1025   getfds(fds);
1026
1027   int bufLen = 42;
1028   uint8_t majorVersion = 18;
1029   uint8_t minorVersion = 25;
1030
1031   // Create callback buf
1032   auto buf = IOBuf::create(bufLen);
1033   buf->append(bufLen);
1034   folly::io::RWPrivateCursor cursor(buf.get());
1035   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1036   cursor.write<uint16_t>(0);
1037   cursor.write<uint8_t>(38);
1038   cursor.write<uint8_t>(majorVersion);
1039   cursor.write<uint8_t>(minorVersion);
1040   cursor.skip(32);
1041   cursor.write<uint32_t>(0);
1042
1043   SSL* ssl = ctx->createSSL();
1044   SCOPE_EXIT { SSL_free(ssl); };
1045   AsyncSSLSocket::UniquePtr sock(
1046       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1047   sock->enableClientHelloParsing();
1048
1049   // Test parsing with multiple small packets
1050   for (uint64_t i = 0; i < buf->length(); i += 3) {
1051     auto bufCopy = folly::IOBuf::copyBuffer(
1052         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1053     AsyncSSLSocket::clientHelloParsingCallback(
1054         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1055         ssl, sock.get());
1056     bufCopy.reset();
1057   }
1058
1059   auto parsedClientHello = sock->getClientHelloInfo();
1060   EXPECT_TRUE(parsedClientHello != nullptr);
1061   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1062   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1063 }
1064
1065 /**
1066  * Verify sucessful behavior of SSL certificate validation.
1067  */
1068 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1069   EventBase eventBase;
1070   auto clientCtx = std::make_shared<SSLContext>();
1071   auto dfServerCtx = std::make_shared<SSLContext>();
1072
1073   int fds[2];
1074   getfds(fds);
1075   getctx(clientCtx, dfServerCtx);
1076
1077   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1078   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1079
1080   AsyncSSLSocket::UniquePtr clientSock(
1081     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1082   AsyncSSLSocket::UniquePtr serverSock(
1083     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1084
1085   SSLHandshakeClient client(std::move(clientSock), true, true);
1086   clientCtx->loadTrustedCertificates(testCA);
1087
1088   SSLHandshakeServer server(std::move(serverSock), true, true);
1089
1090   eventBase.loop();
1091
1092   EXPECT_TRUE(client.handshakeVerify_);
1093   EXPECT_TRUE(client.handshakeSuccess_);
1094   EXPECT_TRUE(!client.handshakeError_);
1095   EXPECT_LE(0, client.handshakeTime.count());
1096   EXPECT_TRUE(!server.handshakeVerify_);
1097   EXPECT_TRUE(server.handshakeSuccess_);
1098   EXPECT_TRUE(!server.handshakeError_);
1099   EXPECT_LE(0, server.handshakeTime.count());
1100 }
1101
1102 /**
1103  * Verify that the client's verification callback is able to fail SSL
1104  * connection establishment.
1105  */
1106 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1107   EventBase eventBase;
1108   auto clientCtx = std::make_shared<SSLContext>();
1109   auto dfServerCtx = std::make_shared<SSLContext>();
1110
1111   int fds[2];
1112   getfds(fds);
1113   getctx(clientCtx, dfServerCtx);
1114
1115   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1116   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1117
1118   AsyncSSLSocket::UniquePtr clientSock(
1119     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1120   AsyncSSLSocket::UniquePtr serverSock(
1121     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1122
1123   SSLHandshakeClient client(std::move(clientSock), true, false);
1124   clientCtx->loadTrustedCertificates(testCA);
1125
1126   SSLHandshakeServer server(std::move(serverSock), true, true);
1127
1128   eventBase.loop();
1129
1130   EXPECT_TRUE(client.handshakeVerify_);
1131   EXPECT_TRUE(!client.handshakeSuccess_);
1132   EXPECT_TRUE(client.handshakeError_);
1133   EXPECT_LE(0, client.handshakeTime.count());
1134   EXPECT_TRUE(!server.handshakeVerify_);
1135   EXPECT_TRUE(!server.handshakeSuccess_);
1136   EXPECT_TRUE(server.handshakeError_);
1137   EXPECT_LE(0, server.handshakeTime.count());
1138 }
1139
1140 /**
1141  * Verify that the options in SSLContext can be overridden in
1142  * sslConnect/Accept.i.e specifying that no validation should be performed
1143  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1144  * the validation callback.
1145  */
1146 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1147   EventBase eventBase;
1148   auto clientCtx = std::make_shared<SSLContext>();
1149   auto dfServerCtx = std::make_shared<SSLContext>();
1150
1151   int fds[2];
1152   getfds(fds);
1153   getctx(clientCtx, dfServerCtx);
1154
1155   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1156   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1157
1158   AsyncSSLSocket::UniquePtr clientSock(
1159     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1160   AsyncSSLSocket::UniquePtr serverSock(
1161     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1162
1163   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1164   clientCtx->loadTrustedCertificates(testCA);
1165
1166   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1167
1168   eventBase.loop();
1169
1170   EXPECT_TRUE(!client.handshakeVerify_);
1171   EXPECT_TRUE(client.handshakeSuccess_);
1172   EXPECT_TRUE(!client.handshakeError_);
1173   EXPECT_LE(0, client.handshakeTime.count());
1174   EXPECT_TRUE(!server.handshakeVerify_);
1175   EXPECT_TRUE(server.handshakeSuccess_);
1176   EXPECT_TRUE(!server.handshakeError_);
1177   EXPECT_LE(0, server.handshakeTime.count());
1178 }
1179
1180 /**
1181  * Verify that the options in SSLContext can be overridden in
1182  * sslConnect/Accept. Enable verification even if context says otherwise.
1183  * Test requireClientCert with client cert
1184  */
1185 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1186   EventBase eventBase;
1187   auto clientCtx = std::make_shared<SSLContext>();
1188   auto serverCtx = std::make_shared<SSLContext>();
1189   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1190   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1191   serverCtx->loadPrivateKey(testKey);
1192   serverCtx->loadCertificate(testCert);
1193   serverCtx->loadTrustedCertificates(testCA);
1194   serverCtx->loadClientCAList(testCA);
1195
1196   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1197   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1198   clientCtx->loadPrivateKey(testKey);
1199   clientCtx->loadCertificate(testCert);
1200   clientCtx->loadTrustedCertificates(testCA);
1201
1202   int fds[2];
1203   getfds(fds);
1204
1205   AsyncSSLSocket::UniquePtr clientSock(
1206       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1207   AsyncSSLSocket::UniquePtr serverSock(
1208       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1209
1210   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1211   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1212
1213   eventBase.loop();
1214
1215   EXPECT_TRUE(client.handshakeVerify_);
1216   EXPECT_TRUE(client.handshakeSuccess_);
1217   EXPECT_FALSE(client.handshakeError_);
1218   EXPECT_LE(0, client.handshakeTime.count());
1219   EXPECT_TRUE(server.handshakeVerify_);
1220   EXPECT_TRUE(server.handshakeSuccess_);
1221   EXPECT_FALSE(server.handshakeError_);
1222   EXPECT_LE(0, server.handshakeTime.count());
1223 }
1224
1225 /**
1226  * Verify that the client's verification callback is able to override
1227  * the preverification failure and allow a successful connection.
1228  */
1229 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1230   EventBase eventBase;
1231   auto clientCtx = std::make_shared<SSLContext>();
1232   auto dfServerCtx = std::make_shared<SSLContext>();
1233
1234   int fds[2];
1235   getfds(fds);
1236   getctx(clientCtx, dfServerCtx);
1237
1238   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1239   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1240
1241   AsyncSSLSocket::UniquePtr clientSock(
1242     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1243   AsyncSSLSocket::UniquePtr serverSock(
1244     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1245
1246   SSLHandshakeClient client(std::move(clientSock), false, true);
1247   SSLHandshakeServer server(std::move(serverSock), true, true);
1248
1249   eventBase.loop();
1250
1251   EXPECT_TRUE(client.handshakeVerify_);
1252   EXPECT_TRUE(client.handshakeSuccess_);
1253   EXPECT_TRUE(!client.handshakeError_);
1254   EXPECT_LE(0, client.handshakeTime.count());
1255   EXPECT_TRUE(!server.handshakeVerify_);
1256   EXPECT_TRUE(server.handshakeSuccess_);
1257   EXPECT_TRUE(!server.handshakeError_);
1258   EXPECT_LE(0, server.handshakeTime.count());
1259 }
1260
1261 /**
1262  * Verify that specifying that no validation should be performed allows an
1263  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1264  * callback.
1265  */
1266 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1267   EventBase eventBase;
1268   auto clientCtx = std::make_shared<SSLContext>();
1269   auto dfServerCtx = std::make_shared<SSLContext>();
1270
1271   int fds[2];
1272   getfds(fds);
1273   getctx(clientCtx, dfServerCtx);
1274
1275   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1276   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1277
1278   AsyncSSLSocket::UniquePtr clientSock(
1279     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1280   AsyncSSLSocket::UniquePtr serverSock(
1281     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1282
1283   SSLHandshakeClient client(std::move(clientSock), false, false);
1284   SSLHandshakeServer server(std::move(serverSock), false, false);
1285
1286   eventBase.loop();
1287
1288   EXPECT_TRUE(!client.handshakeVerify_);
1289   EXPECT_TRUE(client.handshakeSuccess_);
1290   EXPECT_TRUE(!client.handshakeError_);
1291   EXPECT_LE(0, client.handshakeTime.count());
1292   EXPECT_TRUE(!server.handshakeVerify_);
1293   EXPECT_TRUE(server.handshakeSuccess_);
1294   EXPECT_TRUE(!server.handshakeError_);
1295   EXPECT_LE(0, server.handshakeTime.count());
1296 }
1297
1298 /**
1299  * Test requireClientCert with client cert
1300  */
1301 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1302   EventBase eventBase;
1303   auto clientCtx = std::make_shared<SSLContext>();
1304   auto serverCtx = std::make_shared<SSLContext>();
1305   serverCtx->setVerificationOption(
1306       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1307   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1308   serverCtx->loadPrivateKey(testKey);
1309   serverCtx->loadCertificate(testCert);
1310   serverCtx->loadTrustedCertificates(testCA);
1311   serverCtx->loadClientCAList(testCA);
1312
1313   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1314   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1315   clientCtx->loadPrivateKey(testKey);
1316   clientCtx->loadCertificate(testCert);
1317   clientCtx->loadTrustedCertificates(testCA);
1318
1319   int fds[2];
1320   getfds(fds);
1321
1322   AsyncSSLSocket::UniquePtr clientSock(
1323       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1324   AsyncSSLSocket::UniquePtr serverSock(
1325       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1326
1327   SSLHandshakeClient client(std::move(clientSock), true, true);
1328   SSLHandshakeServer server(std::move(serverSock), true, true);
1329
1330   eventBase.loop();
1331
1332   EXPECT_TRUE(client.handshakeVerify_);
1333   EXPECT_TRUE(client.handshakeSuccess_);
1334   EXPECT_FALSE(client.handshakeError_);
1335   EXPECT_LE(0, client.handshakeTime.count());
1336   EXPECT_TRUE(server.handshakeVerify_);
1337   EXPECT_TRUE(server.handshakeSuccess_);
1338   EXPECT_FALSE(server.handshakeError_);
1339   EXPECT_LE(0, server.handshakeTime.count());
1340 }
1341
1342
1343 /**
1344  * Test requireClientCert with no client cert
1345  */
1346 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1347   EventBase eventBase;
1348   auto clientCtx = std::make_shared<SSLContext>();
1349   auto serverCtx = std::make_shared<SSLContext>();
1350   serverCtx->setVerificationOption(
1351       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1352   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1353   serverCtx->loadPrivateKey(testKey);
1354   serverCtx->loadCertificate(testCert);
1355   serverCtx->loadTrustedCertificates(testCA);
1356   serverCtx->loadClientCAList(testCA);
1357   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1358   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1359
1360   int fds[2];
1361   getfds(fds);
1362
1363   AsyncSSLSocket::UniquePtr clientSock(
1364       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1365   AsyncSSLSocket::UniquePtr serverSock(
1366       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1367
1368   SSLHandshakeClient client(std::move(clientSock), false, false);
1369   SSLHandshakeServer server(std::move(serverSock), false, false);
1370
1371   eventBase.loop();
1372
1373   EXPECT_FALSE(server.handshakeVerify_);
1374   EXPECT_FALSE(server.handshakeSuccess_);
1375   EXPECT_TRUE(server.handshakeError_);
1376   EXPECT_LE(0, client.handshakeTime.count());
1377   EXPECT_LE(0, server.handshakeTime.count());
1378 }
1379
1380 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1381   auto cert = getFileAsBuf(testCert);
1382   auto key = getFileAsBuf(testKey);
1383
1384   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1385   BIO_write(certBio.get(), cert.data(), cert.size());
1386   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1387   BIO_write(keyBio.get(), key.data(), key.size());
1388
1389   // Create SSL structs from buffers to get properties
1390   ssl::X509UniquePtr certStruct(
1391       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1392   ssl::EvpPkeyUniquePtr keyStruct(
1393       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1394   certBio = nullptr;
1395   keyBio = nullptr;
1396
1397   auto origCommonName = getCommonName(certStruct.get());
1398   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1399   certStruct = nullptr;
1400   keyStruct = nullptr;
1401
1402   auto ctx = std::make_shared<SSLContext>();
1403   ctx->loadPrivateKeyFromBufferPEM(key);
1404   ctx->loadCertificateFromBufferPEM(cert);
1405   ctx->loadTrustedCertificates(testCA);
1406
1407   ssl::SSLUniquePtr ssl(ctx->createSSL());
1408
1409   auto newCert = SSL_get_certificate(ssl.get());
1410   auto newKey = SSL_get_privatekey(ssl.get());
1411
1412   // Get properties from SSL struct
1413   auto newCommonName = getCommonName(newCert);
1414   auto newKeySize = EVP_PKEY_bits(newKey);
1415
1416   // Check that the key and cert have the expected properties
1417   EXPECT_EQ(origCommonName, newCommonName);
1418   EXPECT_EQ(origKeySize, newKeySize);
1419 }
1420
1421 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1422   EventBase eb;
1423
1424   // Set up SSL context.
1425   auto sslContext = std::make_shared<SSLContext>();
1426   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1427
1428   // create SSL socket
1429   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1430
1431   EXPECT_EQ(1500, socket->getMinWriteSize());
1432
1433   socket->setMinWriteSize(0);
1434   EXPECT_EQ(0, socket->getMinWriteSize());
1435   socket->setMinWriteSize(50000);
1436   EXPECT_EQ(50000, socket->getMinWriteSize());
1437 }
1438
1439 class ReadCallbackTerminator : public ReadCallback {
1440  public:
1441   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1442       : ReadCallback(wcb)
1443       , base_(base) {}
1444
1445   // Do not write data back, terminate the loop.
1446   void readDataAvailable(size_t len) noexcept override {
1447     std::cerr << "readDataAvailable, len " << len << std::endl;
1448
1449     currentBuffer.length = len;
1450
1451     buffers.push_back(currentBuffer);
1452     currentBuffer.reset();
1453     state = STATE_SUCCEEDED;
1454
1455     socket_->setReadCB(nullptr);
1456     base_->terminateLoopSoon();
1457   }
1458  private:
1459   EventBase* base_;
1460 };
1461
1462
1463 /**
1464  * Test a full unencrypted codepath
1465  */
1466 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1467   EventBase base;
1468
1469   auto clientCtx = std::make_shared<folly::SSLContext>();
1470   auto serverCtx = std::make_shared<folly::SSLContext>();
1471   int fds[2];
1472   getfds(fds);
1473   getctx(clientCtx, serverCtx);
1474   auto client = AsyncSSLSocket::newSocket(
1475                   clientCtx, &base, fds[0], false, true);
1476   auto server = AsyncSSLSocket::newSocket(
1477                   serverCtx, &base, fds[1], true, true);
1478
1479   ReadCallbackTerminator readCallback(&base, nullptr);
1480   server->setReadCB(&readCallback);
1481   readCallback.setSocket(server);
1482
1483   uint8_t buf[128];
1484   memset(buf, 'a', sizeof(buf));
1485   client->write(nullptr, buf, sizeof(buf));
1486
1487   // Check that bytes are unencrypted
1488   char c;
1489   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1490   EXPECT_EQ('a', c);
1491
1492   EventBaseAborter eba(&base, 3000);
1493   base.loop();
1494
1495   EXPECT_EQ(1, readCallback.buffers.size());
1496   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1497
1498   server->setReadCB(&readCallback);
1499
1500   // Unencrypted
1501   server->sslAccept(nullptr);
1502   client->sslConn(nullptr);
1503
1504   // Do NOT wait for handshake, writing should be queued and happen after
1505
1506   client->write(nullptr, buf, sizeof(buf));
1507
1508   // Check that bytes are *not* unencrypted
1509   char c2;
1510   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1511   EXPECT_NE('a', c2);
1512
1513
1514   base.loop();
1515
1516   EXPECT_EQ(2, readCallback.buffers.size());
1517   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1518 }
1519
1520 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1521   // Start listening on a local port
1522   WriteCallbackBase writeCallback;
1523   WriteErrorCallback readCallback(&writeCallback);
1524   HandshakeCallback handshakeCallback(&readCallback,
1525                                       HandshakeCallback::EXPECT_ERROR);
1526   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1527   TestSSLServer server(&acceptCallback);
1528
1529   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1530   socket->open();
1531   uint8_t buf[3] = {0x16, 0x03, 0x01};
1532   socket->write(buf, sizeof(buf));
1533   socket->closeWithReset();
1534
1535   handshakeCallback.waitForHandshake();
1536   EXPECT_NE(handshakeCallback.errorString_.find("SSL_ERROR_SYSCALL"),
1537             std::string::npos);
1538   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1539 }
1540
1541 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1542   // Start listening on a local port
1543   WriteCallbackBase writeCallback;
1544   WriteErrorCallback readCallback(&writeCallback);
1545   HandshakeCallback handshakeCallback(&readCallback,
1546                                       HandshakeCallback::EXPECT_ERROR);
1547   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1548   TestSSLServer server(&acceptCallback);
1549
1550   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1551   socket->open();
1552   uint8_t buf[3] = {0x16, 0x03, 0x01};
1553   socket->write(buf, sizeof(buf));
1554   socket->close();
1555
1556   handshakeCallback.waitForHandshake();
1557   EXPECT_NE(handshakeCallback.errorString_.find("SSL_ERROR_SYSCALL"),
1558             std::string::npos);
1559   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1560 }
1561
1562 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1563   // Start listening on a local port
1564   WriteCallbackBase writeCallback;
1565   WriteErrorCallback readCallback(&writeCallback);
1566   HandshakeCallback handshakeCallback(&readCallback,
1567                                       HandshakeCallback::EXPECT_ERROR);
1568   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1569   TestSSLServer server(&acceptCallback);
1570
1571   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1572   socket->open();
1573   uint8_t buf[256] = {0x16, 0x03};
1574   memset(buf + 2, 'a', sizeof(buf) - 2);
1575   socket->write(buf, sizeof(buf));
1576   socket->close();
1577
1578   handshakeCallback.waitForHandshake();
1579   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1580             std::string::npos);
1581   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1582             std::string::npos);
1583 }
1584
1585 } // namespace
1586
1587 ///////////////////////////////////////////////////////////////////////////
1588 // init_unit_test_suite
1589 ///////////////////////////////////////////////////////////////////////////
1590 namespace {
1591 struct Initializer {
1592   Initializer() {
1593     signal(SIGPIPE, SIG_IGN);
1594   }
1595 };
1596 Initializer initializer;
1597 } // anonymous