Allow SSLContext to read certificates and keys from memory
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <signal.h>
19 #include <pthread.h>
20
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/SocketAddress.h>
24
25 #include <folly/io/async/test/BlockingSocket.h>
26
27 #include <fstream>
28 #include <gtest/gtest.h>
29 #include <iostream>
30 #include <list>
31 #include <set>
32 #include <unistd.h>
33 #include <fcntl.h>
34 #include <openssl/bio.h>
35 #include <poll.h>
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 #include <netinet/tcp.h>
39 #include <folly/io/Cursor.h>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 namespace folly {
49 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
50 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
51 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
52
53 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
54 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
55 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
56
57 constexpr size_t SSLClient::kMaxReadBufferSz;
58 constexpr size_t SSLClient::kMaxReadsPerEvent;
59
60 inline void BIO_free_fb(BIO* bio) { CHECK_EQ(1, BIO_free(bio)); }
61 using BIO_deleter = folly::static_function_deleter<BIO, &BIO_free_fb>;
62 using X509_deleter = folly::static_function_deleter<X509, &X509_free>;
63 using SSL_deleter = folly::static_function_deleter<SSL, &SSL_free>;
64 using EVP_PKEY_deleter =
65     folly::static_function_deleter<EVP_PKEY, &EVP_PKEY_free>;
66
67 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb)
68     : ctx_(new folly::SSLContext),
69       acb_(acb),
70       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
71   // Set up the SSL context
72   ctx_->loadCertificate(testCert);
73   ctx_->loadPrivateKey(testKey);
74   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
75
76   acb_->ctx_ = ctx_;
77   acb_->base_ = &evb_;
78
79   //set up the listening socket
80   socket_->bind(0);
81   socket_->getAddress(&address_);
82   socket_->listen(100);
83   socket_->addAcceptCallback(acb_, &evb_);
84   socket_->startAccepting();
85
86   int ret = pthread_create(&thread_, nullptr, Main, this);
87   assert(ret == 0);
88   (void)ret;
89
90   std::cerr << "Accepting connections on " << address_ << std::endl;
91 }
92
93 void getfds(int fds[2]) {
94   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
95     FAIL() << "failed to create socketpair: " << strerror(errno);
96   }
97   for (int idx = 0; idx < 2; ++idx) {
98     int flags = fcntl(fds[idx], F_GETFL, 0);
99     if (flags == -1) {
100       FAIL() << "failed to get flags for socket " << idx << ": "
101              << strerror(errno);
102     }
103     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
104       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
105              << strerror(errno);
106     }
107   }
108 }
109
110 void getctx(
111   std::shared_ptr<folly::SSLContext> clientCtx,
112   std::shared_ptr<folly::SSLContext> serverCtx) {
113   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
114
115   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
116   serverCtx->loadCertificate(
117       testCert);
118   serverCtx->loadPrivateKey(
119       testKey);
120 }
121
122 void sslsocketpair(
123   EventBase* eventBase,
124   AsyncSSLSocket::UniquePtr* clientSock,
125   AsyncSSLSocket::UniquePtr* serverSock) {
126   auto clientCtx = std::make_shared<folly::SSLContext>();
127   auto serverCtx = std::make_shared<folly::SSLContext>();
128   int fds[2];
129   getfds(fds);
130   getctx(clientCtx, serverCtx);
131   clientSock->reset(new AsyncSSLSocket(
132                       clientCtx, eventBase, fds[0], false));
133   serverSock->reset(new AsyncSSLSocket(
134                       serverCtx, eventBase, fds[1], true));
135
136   // (*clientSock)->setSendTimeout(100);
137   // (*serverSock)->setSendTimeout(100);
138 }
139
140 // client protocol filters
141 bool clientProtoFilterPickPony(unsigned char** client,
142   unsigned int* client_len, const unsigned char*, unsigned int ) {
143   //the protocol string in length prefixed byte string. the
144   //length byte is not included in the length
145   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
146   *client = p;
147   *client_len = 7;
148   return true;
149 }
150
151 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
152   const unsigned char*, unsigned int) {
153   return false;
154 }
155
156 std::string getFileAsBuf(const char* fileName) {
157   std::string buffer;
158   folly::readFile(fileName, buffer);
159   return buffer;
160 }
161
162 std::string getCommonName(X509* cert) {
163   X509_NAME* subject = X509_get_subject_name(cert);
164   std::string cn;
165   cn.resize(ub_common_name);
166   X509_NAME_get_text_by_NID(
167       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
168   return cn;
169 }
170
171 /**
172  * Test connecting to, writing to, reading from, and closing the
173  * connection to the SSL server.
174  */
175 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
176   // Start listening on a local port
177   WriteCallbackBase writeCallback;
178   ReadCallback readCallback(&writeCallback);
179   HandshakeCallback handshakeCallback(&readCallback);
180   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
181   TestSSLServer server(&acceptCallback);
182
183   // Set up SSL context.
184   std::shared_ptr<SSLContext> sslContext(new SSLContext());
185   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
186   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
187   //sslContext->authenticate(true, false);
188
189   // connect
190   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
191                                                  sslContext);
192   socket->open();
193
194   // write()
195   uint8_t buf[128];
196   memset(buf, 'a', sizeof(buf));
197   socket->write(buf, sizeof(buf));
198
199   // read()
200   uint8_t readbuf[128];
201   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
202   EXPECT_EQ(bytesRead, 128);
203   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
204
205   // close()
206   socket->close();
207
208   cerr << "ConnectWriteReadClose test completed" << endl;
209 }
210
211 /**
212  * Negative test for handshakeError().
213  */
214 TEST(AsyncSSLSocketTest, HandshakeError) {
215   // Start listening on a local port
216   WriteCallbackBase writeCallback;
217   ReadCallback readCallback(&writeCallback);
218   HandshakeCallback handshakeCallback(&readCallback);
219   HandshakeErrorCallback acceptCallback(&handshakeCallback);
220   TestSSLServer server(&acceptCallback);
221
222   // Set up SSL context.
223   std::shared_ptr<SSLContext> sslContext(new SSLContext());
224   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
225
226   // connect
227   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
228                                                  sslContext);
229   // read()
230   bool ex = false;
231   try {
232     socket->open();
233
234     uint8_t readbuf[128];
235     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
236     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
237   } catch (AsyncSocketException &e) {
238     ex = true;
239   }
240   EXPECT_TRUE(ex);
241
242   // close()
243   socket->close();
244   cerr << "HandshakeError test completed" << endl;
245 }
246
247 /**
248  * Negative test for readError().
249  */
250 TEST(AsyncSSLSocketTest, ReadError) {
251   // Start listening on a local port
252   WriteCallbackBase writeCallback;
253   ReadErrorCallback readCallback(&writeCallback);
254   HandshakeCallback handshakeCallback(&readCallback);
255   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
256   TestSSLServer server(&acceptCallback);
257
258   // Set up SSL context.
259   std::shared_ptr<SSLContext> sslContext(new SSLContext());
260   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
261
262   // connect
263   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
264                                                  sslContext);
265   socket->open();
266
267   // write something to trigger ssl handshake
268   uint8_t buf[128];
269   memset(buf, 'a', sizeof(buf));
270   socket->write(buf, sizeof(buf));
271
272   socket->close();
273   cerr << "ReadError test completed" << endl;
274 }
275
276 /**
277  * Negative test for writeError().
278  */
279 TEST(AsyncSSLSocketTest, WriteError) {
280   // Start listening on a local port
281   WriteCallbackBase writeCallback;
282   WriteErrorCallback readCallback(&writeCallback);
283   HandshakeCallback handshakeCallback(&readCallback);
284   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
285   TestSSLServer server(&acceptCallback);
286
287   // Set up SSL context.
288   std::shared_ptr<SSLContext> sslContext(new SSLContext());
289   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
290
291   // connect
292   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
293                                                  sslContext);
294   socket->open();
295
296   // write something to trigger ssl handshake
297   uint8_t buf[128];
298   memset(buf, 'a', sizeof(buf));
299   socket->write(buf, sizeof(buf));
300
301   socket->close();
302   cerr << "WriteError test completed" << endl;
303 }
304
305 /**
306  * Test a socket with TCP_NODELAY unset.
307  */
308 TEST(AsyncSSLSocketTest, SocketWithDelay) {
309   // Start listening on a local port
310   WriteCallbackBase writeCallback;
311   ReadCallback readCallback(&writeCallback);
312   HandshakeCallback handshakeCallback(&readCallback);
313   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
314   TestSSLServer server(&acceptCallback);
315
316   // Set up SSL context.
317   std::shared_ptr<SSLContext> sslContext(new SSLContext());
318   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
319
320   // connect
321   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
322                                                  sslContext);
323   socket->open();
324
325   // write()
326   uint8_t buf[128];
327   memset(buf, 'a', sizeof(buf));
328   socket->write(buf, sizeof(buf));
329
330   // read()
331   uint8_t readbuf[128];
332   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
333   EXPECT_EQ(bytesRead, 128);
334   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
335
336   // close()
337   socket->close();
338
339   cerr << "SocketWithDelay test completed" << endl;
340 }
341
342 using NextProtocolTypePair =
343     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
344
345 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
346   // For matching protos
347  public:
348   void SetUp() override { getctx(clientCtx, serverCtx); }
349
350   void connect(bool unset = false) {
351     getfds(fds);
352
353     if (unset) {
354       // unsetting NPN for any of [client, server] is enough to make NPN not
355       // work
356       clientCtx->unsetNextProtocols();
357     }
358
359     AsyncSSLSocket::UniquePtr clientSock(
360       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
361     AsyncSSLSocket::UniquePtr serverSock(
362       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
363     client = folly::make_unique<NpnClient>(std::move(clientSock));
364     server = folly::make_unique<NpnServer>(std::move(serverSock));
365
366     eventBase.loop();
367   }
368
369   void expectProtocol(const std::string& proto) {
370     EXPECT_NE(client->nextProtoLength, 0);
371     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
372     EXPECT_EQ(
373         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
374         0);
375     string selected((const char*)client->nextProto, client->nextProtoLength);
376     EXPECT_EQ(proto, selected);
377   }
378
379   void expectNoProtocol() {
380     EXPECT_EQ(client->nextProtoLength, 0);
381     EXPECT_EQ(server->nextProtoLength, 0);
382     EXPECT_EQ(client->nextProto, nullptr);
383     EXPECT_EQ(server->nextProto, nullptr);
384   }
385
386   void expectProtocolType() {
387     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
388         GetParam().second == SSLContext::NextProtocolType::ANY) {
389       EXPECT_EQ(client->protocolType, server->protocolType);
390     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
391                GetParam().second == SSLContext::NextProtocolType::ANY) {
392       // Well not much we can say
393     } else {
394       expectProtocolType(GetParam());
395     }
396   }
397
398   void expectProtocolType(NextProtocolTypePair expected) {
399     EXPECT_EQ(client->protocolType, expected.first);
400     EXPECT_EQ(server->protocolType, expected.second);
401   }
402
403   EventBase eventBase;
404   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
405   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
406   int fds[2];
407   std::unique_ptr<NpnClient> client;
408   std::unique_ptr<NpnServer> server;
409 };
410
411 class NextProtocolNPNOnlyTest : public NextProtocolTest {
412   // For mismatching protos
413 };
414
415 class NextProtocolMismatchTest : public NextProtocolTest {
416   // For mismatching protos
417 };
418
419 TEST_P(NextProtocolTest, NpnTestOverlap) {
420   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
421   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
422                                         GetParam().second);
423
424   connect();
425
426   expectProtocol("baz");
427   expectProtocolType();
428 }
429
430 TEST_P(NextProtocolTest, NpnTestUnset) {
431   // Identical to above test, except that we want unset NPN before
432   // looping.
433   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
434   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
435                                         GetParam().second);
436
437   connect(true /* unset */);
438
439   // if alpn negotiation fails, type will appear as npn
440   expectNoProtocol();
441   EXPECT_EQ(client->protocolType, server->protocolType);
442 }
443
444 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
445   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
446   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
447                                         GetParam().second);
448
449   connect();
450
451   expectNoProtocol();
452   expectProtocolType(
453       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
454 }
455
456 TEST_P(NextProtocolNPNOnlyTest, NpnTestNoOverlap) {
457   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
458   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
459                                         GetParam().second);
460
461   connect();
462
463   expectProtocol("blub");
464   expectProtocolType();
465 }
466
467 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
468   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
469   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
470   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
471                                         GetParam().second);
472
473   connect();
474
475   expectProtocol("ponies");
476   expectProtocolType();
477 }
478
479 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
480   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
481   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
482   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
483                                         GetParam().second);
484
485   connect();
486
487   expectProtocol("blub");
488   expectProtocolType();
489 }
490
491 TEST_P(NextProtocolTest, RandomizedNpnTest) {
492   // Probability that this test will fail is 2^-64, which could be considered
493   // as negligible.
494   const int kTries = 64;
495
496   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
497                                         GetParam().first);
498   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
499                                                   GetParam().second);
500
501   std::set<string> selectedProtocols;
502   for (int i = 0; i < kTries; ++i) {
503     connect();
504
505     EXPECT_NE(client->nextProtoLength, 0);
506     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
507     EXPECT_EQ(
508         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
509         0);
510     string selected((const char*)client->nextProto, client->nextProtoLength);
511     selectedProtocols.insert(selected);
512     expectProtocolType();
513   }
514   EXPECT_EQ(selectedProtocols.size(), 2);
515 }
516
517 INSTANTIATE_TEST_CASE_P(
518     AsyncSSLSocketTest,
519     NextProtocolTest,
520     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
521                                            SSLContext::NextProtocolType::NPN),
522 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
523                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
524                                            SSLContext::NextProtocolType::ALPN),
525 #endif
526                       NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
527                                            SSLContext::NextProtocolType::ANY),
528 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
529                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
530                                            SSLContext::NextProtocolType::ANY),
531 #endif
532                       NextProtocolTypePair(SSLContext::NextProtocolType::ANY,
533                                            SSLContext::NextProtocolType::ANY)));
534
535 INSTANTIATE_TEST_CASE_P(
536     AsyncSSLSocketTest,
537     NextProtocolNPNOnlyTest,
538     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
539                                            SSLContext::NextProtocolType::NPN)));
540
541 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
542 INSTANTIATE_TEST_CASE_P(
543     AsyncSSLSocketTest,
544     NextProtocolMismatchTest,
545     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
546                                            SSLContext::NextProtocolType::ALPN),
547                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
548                                            SSLContext::NextProtocolType::NPN)));
549 #endif
550
551 #ifndef OPENSSL_NO_TLSEXT
552 /**
553  * 1. Client sends TLSEXT_HOSTNAME in client hello.
554  * 2. Server found a match SSL_CTX and use this SSL_CTX to
555  *    continue the SSL handshake.
556  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
557  */
558 TEST(AsyncSSLSocketTest, SNITestMatch) {
559   EventBase eventBase;
560   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
561   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
562   // Use the same SSLContext to continue the handshake after
563   // tlsext_hostname match.
564   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
565   const std::string serverName("xyz.newdev.facebook.com");
566   int fds[2];
567   getfds(fds);
568   getctx(clientCtx, dfServerCtx);
569
570   AsyncSSLSocket::UniquePtr clientSock(
571     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
572   AsyncSSLSocket::UniquePtr serverSock(
573     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
574   SNIClient client(std::move(clientSock));
575   SNIServer server(std::move(serverSock),
576                    dfServerCtx,
577                    hskServerCtx,
578                    serverName);
579
580   eventBase.loop();
581
582   EXPECT_TRUE(client.serverNameMatch);
583   EXPECT_TRUE(server.serverNameMatch);
584 }
585
586 /**
587  * 1. Client sends TLSEXT_HOSTNAME in client hello.
588  * 2. Server cannot find a matching SSL_CTX and continue to use
589  *    the current SSL_CTX to do the handshake.
590  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
591  */
592 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
593   EventBase eventBase;
594   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
595   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
596   // Use the same SSLContext to continue the handshake after
597   // tlsext_hostname match.
598   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
599   const std::string clientRequestingServerName("foo.com");
600   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
601
602   int fds[2];
603   getfds(fds);
604   getctx(clientCtx, dfServerCtx);
605
606   AsyncSSLSocket::UniquePtr clientSock(
607     new AsyncSSLSocket(clientCtx,
608                         &eventBase,
609                         fds[0],
610                         clientRequestingServerName));
611   AsyncSSLSocket::UniquePtr serverSock(
612     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
613   SNIClient client(std::move(clientSock));
614   SNIServer server(std::move(serverSock),
615                    dfServerCtx,
616                    hskServerCtx,
617                    serverExpectedServerName);
618
619   eventBase.loop();
620
621   EXPECT_TRUE(!client.serverNameMatch);
622   EXPECT_TRUE(!server.serverNameMatch);
623 }
624 /**
625  * 1. Client sends TLSEXT_HOSTNAME in client hello.
626  * 2. We then change the serverName.
627  * 3. We expect that we get 'false' as the result for serNameMatch.
628  */
629
630 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
631    EventBase eventBase;
632   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
633   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
634   // Use the same SSLContext to continue the handshake after
635   // tlsext_hostname match.
636   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
637   const std::string serverName("xyz.newdev.facebook.com");
638   int fds[2];
639   getfds(fds);
640   getctx(clientCtx, dfServerCtx);
641
642   AsyncSSLSocket::UniquePtr clientSock(
643     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
644   //Change the server name
645   std::string newName("new.com");
646   clientSock->setServerName(newName);
647   AsyncSSLSocket::UniquePtr serverSock(
648     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
649   SNIClient client(std::move(clientSock));
650   SNIServer server(std::move(serverSock),
651                    dfServerCtx,
652                    hskServerCtx,
653                    serverName);
654
655   eventBase.loop();
656
657   EXPECT_TRUE(!client.serverNameMatch);
658 }
659
660 /**
661  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
662  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
663  */
664 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
665   EventBase eventBase;
666   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
667   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
668   // Use the same SSLContext to continue the handshake after
669   // tlsext_hostname match.
670   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
671   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
672
673   int fds[2];
674   getfds(fds);
675   getctx(clientCtx, dfServerCtx);
676
677   AsyncSSLSocket::UniquePtr clientSock(
678     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
679   AsyncSSLSocket::UniquePtr serverSock(
680     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
681   SNIClient client(std::move(clientSock));
682   SNIServer server(std::move(serverSock),
683                    dfServerCtx,
684                    hskServerCtx,
685                    serverExpectedServerName);
686
687   eventBase.loop();
688
689   EXPECT_TRUE(!client.serverNameMatch);
690   EXPECT_TRUE(!server.serverNameMatch);
691 }
692
693 #endif
694 /**
695  * Test SSL client socket
696  */
697 TEST(AsyncSSLSocketTest, SSLClientTest) {
698   // Start listening on a local port
699   WriteCallbackBase writeCallback;
700   ReadCallback readCallback(&writeCallback);
701   HandshakeCallback handshakeCallback(&readCallback);
702   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
703   TestSSLServer server(&acceptCallback);
704
705   // Set up SSL client
706   EventBase eventBase;
707   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
708
709   client->connect();
710   EventBaseAborter eba(&eventBase, 3000);
711   eventBase.loop();
712
713   EXPECT_EQ(client->getMiss(), 1);
714   EXPECT_EQ(client->getHit(), 0);
715
716   cerr << "SSLClientTest test completed" << endl;
717 }
718
719
720 /**
721  * Test SSL client socket session re-use
722  */
723 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
724   // Start listening on a local port
725   WriteCallbackBase writeCallback;
726   ReadCallback readCallback(&writeCallback);
727   HandshakeCallback handshakeCallback(&readCallback);
728   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
729   TestSSLServer server(&acceptCallback);
730
731   // Set up SSL client
732   EventBase eventBase;
733   auto client =
734       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
735
736   client->connect();
737   EventBaseAborter eba(&eventBase, 3000);
738   eventBase.loop();
739
740   EXPECT_EQ(client->getMiss(), 1);
741   EXPECT_EQ(client->getHit(), 9);
742
743   cerr << "SSLClientTestReuse test completed" << endl;
744 }
745
746 /**
747  * Test SSL client socket timeout
748  */
749 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
750   // Start listening on a local port
751   EmptyReadCallback readCallback;
752   HandshakeCallback handshakeCallback(&readCallback,
753                                       HandshakeCallback::EXPECT_ERROR);
754   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
755   TestSSLServer server(&acceptCallback);
756
757   // Set up SSL client
758   EventBase eventBase;
759   auto client =
760       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
761   client->connect(true /* write before connect completes */);
762   EventBaseAborter eba(&eventBase, 3000);
763   eventBase.loop();
764
765   usleep(100000);
766   // This is checking that the connectError callback precedes any queued
767   // writeError callbacks.  This matches AsyncSocket's behavior
768   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
769   EXPECT_EQ(client->getErrors(), 1);
770   EXPECT_EQ(client->getMiss(), 0);
771   EXPECT_EQ(client->getHit(), 0);
772
773   cerr << "SSLClientTimeoutTest test completed" << endl;
774 }
775
776
777 /**
778  * Test SSL server async cache
779  */
780 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
781   // Start listening on a local port
782   WriteCallbackBase writeCallback;
783   ReadCallback readCallback(&writeCallback);
784   HandshakeCallback handshakeCallback(&readCallback);
785   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
786   TestSSLAsyncCacheServer server(&acceptCallback);
787
788   // Set up SSL client
789   EventBase eventBase;
790   auto client =
791       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
792
793   client->connect();
794   EventBaseAborter eba(&eventBase, 3000);
795   eventBase.loop();
796
797   EXPECT_EQ(server.getAsyncCallbacks(), 18);
798   EXPECT_EQ(server.getAsyncLookups(), 9);
799   EXPECT_EQ(client->getMiss(), 10);
800   EXPECT_EQ(client->getHit(), 0);
801
802   cerr << "SSLServerAsyncCacheTest test completed" << endl;
803 }
804
805
806 /**
807  * Test SSL server accept timeout with cache path
808  */
809 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
810   // Start listening on a local port
811   WriteCallbackBase writeCallback;
812   ReadCallback readCallback(&writeCallback);
813   EmptyReadCallback clientReadCallback;
814   HandshakeCallback handshakeCallback(&readCallback);
815   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
816   TestSSLAsyncCacheServer server(&acceptCallback);
817
818   // Set up SSL client
819   EventBase eventBase;
820   // only do a TCP connect
821   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
822   sock->connect(nullptr, server.getAddress());
823   clientReadCallback.tcpSocket_ = sock;
824   sock->setReadCB(&clientReadCallback);
825
826   EventBaseAborter eba(&eventBase, 3000);
827   eventBase.loop();
828
829   EXPECT_EQ(readCallback.state, STATE_WAITING);
830
831   cerr << "SSLServerTimeoutTest test completed" << endl;
832 }
833
834 /**
835  * Test SSL server accept timeout with cache path
836  */
837 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
838   // Start listening on a local port
839   WriteCallbackBase writeCallback;
840   ReadCallback readCallback(&writeCallback);
841   HandshakeCallback handshakeCallback(&readCallback);
842   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
843   TestSSLAsyncCacheServer server(&acceptCallback);
844
845   // Set up SSL client
846   EventBase eventBase;
847   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
848
849   client->connect();
850   EventBaseAborter eba(&eventBase, 3000);
851   eventBase.loop();
852
853   EXPECT_EQ(server.getAsyncCallbacks(), 1);
854   EXPECT_EQ(server.getAsyncLookups(), 1);
855   EXPECT_EQ(client->getErrors(), 1);
856   EXPECT_EQ(client->getMiss(), 1);
857   EXPECT_EQ(client->getHit(), 0);
858
859   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
860 }
861
862 /**
863  * Test SSL server accept timeout with cache path
864  */
865 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
866   // Start listening on a local port
867   WriteCallbackBase writeCallback;
868   ReadCallback readCallback(&writeCallback);
869   HandshakeCallback handshakeCallback(&readCallback,
870                                       HandshakeCallback::EXPECT_ERROR);
871   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
872   TestSSLAsyncCacheServer server(&acceptCallback, 500);
873
874   // Set up SSL client
875   EventBase eventBase;
876   auto client =
877       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
878
879   client->connect();
880   EventBaseAborter eba(&eventBase, 3000);
881   eventBase.loop();
882
883   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
884       handshakeCallback.closeSocket();});
885   // give time for the cache lookup to come back and find it closed
886   usleep(500000);
887
888   EXPECT_EQ(server.getAsyncCallbacks(), 1);
889   EXPECT_EQ(server.getAsyncLookups(), 1);
890   EXPECT_EQ(client->getErrors(), 1);
891   EXPECT_EQ(client->getMiss(), 1);
892   EXPECT_EQ(client->getHit(), 0);
893
894   cerr << "SSLServerCacheCloseTest test completed" << endl;
895 }
896
897 /**
898  * Verify Client Ciphers obtained using SSL MSG Callback.
899  */
900 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
901   EventBase eventBase;
902   auto clientCtx = std::make_shared<SSLContext>();
903   auto serverCtx = std::make_shared<SSLContext>();
904   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
905   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
906   serverCtx->loadPrivateKey(testKey);
907   serverCtx->loadCertificate(testCert);
908   serverCtx->loadTrustedCertificates(testCA);
909   serverCtx->loadClientCAList(testCA);
910
911   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
912   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
913   clientCtx->loadPrivateKey(testKey);
914   clientCtx->loadCertificate(testCert);
915   clientCtx->loadTrustedCertificates(testCA);
916
917   int fds[2];
918   getfds(fds);
919
920   AsyncSSLSocket::UniquePtr clientSock(
921       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
922   AsyncSSLSocket::UniquePtr serverSock(
923       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
924
925   SSLHandshakeClient client(std::move(clientSock), true, true);
926   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
927
928   eventBase.loop();
929
930   EXPECT_EQ(server.clientCiphers_,
931             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
932   EXPECT_TRUE(client.handshakeVerify_);
933   EXPECT_TRUE(client.handshakeSuccess_);
934   EXPECT_TRUE(!client.handshakeError_);
935   EXPECT_TRUE(server.handshakeVerify_);
936   EXPECT_TRUE(server.handshakeSuccess_);
937   EXPECT_TRUE(!server.handshakeError_);
938 }
939
940 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
941   EventBase eventBase;
942   auto ctx = std::make_shared<SSLContext>();
943
944   int fds[2];
945   getfds(fds);
946
947   int bufLen = 42;
948   uint8_t majorVersion = 18;
949   uint8_t minorVersion = 25;
950
951   // Create callback buf
952   auto buf = IOBuf::create(bufLen);
953   buf->append(bufLen);
954   folly::io::RWPrivateCursor cursor(buf.get());
955   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
956   cursor.write<uint16_t>(0);
957   cursor.write<uint8_t>(38);
958   cursor.write<uint8_t>(majorVersion);
959   cursor.write<uint8_t>(minorVersion);
960   cursor.skip(32);
961   cursor.write<uint32_t>(0);
962
963   SSL* ssl = ctx->createSSL();
964   SCOPE_EXIT { SSL_free(ssl); };
965   AsyncSSLSocket::UniquePtr sock(
966       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
967   sock->enableClientHelloParsing();
968
969   // Test client hello parsing in one packet
970   AsyncSSLSocket::clientHelloParsingCallback(
971       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
972   buf.reset();
973
974   auto parsedClientHello = sock->getClientHelloInfo();
975   EXPECT_TRUE(parsedClientHello != nullptr);
976   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
977   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
978 }
979
980 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
981   EventBase eventBase;
982   auto ctx = std::make_shared<SSLContext>();
983
984   int fds[2];
985   getfds(fds);
986
987   int bufLen = 42;
988   uint8_t majorVersion = 18;
989   uint8_t minorVersion = 25;
990
991   // Create callback buf
992   auto buf = IOBuf::create(bufLen);
993   buf->append(bufLen);
994   folly::io::RWPrivateCursor cursor(buf.get());
995   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
996   cursor.write<uint16_t>(0);
997   cursor.write<uint8_t>(38);
998   cursor.write<uint8_t>(majorVersion);
999   cursor.write<uint8_t>(minorVersion);
1000   cursor.skip(32);
1001   cursor.write<uint32_t>(0);
1002
1003   SSL* ssl = ctx->createSSL();
1004   SCOPE_EXIT { SSL_free(ssl); };
1005   AsyncSSLSocket::UniquePtr sock(
1006       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1007   sock->enableClientHelloParsing();
1008
1009   // Test parsing with two packets with first packet size < 3
1010   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1011   AsyncSSLSocket::clientHelloParsingCallback(
1012       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1013       ssl, sock.get());
1014   bufCopy.reset();
1015   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1016   AsyncSSLSocket::clientHelloParsingCallback(
1017       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1018       ssl, sock.get());
1019   bufCopy.reset();
1020
1021   auto parsedClientHello = sock->getClientHelloInfo();
1022   EXPECT_TRUE(parsedClientHello != nullptr);
1023   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1024   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1025 }
1026
1027 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1028   EventBase eventBase;
1029   auto ctx = std::make_shared<SSLContext>();
1030
1031   int fds[2];
1032   getfds(fds);
1033
1034   int bufLen = 42;
1035   uint8_t majorVersion = 18;
1036   uint8_t minorVersion = 25;
1037
1038   // Create callback buf
1039   auto buf = IOBuf::create(bufLen);
1040   buf->append(bufLen);
1041   folly::io::RWPrivateCursor cursor(buf.get());
1042   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1043   cursor.write<uint16_t>(0);
1044   cursor.write<uint8_t>(38);
1045   cursor.write<uint8_t>(majorVersion);
1046   cursor.write<uint8_t>(minorVersion);
1047   cursor.skip(32);
1048   cursor.write<uint32_t>(0);
1049
1050   SSL* ssl = ctx->createSSL();
1051   SCOPE_EXIT { SSL_free(ssl); };
1052   AsyncSSLSocket::UniquePtr sock(
1053       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1054   sock->enableClientHelloParsing();
1055
1056   // Test parsing with multiple small packets
1057   for (uint64_t i = 0; i < buf->length(); i += 3) {
1058     auto bufCopy = folly::IOBuf::copyBuffer(
1059         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1060     AsyncSSLSocket::clientHelloParsingCallback(
1061         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1062         ssl, sock.get());
1063     bufCopy.reset();
1064   }
1065
1066   auto parsedClientHello = sock->getClientHelloInfo();
1067   EXPECT_TRUE(parsedClientHello != nullptr);
1068   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1069   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1070 }
1071
1072 /**
1073  * Verify sucessful behavior of SSL certificate validation.
1074  */
1075 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1076   EventBase eventBase;
1077   auto clientCtx = std::make_shared<SSLContext>();
1078   auto dfServerCtx = std::make_shared<SSLContext>();
1079
1080   int fds[2];
1081   getfds(fds);
1082   getctx(clientCtx, dfServerCtx);
1083
1084   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1085   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1086
1087   AsyncSSLSocket::UniquePtr clientSock(
1088     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1089   AsyncSSLSocket::UniquePtr serverSock(
1090     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1091
1092   SSLHandshakeClient client(std::move(clientSock), true, true);
1093   clientCtx->loadTrustedCertificates(testCA);
1094
1095   SSLHandshakeServer server(std::move(serverSock), true, true);
1096
1097   eventBase.loop();
1098
1099   EXPECT_TRUE(client.handshakeVerify_);
1100   EXPECT_TRUE(client.handshakeSuccess_);
1101   EXPECT_TRUE(!client.handshakeError_);
1102   EXPECT_LE(0, client.handshakeTime.count());
1103   EXPECT_TRUE(!server.handshakeVerify_);
1104   EXPECT_TRUE(server.handshakeSuccess_);
1105   EXPECT_TRUE(!server.handshakeError_);
1106   EXPECT_LE(0, server.handshakeTime.count());
1107 }
1108
1109 /**
1110  * Verify that the client's verification callback is able to fail SSL
1111  * connection establishment.
1112  */
1113 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1114   EventBase eventBase;
1115   auto clientCtx = std::make_shared<SSLContext>();
1116   auto dfServerCtx = std::make_shared<SSLContext>();
1117
1118   int fds[2];
1119   getfds(fds);
1120   getctx(clientCtx, dfServerCtx);
1121
1122   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1123   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1124
1125   AsyncSSLSocket::UniquePtr clientSock(
1126     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1127   AsyncSSLSocket::UniquePtr serverSock(
1128     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1129
1130   SSLHandshakeClient client(std::move(clientSock), true, false);
1131   clientCtx->loadTrustedCertificates(testCA);
1132
1133   SSLHandshakeServer server(std::move(serverSock), true, true);
1134
1135   eventBase.loop();
1136
1137   EXPECT_TRUE(client.handshakeVerify_);
1138   EXPECT_TRUE(!client.handshakeSuccess_);
1139   EXPECT_TRUE(client.handshakeError_);
1140   EXPECT_LE(0, client.handshakeTime.count());
1141   EXPECT_TRUE(!server.handshakeVerify_);
1142   EXPECT_TRUE(!server.handshakeSuccess_);
1143   EXPECT_TRUE(server.handshakeError_);
1144   EXPECT_LE(0, server.handshakeTime.count());
1145 }
1146
1147 /**
1148  * Verify that the options in SSLContext can be overridden in
1149  * sslConnect/Accept.i.e specifying that no validation should be performed
1150  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1151  * the validation callback.
1152  */
1153 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1154   EventBase eventBase;
1155   auto clientCtx = std::make_shared<SSLContext>();
1156   auto dfServerCtx = std::make_shared<SSLContext>();
1157
1158   int fds[2];
1159   getfds(fds);
1160   getctx(clientCtx, dfServerCtx);
1161
1162   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1163   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1164
1165   AsyncSSLSocket::UniquePtr clientSock(
1166     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1167   AsyncSSLSocket::UniquePtr serverSock(
1168     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1169
1170   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1171   clientCtx->loadTrustedCertificates(testCA);
1172
1173   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1174
1175   eventBase.loop();
1176
1177   EXPECT_TRUE(!client.handshakeVerify_);
1178   EXPECT_TRUE(client.handshakeSuccess_);
1179   EXPECT_TRUE(!client.handshakeError_);
1180   EXPECT_LE(0, client.handshakeTime.count());
1181   EXPECT_TRUE(!server.handshakeVerify_);
1182   EXPECT_TRUE(server.handshakeSuccess_);
1183   EXPECT_TRUE(!server.handshakeError_);
1184   EXPECT_LE(0, server.handshakeTime.count());
1185 }
1186
1187 /**
1188  * Verify that the options in SSLContext can be overridden in
1189  * sslConnect/Accept. Enable verification even if context says otherwise.
1190  * Test requireClientCert with client cert
1191  */
1192 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1193   EventBase eventBase;
1194   auto clientCtx = std::make_shared<SSLContext>();
1195   auto serverCtx = std::make_shared<SSLContext>();
1196   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1197   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1198   serverCtx->loadPrivateKey(testKey);
1199   serverCtx->loadCertificate(testCert);
1200   serverCtx->loadTrustedCertificates(testCA);
1201   serverCtx->loadClientCAList(testCA);
1202
1203   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1204   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1205   clientCtx->loadPrivateKey(testKey);
1206   clientCtx->loadCertificate(testCert);
1207   clientCtx->loadTrustedCertificates(testCA);
1208
1209   int fds[2];
1210   getfds(fds);
1211
1212   AsyncSSLSocket::UniquePtr clientSock(
1213       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1214   AsyncSSLSocket::UniquePtr serverSock(
1215       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1216
1217   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1218   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1219
1220   eventBase.loop();
1221
1222   EXPECT_TRUE(client.handshakeVerify_);
1223   EXPECT_TRUE(client.handshakeSuccess_);
1224   EXPECT_FALSE(client.handshakeError_);
1225   EXPECT_LE(0, client.handshakeTime.count());
1226   EXPECT_TRUE(server.handshakeVerify_);
1227   EXPECT_TRUE(server.handshakeSuccess_);
1228   EXPECT_FALSE(server.handshakeError_);
1229   EXPECT_LE(0, server.handshakeTime.count());
1230 }
1231
1232 /**
1233  * Verify that the client's verification callback is able to override
1234  * the preverification failure and allow a successful connection.
1235  */
1236 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1237   EventBase eventBase;
1238   auto clientCtx = std::make_shared<SSLContext>();
1239   auto dfServerCtx = std::make_shared<SSLContext>();
1240
1241   int fds[2];
1242   getfds(fds);
1243   getctx(clientCtx, dfServerCtx);
1244
1245   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1246   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1247
1248   AsyncSSLSocket::UniquePtr clientSock(
1249     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1250   AsyncSSLSocket::UniquePtr serverSock(
1251     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1252
1253   SSLHandshakeClient client(std::move(clientSock), false, true);
1254   SSLHandshakeServer server(std::move(serverSock), true, true);
1255
1256   eventBase.loop();
1257
1258   EXPECT_TRUE(client.handshakeVerify_);
1259   EXPECT_TRUE(client.handshakeSuccess_);
1260   EXPECT_TRUE(!client.handshakeError_);
1261   EXPECT_LE(0, client.handshakeTime.count());
1262   EXPECT_TRUE(!server.handshakeVerify_);
1263   EXPECT_TRUE(server.handshakeSuccess_);
1264   EXPECT_TRUE(!server.handshakeError_);
1265   EXPECT_LE(0, server.handshakeTime.count());
1266 }
1267
1268 /**
1269  * Verify that specifying that no validation should be performed allows an
1270  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1271  * callback.
1272  */
1273 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1274   EventBase eventBase;
1275   auto clientCtx = std::make_shared<SSLContext>();
1276   auto dfServerCtx = std::make_shared<SSLContext>();
1277
1278   int fds[2];
1279   getfds(fds);
1280   getctx(clientCtx, dfServerCtx);
1281
1282   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1283   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1284
1285   AsyncSSLSocket::UniquePtr clientSock(
1286     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1287   AsyncSSLSocket::UniquePtr serverSock(
1288     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1289
1290   SSLHandshakeClient client(std::move(clientSock), false, false);
1291   SSLHandshakeServer server(std::move(serverSock), false, false);
1292
1293   eventBase.loop();
1294
1295   EXPECT_TRUE(!client.handshakeVerify_);
1296   EXPECT_TRUE(client.handshakeSuccess_);
1297   EXPECT_TRUE(!client.handshakeError_);
1298   EXPECT_LE(0, client.handshakeTime.count());
1299   EXPECT_TRUE(!server.handshakeVerify_);
1300   EXPECT_TRUE(server.handshakeSuccess_);
1301   EXPECT_TRUE(!server.handshakeError_);
1302   EXPECT_LE(0, server.handshakeTime.count());
1303 }
1304
1305 /**
1306  * Test requireClientCert with client cert
1307  */
1308 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1309   EventBase eventBase;
1310   auto clientCtx = std::make_shared<SSLContext>();
1311   auto serverCtx = std::make_shared<SSLContext>();
1312   serverCtx->setVerificationOption(
1313       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1314   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1315   serverCtx->loadPrivateKey(testKey);
1316   serverCtx->loadCertificate(testCert);
1317   serverCtx->loadTrustedCertificates(testCA);
1318   serverCtx->loadClientCAList(testCA);
1319
1320   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1321   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1322   clientCtx->loadPrivateKey(testKey);
1323   clientCtx->loadCertificate(testCert);
1324   clientCtx->loadTrustedCertificates(testCA);
1325
1326   int fds[2];
1327   getfds(fds);
1328
1329   AsyncSSLSocket::UniquePtr clientSock(
1330       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1331   AsyncSSLSocket::UniquePtr serverSock(
1332       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1333
1334   SSLHandshakeClient client(std::move(clientSock), true, true);
1335   SSLHandshakeServer server(std::move(serverSock), true, true);
1336
1337   eventBase.loop();
1338
1339   EXPECT_TRUE(client.handshakeVerify_);
1340   EXPECT_TRUE(client.handshakeSuccess_);
1341   EXPECT_FALSE(client.handshakeError_);
1342   EXPECT_LE(0, client.handshakeTime.count());
1343   EXPECT_TRUE(server.handshakeVerify_);
1344   EXPECT_TRUE(server.handshakeSuccess_);
1345   EXPECT_FALSE(server.handshakeError_);
1346   EXPECT_LE(0, server.handshakeTime.count());
1347 }
1348
1349
1350 /**
1351  * Test requireClientCert with no client cert
1352  */
1353 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1354   EventBase eventBase;
1355   auto clientCtx = std::make_shared<SSLContext>();
1356   auto serverCtx = std::make_shared<SSLContext>();
1357   serverCtx->setVerificationOption(
1358       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1359   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1360   serverCtx->loadPrivateKey(testKey);
1361   serverCtx->loadCertificate(testCert);
1362   serverCtx->loadTrustedCertificates(testCA);
1363   serverCtx->loadClientCAList(testCA);
1364   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1365   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1366
1367   int fds[2];
1368   getfds(fds);
1369
1370   AsyncSSLSocket::UniquePtr clientSock(
1371       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1372   AsyncSSLSocket::UniquePtr serverSock(
1373       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1374
1375   SSLHandshakeClient client(std::move(clientSock), false, false);
1376   SSLHandshakeServer server(std::move(serverSock), false, false);
1377
1378   eventBase.loop();
1379
1380   EXPECT_FALSE(server.handshakeVerify_);
1381   EXPECT_FALSE(server.handshakeSuccess_);
1382   EXPECT_TRUE(server.handshakeError_);
1383   EXPECT_LE(0, client.handshakeTime.count());
1384   EXPECT_LE(0, server.handshakeTime.count());
1385 }
1386
1387 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1388   auto cert = getFileAsBuf(testCert);
1389   auto key = getFileAsBuf(testKey);
1390
1391   std::unique_ptr<BIO, BIO_deleter> certBio(BIO_new(BIO_s_mem()));
1392   BIO_write(certBio.get(), cert.data(), cert.size());
1393   std::unique_ptr<BIO, BIO_deleter> keyBio(BIO_new(BIO_s_mem()));
1394   BIO_write(keyBio.get(), key.data(), key.size());
1395
1396   // Create SSL structs from buffers to get properties
1397   std::unique_ptr<X509, X509_deleter> certStruct(
1398       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1399   std::unique_ptr<EVP_PKEY, EVP_PKEY_deleter> keyStruct(
1400       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1401   certBio = nullptr;
1402   keyBio = nullptr;
1403
1404   auto origCommonName = getCommonName(certStruct.get());
1405   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1406   certStruct = nullptr;
1407   keyStruct = nullptr;
1408
1409   auto ctx = std::make_shared<SSLContext>();
1410   ctx->loadPrivateKeyFromBufferPEM(key);
1411   ctx->loadCertificateFromBufferPEM(cert);
1412   ctx->loadTrustedCertificates(testCA);
1413
1414   std::unique_ptr<SSL, SSL_deleter> ssl(ctx->createSSL());
1415
1416   auto newCert = SSL_get_certificate(ssl.get());
1417   auto newKey = SSL_get_privatekey(ssl.get());
1418
1419   // Get properties from SSL struct
1420   auto newCommonName = getCommonName(newCert);
1421   auto newKeySize = EVP_PKEY_bits(newKey);
1422
1423   // Check that the key and cert have the expected properties
1424   EXPECT_EQ(origCommonName, newCommonName);
1425   EXPECT_EQ(origKeySize, newKeySize);
1426 }
1427
1428 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1429   EventBase eb;
1430
1431   // Set up SSL context.
1432   auto sslContext = std::make_shared<SSLContext>();
1433   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1434
1435   // create SSL socket
1436   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1437
1438   EXPECT_EQ(1500, socket->getMinWriteSize());
1439
1440   socket->setMinWriteSize(0);
1441   EXPECT_EQ(0, socket->getMinWriteSize());
1442   socket->setMinWriteSize(50000);
1443   EXPECT_EQ(50000, socket->getMinWriteSize());
1444 }
1445
1446 class ReadCallbackTerminator : public ReadCallback {
1447  public:
1448   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1449       : ReadCallback(wcb)
1450       , base_(base) {}
1451
1452   // Do not write data back, terminate the loop.
1453   void readDataAvailable(size_t len) noexcept override {
1454     std::cerr << "readDataAvailable, len " << len << std::endl;
1455
1456     currentBuffer.length = len;
1457
1458     buffers.push_back(currentBuffer);
1459     currentBuffer.reset();
1460     state = STATE_SUCCEEDED;
1461
1462     socket_->setReadCB(nullptr);
1463     base_->terminateLoopSoon();
1464   }
1465  private:
1466   EventBase* base_;
1467 };
1468
1469
1470 /**
1471  * Test a full unencrypted codepath
1472  */
1473 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1474   EventBase base;
1475
1476   auto clientCtx = std::make_shared<folly::SSLContext>();
1477   auto serverCtx = std::make_shared<folly::SSLContext>();
1478   int fds[2];
1479   getfds(fds);
1480   getctx(clientCtx, serverCtx);
1481   auto client = AsyncSSLSocket::newSocket(
1482                   clientCtx, &base, fds[0], false, true);
1483   auto server = AsyncSSLSocket::newSocket(
1484                   serverCtx, &base, fds[1], true, true);
1485
1486   ReadCallbackTerminator readCallback(&base, nullptr);
1487   server->setReadCB(&readCallback);
1488   readCallback.setSocket(server);
1489
1490   uint8_t buf[128];
1491   memset(buf, 'a', sizeof(buf));
1492   client->write(nullptr, buf, sizeof(buf));
1493
1494   // Check that bytes are unencrypted
1495   char c;
1496   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1497   EXPECT_EQ('a', c);
1498
1499   EventBaseAborter eba(&base, 3000);
1500   base.loop();
1501
1502   EXPECT_EQ(1, readCallback.buffers.size());
1503   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1504
1505   server->setReadCB(&readCallback);
1506
1507   // Unencrypted
1508   server->sslAccept(nullptr);
1509   client->sslConn(nullptr);
1510
1511   // Do NOT wait for handshake, writing should be queued and happen after
1512
1513   client->write(nullptr, buf, sizeof(buf));
1514
1515   // Check that bytes are *not* unencrypted
1516   char c2;
1517   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1518   EXPECT_NE('a', c2);
1519
1520
1521   base.loop();
1522
1523   EXPECT_EQ(2, readCallback.buffers.size());
1524   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1525 }
1526
1527 } // namespace
1528
1529 ///////////////////////////////////////////////////////////////////////////
1530 // init_unit_test_suite
1531 ///////////////////////////////////////////////////////////////////////////
1532 namespace {
1533 struct Initializer {
1534   Initializer() {
1535     signal(SIGPIPE, SIG_IGN);
1536   }
1537 };
1538 Initializer initializer;
1539 } // anonymous