Split out SSL test server for reuse
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <signal.h>
19
20 #include <folly/SocketAddress.h>
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/portability/GMock.h>
24 #include <folly/portability/GTest.h>
25 #include <folly/portability/OpenSSL.h>
26 #include <folly/portability/Sockets.h>
27 #include <folly/portability/Unistd.h>
28
29 #include <folly/io/async/test/BlockingSocket.h>
30
31 #include <fcntl.h>
32 #include <folly/io/Cursor.h>
33 #include <openssl/bio.h>
34 #include <sys/types.h>
35 #include <fstream>
36 #include <iostream>
37 #include <list>
38 #include <set>
39 #include <thread>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 using namespace testing;
49
50 namespace folly {
51 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
52 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
53 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
54
55 constexpr size_t SSLClient::kMaxReadBufferSz;
56 constexpr size_t SSLClient::kMaxReadsPerEvent;
57
58 void getfds(int fds[2]) {
59   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
60     FAIL() << "failed to create socketpair: " << strerror(errno);
61   }
62   for (int idx = 0; idx < 2; ++idx) {
63     int flags = fcntl(fds[idx], F_GETFL, 0);
64     if (flags == -1) {
65       FAIL() << "failed to get flags for socket " << idx << ": "
66              << strerror(errno);
67     }
68     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
69       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
70              << strerror(errno);
71     }
72   }
73 }
74
75 void getctx(
76   std::shared_ptr<folly::SSLContext> clientCtx,
77   std::shared_ptr<folly::SSLContext> serverCtx) {
78   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
79
80   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
81   serverCtx->loadCertificate(kTestCert);
82   serverCtx->loadPrivateKey(kTestKey);
83 }
84
85 void sslsocketpair(
86   EventBase* eventBase,
87   AsyncSSLSocket::UniquePtr* clientSock,
88   AsyncSSLSocket::UniquePtr* serverSock) {
89   auto clientCtx = std::make_shared<folly::SSLContext>();
90   auto serverCtx = std::make_shared<folly::SSLContext>();
91   int fds[2];
92   getfds(fds);
93   getctx(clientCtx, serverCtx);
94   clientSock->reset(new AsyncSSLSocket(
95                       clientCtx, eventBase, fds[0], false));
96   serverSock->reset(new AsyncSSLSocket(
97                       serverCtx, eventBase, fds[1], true));
98
99   // (*clientSock)->setSendTimeout(100);
100   // (*serverSock)->setSendTimeout(100);
101 }
102
103 // client protocol filters
104 bool clientProtoFilterPickPony(unsigned char** client,
105   unsigned int* client_len, const unsigned char*, unsigned int ) {
106   //the protocol string in length prefixed byte string. the
107   //length byte is not included in the length
108   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
109   *client = p;
110   *client_len = 7;
111   return true;
112 }
113
114 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
115   const unsigned char*, unsigned int) {
116   return false;
117 }
118
119 std::string getFileAsBuf(const char* fileName) {
120   std::string buffer;
121   folly::readFile(fileName, buffer);
122   return buffer;
123 }
124
125 std::string getCommonName(X509* cert) {
126   X509_NAME* subject = X509_get_subject_name(cert);
127   std::string cn;
128   cn.resize(ub_common_name);
129   X509_NAME_get_text_by_NID(
130       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
131   return cn;
132 }
133
134 /**
135  * Test connecting to, writing to, reading from, and closing the
136  * connection to the SSL server.
137  */
138 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
139   // Start listening on a local port
140   WriteCallbackBase writeCallback;
141   ReadCallback readCallback(&writeCallback);
142   HandshakeCallback handshakeCallback(&readCallback);
143   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
144   TestSSLServer server(&acceptCallback);
145
146   // Set up SSL context.
147   std::shared_ptr<SSLContext> sslContext(new SSLContext());
148   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
149   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
150   //sslContext->authenticate(true, false);
151
152   // connect
153   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
154                                                  sslContext);
155   socket->open();
156
157   // write()
158   uint8_t buf[128];
159   memset(buf, 'a', sizeof(buf));
160   socket->write(buf, sizeof(buf));
161
162   // read()
163   uint8_t readbuf[128];
164   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
165   EXPECT_EQ(bytesRead, 128);
166   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
167
168   // close()
169   socket->close();
170
171   cerr << "ConnectWriteReadClose test completed" << endl;
172 }
173
174 /**
175  * Test reading after server close.
176  */
177 TEST(AsyncSSLSocketTest, ReadAfterClose) {
178   // Start listening on a local port
179   WriteCallbackBase writeCallback;
180   ReadEOFCallback readCallback(&writeCallback);
181   HandshakeCallback handshakeCallback(&readCallback);
182   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
183   auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
184
185   // Set up SSL context.
186   auto sslContext = std::make_shared<SSLContext>();
187   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
188
189   auto socket =
190       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
191   socket->open();
192
193   // This should trigger an EOF on the client.
194   auto evb = handshakeCallback.getSocket()->getEventBase();
195   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
196   std::array<uint8_t, 128> readbuf;
197   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
198   EXPECT_EQ(0, bytesRead);
199 }
200
201 /**
202  * Test bad renegotiation
203  */
204 #if !defined(OPENSSL_IS_BORINGSSL)
205 TEST(AsyncSSLSocketTest, Renegotiate) {
206   EventBase eventBase;
207   auto clientCtx = std::make_shared<SSLContext>();
208   auto dfServerCtx = std::make_shared<SSLContext>();
209   std::array<int, 2> fds;
210   getfds(fds.data());
211   getctx(clientCtx, dfServerCtx);
212
213   AsyncSSLSocket::UniquePtr clientSock(
214       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
215   AsyncSSLSocket::UniquePtr serverSock(
216       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
217   SSLHandshakeClient client(std::move(clientSock), true, true);
218   RenegotiatingServer server(std::move(serverSock));
219
220   while (!client.handshakeSuccess_ && !client.handshakeError_) {
221     eventBase.loopOnce();
222   }
223
224   ASSERT_TRUE(client.handshakeSuccess_);
225
226   auto sslSock = std::move(client).moveSocket();
227   sslSock->detachEventBase();
228   // This is nasty, however we don't want to add support for
229   // renegotiation in AsyncSSLSocket.
230   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
231
232   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
233
234   std::thread t([&]() { eventBase.loopForever(); });
235
236   // Trigger the renegotiation.
237   std::array<uint8_t, 128> buf;
238   memset(buf.data(), 'a', buf.size());
239   try {
240     socket->write(buf.data(), buf.size());
241   } catch (AsyncSocketException& e) {
242     LOG(INFO) << "client got error " << e.what();
243   }
244   eventBase.terminateLoopSoon();
245   t.join();
246
247   eventBase.loop();
248   ASSERT_TRUE(server.renegotiationError_);
249 }
250 #endif
251
252 /**
253  * Negative test for handshakeError().
254  */
255 TEST(AsyncSSLSocketTest, HandshakeError) {
256   // Start listening on a local port
257   WriteCallbackBase writeCallback;
258   WriteErrorCallback readCallback(&writeCallback);
259   HandshakeCallback handshakeCallback(&readCallback);
260   HandshakeErrorCallback acceptCallback(&handshakeCallback);
261   TestSSLServer server(&acceptCallback);
262
263   // Set up SSL context.
264   std::shared_ptr<SSLContext> sslContext(new SSLContext());
265   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
266
267   // connect
268   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
269                                                  sslContext);
270   // read()
271   bool ex = false;
272   try {
273     socket->open();
274
275     uint8_t readbuf[128];
276     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
277     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
278   } catch (AsyncSocketException&) {
279     ex = true;
280   }
281   EXPECT_TRUE(ex);
282
283   // close()
284   socket->close();
285   cerr << "HandshakeError test completed" << endl;
286 }
287
288 /**
289  * Negative test for readError().
290  */
291 TEST(AsyncSSLSocketTest, ReadError) {
292   // Start listening on a local port
293   WriteCallbackBase writeCallback;
294   ReadErrorCallback readCallback(&writeCallback);
295   HandshakeCallback handshakeCallback(&readCallback);
296   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
297   TestSSLServer server(&acceptCallback);
298
299   // Set up SSL context.
300   std::shared_ptr<SSLContext> sslContext(new SSLContext());
301   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
302
303   // connect
304   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
305                                                  sslContext);
306   socket->open();
307
308   // write something to trigger ssl handshake
309   uint8_t buf[128];
310   memset(buf, 'a', sizeof(buf));
311   socket->write(buf, sizeof(buf));
312
313   socket->close();
314   cerr << "ReadError test completed" << endl;
315 }
316
317 /**
318  * Negative test for writeError().
319  */
320 TEST(AsyncSSLSocketTest, WriteError) {
321   // Start listening on a local port
322   WriteCallbackBase writeCallback;
323   WriteErrorCallback readCallback(&writeCallback);
324   HandshakeCallback handshakeCallback(&readCallback);
325   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
326   TestSSLServer server(&acceptCallback);
327
328   // Set up SSL context.
329   std::shared_ptr<SSLContext> sslContext(new SSLContext());
330   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
331
332   // connect
333   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
334                                                  sslContext);
335   socket->open();
336
337   // write something to trigger ssl handshake
338   uint8_t buf[128];
339   memset(buf, 'a', sizeof(buf));
340   socket->write(buf, sizeof(buf));
341
342   socket->close();
343   cerr << "WriteError test completed" << endl;
344 }
345
346 /**
347  * Test a socket with TCP_NODELAY unset.
348  */
349 TEST(AsyncSSLSocketTest, SocketWithDelay) {
350   // Start listening on a local port
351   WriteCallbackBase writeCallback;
352   ReadCallback readCallback(&writeCallback);
353   HandshakeCallback handshakeCallback(&readCallback);
354   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
355   TestSSLServer server(&acceptCallback);
356
357   // Set up SSL context.
358   std::shared_ptr<SSLContext> sslContext(new SSLContext());
359   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
360
361   // connect
362   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
363                                                  sslContext);
364   socket->open();
365
366   // write()
367   uint8_t buf[128];
368   memset(buf, 'a', sizeof(buf));
369   socket->write(buf, sizeof(buf));
370
371   // read()
372   uint8_t readbuf[128];
373   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
374   EXPECT_EQ(bytesRead, 128);
375   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
376
377   // close()
378   socket->close();
379
380   cerr << "SocketWithDelay test completed" << endl;
381 }
382
383 using NextProtocolTypePair =
384     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
385
386 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
387   // For matching protos
388  public:
389   void SetUp() override { getctx(clientCtx, serverCtx); }
390
391   void connect(bool unset = false) {
392     getfds(fds);
393
394     if (unset) {
395       // unsetting NPN for any of [client, server] is enough to make NPN not
396       // work
397       clientCtx->unsetNextProtocols();
398     }
399
400     AsyncSSLSocket::UniquePtr clientSock(
401       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
402     AsyncSSLSocket::UniquePtr serverSock(
403       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
404     client = folly::make_unique<NpnClient>(std::move(clientSock));
405     server = folly::make_unique<NpnServer>(std::move(serverSock));
406
407     eventBase.loop();
408   }
409
410   void expectProtocol(const std::string& proto) {
411     EXPECT_NE(client->nextProtoLength, 0);
412     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
413     EXPECT_EQ(
414         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
415         0);
416     string selected((const char*)client->nextProto, client->nextProtoLength);
417     EXPECT_EQ(proto, selected);
418   }
419
420   void expectNoProtocol() {
421     EXPECT_EQ(client->nextProtoLength, 0);
422     EXPECT_EQ(server->nextProtoLength, 0);
423     EXPECT_EQ(client->nextProto, nullptr);
424     EXPECT_EQ(server->nextProto, nullptr);
425   }
426
427   void expectProtocolType() {
428     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
429         GetParam().second == SSLContext::NextProtocolType::ANY) {
430       EXPECT_EQ(client->protocolType, server->protocolType);
431     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
432                GetParam().second == SSLContext::NextProtocolType::ANY) {
433       // Well not much we can say
434     } else {
435       expectProtocolType(GetParam());
436     }
437   }
438
439   void expectProtocolType(NextProtocolTypePair expected) {
440     EXPECT_EQ(client->protocolType, expected.first);
441     EXPECT_EQ(server->protocolType, expected.second);
442   }
443
444   EventBase eventBase;
445   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
446   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
447   int fds[2];
448   std::unique_ptr<NpnClient> client;
449   std::unique_ptr<NpnServer> server;
450 };
451
452 class NextProtocolTLSExtTest : public NextProtocolTest {
453   // For extended TLS protos
454 };
455
456 class NextProtocolNPNOnlyTest : public NextProtocolTest {
457   // For mismatching protos
458 };
459
460 class NextProtocolMismatchTest : public NextProtocolTest {
461   // For mismatching protos
462 };
463
464 TEST_P(NextProtocolTest, NpnTestOverlap) {
465   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
466   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
467                                         GetParam().second);
468
469   connect();
470
471   expectProtocol("baz");
472   expectProtocolType();
473 }
474
475 TEST_P(NextProtocolTest, NpnTestUnset) {
476   // Identical to above test, except that we want unset NPN before
477   // looping.
478   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
479   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
480                                         GetParam().second);
481
482   connect(true /* unset */);
483
484   // if alpn negotiation fails, type will appear as npn
485   expectNoProtocol();
486   EXPECT_EQ(client->protocolType, server->protocolType);
487 }
488
489 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
490   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
491   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
492                                         GetParam().second);
493
494   connect();
495
496   expectNoProtocol();
497   expectProtocolType(
498       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
499 }
500
501 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
502 // will fail on 1.0.2 before that.
503 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
504   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
505   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
506                                         GetParam().second);
507
508   connect();
509
510   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
511       GetParam().second == SSLContext::NextProtocolType::ALPN) {
512     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
513     // mismatch should result in a fatal alert, but this is OpenSSL's current
514     // behavior and we want to know if it changes.
515     expectNoProtocol();
516   }
517 #if defined(OPENSSL_IS_BORINGSSL)
518   // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
519   // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
520   // NPN *and* ALPN
521   else if (
522       GetParam().first == SSLContext::NextProtocolType::ANY &&
523       GetParam().second == SSLContext::NextProtocolType::ANY) {
524     expectNoProtocol();
525   }
526 #endif
527   else {
528     expectProtocol("blub");
529     expectProtocolType(
530         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
531   }
532 }
533
534 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
535   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
536   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
537   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
538                                         GetParam().second);
539
540   connect();
541
542   expectProtocol("ponies");
543   expectProtocolType();
544 }
545
546 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
547   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
548   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
549   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
550                                         GetParam().second);
551
552   connect();
553
554   expectProtocol("blub");
555   expectProtocolType();
556 }
557
558 TEST_P(NextProtocolTest, RandomizedNpnTest) {
559   // Probability that this test will fail is 2^-64, which could be considered
560   // as negligible.
561   const int kTries = 64;
562
563   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
564                                         GetParam().first);
565   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
566                                                   GetParam().second);
567
568   std::set<string> selectedProtocols;
569   for (int i = 0; i < kTries; ++i) {
570     connect();
571
572     EXPECT_NE(client->nextProtoLength, 0);
573     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
574     EXPECT_EQ(
575         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
576         0);
577     string selected((const char*)client->nextProto, client->nextProtoLength);
578     selectedProtocols.insert(selected);
579     expectProtocolType();
580   }
581   EXPECT_EQ(selectedProtocols.size(), 2);
582 }
583
584 INSTANTIATE_TEST_CASE_P(
585     AsyncSSLSocketTest,
586     NextProtocolTest,
587     ::testing::Values(
588         NextProtocolTypePair(
589             SSLContext::NextProtocolType::NPN,
590             SSLContext::NextProtocolType::NPN),
591         NextProtocolTypePair(
592             SSLContext::NextProtocolType::NPN,
593             SSLContext::NextProtocolType::ANY),
594         NextProtocolTypePair(
595             SSLContext::NextProtocolType::ANY,
596             SSLContext::NextProtocolType::ANY)));
597
598 #if FOLLY_OPENSSL_HAS_ALPN
599 INSTANTIATE_TEST_CASE_P(
600     AsyncSSLSocketTest,
601     NextProtocolTLSExtTest,
602     ::testing::Values(
603         NextProtocolTypePair(
604             SSLContext::NextProtocolType::ALPN,
605             SSLContext::NextProtocolType::ALPN),
606         NextProtocolTypePair(
607             SSLContext::NextProtocolType::ALPN,
608             SSLContext::NextProtocolType::ANY),
609         NextProtocolTypePair(
610             SSLContext::NextProtocolType::ANY,
611             SSLContext::NextProtocolType::ALPN)));
612 #endif
613
614 INSTANTIATE_TEST_CASE_P(
615     AsyncSSLSocketTest,
616     NextProtocolNPNOnlyTest,
617     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
618                                            SSLContext::NextProtocolType::NPN)));
619
620 #if FOLLY_OPENSSL_HAS_ALPN
621 INSTANTIATE_TEST_CASE_P(
622     AsyncSSLSocketTest,
623     NextProtocolMismatchTest,
624     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
625                                            SSLContext::NextProtocolType::ALPN),
626                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
627                                            SSLContext::NextProtocolType::NPN)));
628 #endif
629
630 #ifndef OPENSSL_NO_TLSEXT
631 /**
632  * 1. Client sends TLSEXT_HOSTNAME in client hello.
633  * 2. Server found a match SSL_CTX and use this SSL_CTX to
634  *    continue the SSL handshake.
635  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
636  */
637 TEST(AsyncSSLSocketTest, SNITestMatch) {
638   EventBase eventBase;
639   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
640   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
641   // Use the same SSLContext to continue the handshake after
642   // tlsext_hostname match.
643   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
644   const std::string serverName("xyz.newdev.facebook.com");
645   int fds[2];
646   getfds(fds);
647   getctx(clientCtx, dfServerCtx);
648
649   AsyncSSLSocket::UniquePtr clientSock(
650     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
651   AsyncSSLSocket::UniquePtr serverSock(
652     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
653   SNIClient client(std::move(clientSock));
654   SNIServer server(std::move(serverSock),
655                    dfServerCtx,
656                    hskServerCtx,
657                    serverName);
658
659   eventBase.loop();
660
661   EXPECT_TRUE(client.serverNameMatch);
662   EXPECT_TRUE(server.serverNameMatch);
663 }
664
665 /**
666  * 1. Client sends TLSEXT_HOSTNAME in client hello.
667  * 2. Server cannot find a matching SSL_CTX and continue to use
668  *    the current SSL_CTX to do the handshake.
669  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
670  */
671 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
672   EventBase eventBase;
673   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
674   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
675   // Use the same SSLContext to continue the handshake after
676   // tlsext_hostname match.
677   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
678   const std::string clientRequestingServerName("foo.com");
679   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
680
681   int fds[2];
682   getfds(fds);
683   getctx(clientCtx, dfServerCtx);
684
685   AsyncSSLSocket::UniquePtr clientSock(
686     new AsyncSSLSocket(clientCtx,
687                         &eventBase,
688                         fds[0],
689                         clientRequestingServerName));
690   AsyncSSLSocket::UniquePtr serverSock(
691     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
692   SNIClient client(std::move(clientSock));
693   SNIServer server(std::move(serverSock),
694                    dfServerCtx,
695                    hskServerCtx,
696                    serverExpectedServerName);
697
698   eventBase.loop();
699
700   EXPECT_TRUE(!client.serverNameMatch);
701   EXPECT_TRUE(!server.serverNameMatch);
702 }
703 /**
704  * 1. Client sends TLSEXT_HOSTNAME in client hello.
705  * 2. We then change the serverName.
706  * 3. We expect that we get 'false' as the result for serNameMatch.
707  */
708
709 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
710    EventBase eventBase;
711   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
712   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
713   // Use the same SSLContext to continue the handshake after
714   // tlsext_hostname match.
715   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
716   const std::string serverName("xyz.newdev.facebook.com");
717   int fds[2];
718   getfds(fds);
719   getctx(clientCtx, dfServerCtx);
720
721   AsyncSSLSocket::UniquePtr clientSock(
722     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
723   //Change the server name
724   std::string newName("new.com");
725   clientSock->setServerName(newName);
726   AsyncSSLSocket::UniquePtr serverSock(
727     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
728   SNIClient client(std::move(clientSock));
729   SNIServer server(std::move(serverSock),
730                    dfServerCtx,
731                    hskServerCtx,
732                    serverName);
733
734   eventBase.loop();
735
736   EXPECT_TRUE(!client.serverNameMatch);
737 }
738
739 /**
740  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
741  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
742  */
743 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
744   EventBase eventBase;
745   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
746   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
747   // Use the same SSLContext to continue the handshake after
748   // tlsext_hostname match.
749   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
750   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
751
752   int fds[2];
753   getfds(fds);
754   getctx(clientCtx, dfServerCtx);
755
756   AsyncSSLSocket::UniquePtr clientSock(
757     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
758   AsyncSSLSocket::UniquePtr serverSock(
759     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
760   SNIClient client(std::move(clientSock));
761   SNIServer server(std::move(serverSock),
762                    dfServerCtx,
763                    hskServerCtx,
764                    serverExpectedServerName);
765
766   eventBase.loop();
767
768   EXPECT_TRUE(!client.serverNameMatch);
769   EXPECT_TRUE(!server.serverNameMatch);
770 }
771
772 #endif
773 /**
774  * Test SSL client socket
775  */
776 TEST(AsyncSSLSocketTest, SSLClientTest) {
777   // Start listening on a local port
778   WriteCallbackBase writeCallback;
779   ReadCallback readCallback(&writeCallback);
780   HandshakeCallback handshakeCallback(&readCallback);
781   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
782   TestSSLServer server(&acceptCallback);
783
784   // Set up SSL client
785   EventBase eventBase;
786   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
787
788   client->connect();
789   EventBaseAborter eba(&eventBase, 3000);
790   eventBase.loop();
791
792   EXPECT_EQ(client->getMiss(), 1);
793   EXPECT_EQ(client->getHit(), 0);
794
795   cerr << "SSLClientTest test completed" << endl;
796 }
797
798
799 /**
800  * Test SSL client socket session re-use
801  */
802 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
803   // Start listening on a local port
804   WriteCallbackBase writeCallback;
805   ReadCallback readCallback(&writeCallback);
806   HandshakeCallback handshakeCallback(&readCallback);
807   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
808   TestSSLServer server(&acceptCallback);
809
810   // Set up SSL client
811   EventBase eventBase;
812   auto client =
813       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
814
815   client->connect();
816   EventBaseAborter eba(&eventBase, 3000);
817   eventBase.loop();
818
819   EXPECT_EQ(client->getMiss(), 1);
820   EXPECT_EQ(client->getHit(), 9);
821
822   cerr << "SSLClientTestReuse test completed" << endl;
823 }
824
825 /**
826  * Test SSL client socket timeout
827  */
828 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
829   // Start listening on a local port
830   EmptyReadCallback readCallback;
831   HandshakeCallback handshakeCallback(&readCallback,
832                                       HandshakeCallback::EXPECT_ERROR);
833   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
834   TestSSLServer server(&acceptCallback);
835
836   // Set up SSL client
837   EventBase eventBase;
838   auto client =
839       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
840   client->connect(true /* write before connect completes */);
841   EventBaseAborter eba(&eventBase, 3000);
842   eventBase.loop();
843
844   usleep(100000);
845   // This is checking that the connectError callback precedes any queued
846   // writeError callbacks.  This matches AsyncSocket's behavior
847   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
848   EXPECT_EQ(client->getErrors(), 1);
849   EXPECT_EQ(client->getMiss(), 0);
850   EXPECT_EQ(client->getHit(), 0);
851
852   cerr << "SSLClientTimeoutTest test completed" << endl;
853 }
854
855 // The next 3 tests need an FB-only extension, and will fail without it
856 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
857 /**
858  * Test SSL server async cache
859  */
860 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
861   // Start listening on a local port
862   WriteCallbackBase writeCallback;
863   ReadCallback readCallback(&writeCallback);
864   HandshakeCallback handshakeCallback(&readCallback);
865   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
866   TestSSLAsyncCacheServer server(&acceptCallback);
867
868   // Set up SSL client
869   EventBase eventBase;
870   auto client =
871       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
872
873   client->connect();
874   EventBaseAborter eba(&eventBase, 3000);
875   eventBase.loop();
876
877   EXPECT_EQ(server.getAsyncCallbacks(), 18);
878   EXPECT_EQ(server.getAsyncLookups(), 9);
879   EXPECT_EQ(client->getMiss(), 10);
880   EXPECT_EQ(client->getHit(), 0);
881
882   cerr << "SSLServerAsyncCacheTest test completed" << endl;
883 }
884
885 /**
886  * Test SSL server accept timeout with cache path
887  */
888 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
889   // Start listening on a local port
890   WriteCallbackBase writeCallback;
891   ReadCallback readCallback(&writeCallback);
892   HandshakeCallback handshakeCallback(&readCallback);
893   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
894   TestSSLAsyncCacheServer server(&acceptCallback);
895
896   // Set up SSL client
897   EventBase eventBase;
898   // only do a TCP connect
899   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
900   sock->connect(nullptr, server.getAddress());
901
902   EmptyReadCallback clientReadCallback;
903   clientReadCallback.tcpSocket_ = sock;
904   sock->setReadCB(&clientReadCallback);
905
906   EventBaseAborter eba(&eventBase, 3000);
907   eventBase.loop();
908
909   EXPECT_EQ(readCallback.state, STATE_WAITING);
910
911   cerr << "SSLServerTimeoutTest test completed" << endl;
912 }
913
914 /**
915  * Test SSL server accept timeout with cache path
916  */
917 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
918   // Start listening on a local port
919   WriteCallbackBase writeCallback;
920   ReadCallback readCallback(&writeCallback);
921   HandshakeCallback handshakeCallback(&readCallback);
922   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
923   TestSSLAsyncCacheServer server(&acceptCallback);
924
925   // Set up SSL client
926   EventBase eventBase;
927   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
928
929   client->connect();
930   EventBaseAborter eba(&eventBase, 3000);
931   eventBase.loop();
932
933   EXPECT_EQ(server.getAsyncCallbacks(), 1);
934   EXPECT_EQ(server.getAsyncLookups(), 1);
935   EXPECT_EQ(client->getErrors(), 1);
936   EXPECT_EQ(client->getMiss(), 1);
937   EXPECT_EQ(client->getHit(), 0);
938
939   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
940 }
941
942 /**
943  * Test SSL server accept timeout with cache path
944  */
945 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
946   // Start listening on a local port
947   WriteCallbackBase writeCallback;
948   ReadCallback readCallback(&writeCallback);
949   HandshakeCallback handshakeCallback(&readCallback,
950                                       HandshakeCallback::EXPECT_ERROR);
951   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
952   TestSSLAsyncCacheServer server(&acceptCallback, 500);
953
954   // Set up SSL client
955   EventBase eventBase;
956   auto client =
957       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
958
959   client->connect();
960   EventBaseAborter eba(&eventBase, 3000);
961   eventBase.loop();
962
963   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
964       handshakeCallback.closeSocket();});
965   // give time for the cache lookup to come back and find it closed
966   handshakeCallback.waitForHandshake();
967
968   EXPECT_EQ(server.getAsyncCallbacks(), 1);
969   EXPECT_EQ(server.getAsyncLookups(), 1);
970   EXPECT_EQ(client->getErrors(), 1);
971   EXPECT_EQ(client->getMiss(), 1);
972   EXPECT_EQ(client->getHit(), 0);
973
974   cerr << "SSLServerCacheCloseTest test completed" << endl;
975 }
976 #endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
977
978 /**
979  * Verify Client Ciphers obtained using SSL MSG Callback.
980  */
981 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
982   EventBase eventBase;
983   auto clientCtx = std::make_shared<SSLContext>();
984   auto serverCtx = std::make_shared<SSLContext>();
985   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
986   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
987   serverCtx->loadPrivateKey(kTestKey);
988   serverCtx->loadCertificate(kTestCert);
989   serverCtx->loadTrustedCertificates(kTestCA);
990   serverCtx->loadClientCAList(kTestCA);
991
992   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
993   clientCtx->ciphers("AES256-SHA:AES128-SHA");
994   clientCtx->loadPrivateKey(kTestKey);
995   clientCtx->loadCertificate(kTestCert);
996   clientCtx->loadTrustedCertificates(kTestCA);
997
998   int fds[2];
999   getfds(fds);
1000
1001   AsyncSSLSocket::UniquePtr clientSock(
1002       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1003   AsyncSSLSocket::UniquePtr serverSock(
1004       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1005
1006   SSLHandshakeClient client(std::move(clientSock), true, true);
1007   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1008
1009   eventBase.loop();
1010
1011 #if defined(OPENSSL_IS_BORINGSSL)
1012   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
1013 #else
1014   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
1015 #endif
1016   EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1017   EXPECT_TRUE(client.handshakeVerify_);
1018   EXPECT_TRUE(client.handshakeSuccess_);
1019   EXPECT_TRUE(!client.handshakeError_);
1020   EXPECT_TRUE(server.handshakeVerify_);
1021   EXPECT_TRUE(server.handshakeSuccess_);
1022   EXPECT_TRUE(!server.handshakeError_);
1023 }
1024
1025 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1026   EventBase eventBase;
1027   auto ctx = std::make_shared<SSLContext>();
1028
1029   int fds[2];
1030   getfds(fds);
1031
1032   int bufLen = 42;
1033   uint8_t majorVersion = 18;
1034   uint8_t minorVersion = 25;
1035
1036   // Create callback buf
1037   auto buf = IOBuf::create(bufLen);
1038   buf->append(bufLen);
1039   folly::io::RWPrivateCursor cursor(buf.get());
1040   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1041   cursor.write<uint16_t>(0);
1042   cursor.write<uint8_t>(38);
1043   cursor.write<uint8_t>(majorVersion);
1044   cursor.write<uint8_t>(minorVersion);
1045   cursor.skip(32);
1046   cursor.write<uint32_t>(0);
1047
1048   SSL* ssl = ctx->createSSL();
1049   SCOPE_EXIT { SSL_free(ssl); };
1050   AsyncSSLSocket::UniquePtr sock(
1051       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1052   sock->enableClientHelloParsing();
1053
1054   // Test client hello parsing in one packet
1055   AsyncSSLSocket::clientHelloParsingCallback(
1056       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1057   buf.reset();
1058
1059   auto parsedClientHello = sock->getClientHelloInfo();
1060   EXPECT_TRUE(parsedClientHello != nullptr);
1061   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1062   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1063 }
1064
1065 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1066   EventBase eventBase;
1067   auto ctx = std::make_shared<SSLContext>();
1068
1069   int fds[2];
1070   getfds(fds);
1071
1072   int bufLen = 42;
1073   uint8_t majorVersion = 18;
1074   uint8_t minorVersion = 25;
1075
1076   // Create callback buf
1077   auto buf = IOBuf::create(bufLen);
1078   buf->append(bufLen);
1079   folly::io::RWPrivateCursor cursor(buf.get());
1080   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1081   cursor.write<uint16_t>(0);
1082   cursor.write<uint8_t>(38);
1083   cursor.write<uint8_t>(majorVersion);
1084   cursor.write<uint8_t>(minorVersion);
1085   cursor.skip(32);
1086   cursor.write<uint32_t>(0);
1087
1088   SSL* ssl = ctx->createSSL();
1089   SCOPE_EXIT { SSL_free(ssl); };
1090   AsyncSSLSocket::UniquePtr sock(
1091       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1092   sock->enableClientHelloParsing();
1093
1094   // Test parsing with two packets with first packet size < 3
1095   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1096   AsyncSSLSocket::clientHelloParsingCallback(
1097       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1098       ssl, sock.get());
1099   bufCopy.reset();
1100   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1101   AsyncSSLSocket::clientHelloParsingCallback(
1102       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1103       ssl, sock.get());
1104   bufCopy.reset();
1105
1106   auto parsedClientHello = sock->getClientHelloInfo();
1107   EXPECT_TRUE(parsedClientHello != nullptr);
1108   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1109   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1110 }
1111
1112 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1113   EventBase eventBase;
1114   auto ctx = std::make_shared<SSLContext>();
1115
1116   int fds[2];
1117   getfds(fds);
1118
1119   int bufLen = 42;
1120   uint8_t majorVersion = 18;
1121   uint8_t minorVersion = 25;
1122
1123   // Create callback buf
1124   auto buf = IOBuf::create(bufLen);
1125   buf->append(bufLen);
1126   folly::io::RWPrivateCursor cursor(buf.get());
1127   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1128   cursor.write<uint16_t>(0);
1129   cursor.write<uint8_t>(38);
1130   cursor.write<uint8_t>(majorVersion);
1131   cursor.write<uint8_t>(minorVersion);
1132   cursor.skip(32);
1133   cursor.write<uint32_t>(0);
1134
1135   SSL* ssl = ctx->createSSL();
1136   SCOPE_EXIT { SSL_free(ssl); };
1137   AsyncSSLSocket::UniquePtr sock(
1138       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1139   sock->enableClientHelloParsing();
1140
1141   // Test parsing with multiple small packets
1142   for (uint64_t i = 0; i < buf->length(); i += 3) {
1143     auto bufCopy = folly::IOBuf::copyBuffer(
1144         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1145     AsyncSSLSocket::clientHelloParsingCallback(
1146         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1147         ssl, sock.get());
1148     bufCopy.reset();
1149   }
1150
1151   auto parsedClientHello = sock->getClientHelloInfo();
1152   EXPECT_TRUE(parsedClientHello != nullptr);
1153   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1154   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1155 }
1156
1157 /**
1158  * Verify sucessful behavior of SSL certificate validation.
1159  */
1160 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1161   EventBase eventBase;
1162   auto clientCtx = std::make_shared<SSLContext>();
1163   auto dfServerCtx = std::make_shared<SSLContext>();
1164
1165   int fds[2];
1166   getfds(fds);
1167   getctx(clientCtx, dfServerCtx);
1168
1169   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1170   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1171
1172   AsyncSSLSocket::UniquePtr clientSock(
1173     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1174   AsyncSSLSocket::UniquePtr serverSock(
1175     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1176
1177   SSLHandshakeClient client(std::move(clientSock), true, true);
1178   clientCtx->loadTrustedCertificates(kTestCA);
1179
1180   SSLHandshakeServer server(std::move(serverSock), true, true);
1181
1182   eventBase.loop();
1183
1184   EXPECT_TRUE(client.handshakeVerify_);
1185   EXPECT_TRUE(client.handshakeSuccess_);
1186   EXPECT_TRUE(!client.handshakeError_);
1187   EXPECT_LE(0, client.handshakeTime.count());
1188   EXPECT_TRUE(!server.handshakeVerify_);
1189   EXPECT_TRUE(server.handshakeSuccess_);
1190   EXPECT_TRUE(!server.handshakeError_);
1191   EXPECT_LE(0, server.handshakeTime.count());
1192 }
1193
1194 /**
1195  * Verify that the client's verification callback is able to fail SSL
1196  * connection establishment.
1197  */
1198 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1199   EventBase eventBase;
1200   auto clientCtx = std::make_shared<SSLContext>();
1201   auto dfServerCtx = std::make_shared<SSLContext>();
1202
1203   int fds[2];
1204   getfds(fds);
1205   getctx(clientCtx, dfServerCtx);
1206
1207   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1208   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1209
1210   AsyncSSLSocket::UniquePtr clientSock(
1211     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1212   AsyncSSLSocket::UniquePtr serverSock(
1213     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1214
1215   SSLHandshakeClient client(std::move(clientSock), true, false);
1216   clientCtx->loadTrustedCertificates(kTestCA);
1217
1218   SSLHandshakeServer server(std::move(serverSock), true, true);
1219
1220   eventBase.loop();
1221
1222   EXPECT_TRUE(client.handshakeVerify_);
1223   EXPECT_TRUE(!client.handshakeSuccess_);
1224   EXPECT_TRUE(client.handshakeError_);
1225   EXPECT_LE(0, client.handshakeTime.count());
1226   EXPECT_TRUE(!server.handshakeVerify_);
1227   EXPECT_TRUE(!server.handshakeSuccess_);
1228   EXPECT_TRUE(server.handshakeError_);
1229   EXPECT_LE(0, server.handshakeTime.count());
1230 }
1231
1232 /**
1233  * Verify that the options in SSLContext can be overridden in
1234  * sslConnect/Accept.i.e specifying that no validation should be performed
1235  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1236  * the validation callback.
1237  */
1238 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1239   EventBase eventBase;
1240   auto clientCtx = std::make_shared<SSLContext>();
1241   auto dfServerCtx = std::make_shared<SSLContext>();
1242
1243   int fds[2];
1244   getfds(fds);
1245   getctx(clientCtx, dfServerCtx);
1246
1247   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1248   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1249
1250   AsyncSSLSocket::UniquePtr clientSock(
1251     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1252   AsyncSSLSocket::UniquePtr serverSock(
1253     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1254
1255   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1256   clientCtx->loadTrustedCertificates(kTestCA);
1257
1258   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1259
1260   eventBase.loop();
1261
1262   EXPECT_TRUE(!client.handshakeVerify_);
1263   EXPECT_TRUE(client.handshakeSuccess_);
1264   EXPECT_TRUE(!client.handshakeError_);
1265   EXPECT_LE(0, client.handshakeTime.count());
1266   EXPECT_TRUE(!server.handshakeVerify_);
1267   EXPECT_TRUE(server.handshakeSuccess_);
1268   EXPECT_TRUE(!server.handshakeError_);
1269   EXPECT_LE(0, server.handshakeTime.count());
1270 }
1271
1272 /**
1273  * Verify that the options in SSLContext can be overridden in
1274  * sslConnect/Accept. Enable verification even if context says otherwise.
1275  * Test requireClientCert with client cert
1276  */
1277 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1278   EventBase eventBase;
1279   auto clientCtx = std::make_shared<SSLContext>();
1280   auto serverCtx = std::make_shared<SSLContext>();
1281   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1282   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1283   serverCtx->loadPrivateKey(kTestKey);
1284   serverCtx->loadCertificate(kTestCert);
1285   serverCtx->loadTrustedCertificates(kTestCA);
1286   serverCtx->loadClientCAList(kTestCA);
1287
1288   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1289   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1290   clientCtx->loadPrivateKey(kTestKey);
1291   clientCtx->loadCertificate(kTestCert);
1292   clientCtx->loadTrustedCertificates(kTestCA);
1293
1294   int fds[2];
1295   getfds(fds);
1296
1297   AsyncSSLSocket::UniquePtr clientSock(
1298       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1299   AsyncSSLSocket::UniquePtr serverSock(
1300       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1301
1302   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1303   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1304
1305   eventBase.loop();
1306
1307   EXPECT_TRUE(client.handshakeVerify_);
1308   EXPECT_TRUE(client.handshakeSuccess_);
1309   EXPECT_FALSE(client.handshakeError_);
1310   EXPECT_LE(0, client.handshakeTime.count());
1311   EXPECT_TRUE(server.handshakeVerify_);
1312   EXPECT_TRUE(server.handshakeSuccess_);
1313   EXPECT_FALSE(server.handshakeError_);
1314   EXPECT_LE(0, server.handshakeTime.count());
1315 }
1316
1317 /**
1318  * Verify that the client's verification callback is able to override
1319  * the preverification failure and allow a successful connection.
1320  */
1321 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1322   EventBase eventBase;
1323   auto clientCtx = std::make_shared<SSLContext>();
1324   auto dfServerCtx = std::make_shared<SSLContext>();
1325
1326   int fds[2];
1327   getfds(fds);
1328   getctx(clientCtx, dfServerCtx);
1329
1330   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1331   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1332
1333   AsyncSSLSocket::UniquePtr clientSock(
1334     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1335   AsyncSSLSocket::UniquePtr serverSock(
1336     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1337
1338   SSLHandshakeClient client(std::move(clientSock), false, true);
1339   SSLHandshakeServer server(std::move(serverSock), true, true);
1340
1341   eventBase.loop();
1342
1343   EXPECT_TRUE(client.handshakeVerify_);
1344   EXPECT_TRUE(client.handshakeSuccess_);
1345   EXPECT_TRUE(!client.handshakeError_);
1346   EXPECT_LE(0, client.handshakeTime.count());
1347   EXPECT_TRUE(!server.handshakeVerify_);
1348   EXPECT_TRUE(server.handshakeSuccess_);
1349   EXPECT_TRUE(!server.handshakeError_);
1350   EXPECT_LE(0, server.handshakeTime.count());
1351 }
1352
1353 /**
1354  * Verify that specifying that no validation should be performed allows an
1355  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1356  * callback.
1357  */
1358 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1359   EventBase eventBase;
1360   auto clientCtx = std::make_shared<SSLContext>();
1361   auto dfServerCtx = std::make_shared<SSLContext>();
1362
1363   int fds[2];
1364   getfds(fds);
1365   getctx(clientCtx, dfServerCtx);
1366
1367   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1368   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1369
1370   AsyncSSLSocket::UniquePtr clientSock(
1371     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1372   AsyncSSLSocket::UniquePtr serverSock(
1373     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1374
1375   SSLHandshakeClient client(std::move(clientSock), false, false);
1376   SSLHandshakeServer server(std::move(serverSock), false, false);
1377
1378   eventBase.loop();
1379
1380   EXPECT_TRUE(!client.handshakeVerify_);
1381   EXPECT_TRUE(client.handshakeSuccess_);
1382   EXPECT_TRUE(!client.handshakeError_);
1383   EXPECT_LE(0, client.handshakeTime.count());
1384   EXPECT_TRUE(!server.handshakeVerify_);
1385   EXPECT_TRUE(server.handshakeSuccess_);
1386   EXPECT_TRUE(!server.handshakeError_);
1387   EXPECT_LE(0, server.handshakeTime.count());
1388 }
1389
1390 /**
1391  * Test requireClientCert with client cert
1392  */
1393 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1394   EventBase eventBase;
1395   auto clientCtx = std::make_shared<SSLContext>();
1396   auto serverCtx = std::make_shared<SSLContext>();
1397   serverCtx->setVerificationOption(
1398       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1399   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1400   serverCtx->loadPrivateKey(kTestKey);
1401   serverCtx->loadCertificate(kTestCert);
1402   serverCtx->loadTrustedCertificates(kTestCA);
1403   serverCtx->loadClientCAList(kTestCA);
1404
1405   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1406   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1407   clientCtx->loadPrivateKey(kTestKey);
1408   clientCtx->loadCertificate(kTestCert);
1409   clientCtx->loadTrustedCertificates(kTestCA);
1410
1411   int fds[2];
1412   getfds(fds);
1413
1414   AsyncSSLSocket::UniquePtr clientSock(
1415       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1416   AsyncSSLSocket::UniquePtr serverSock(
1417       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1418
1419   SSLHandshakeClient client(std::move(clientSock), true, true);
1420   SSLHandshakeServer server(std::move(serverSock), true, true);
1421
1422   eventBase.loop();
1423
1424   EXPECT_TRUE(client.handshakeVerify_);
1425   EXPECT_TRUE(client.handshakeSuccess_);
1426   EXPECT_FALSE(client.handshakeError_);
1427   EXPECT_LE(0, client.handshakeTime.count());
1428   EXPECT_TRUE(server.handshakeVerify_);
1429   EXPECT_TRUE(server.handshakeSuccess_);
1430   EXPECT_FALSE(server.handshakeError_);
1431   EXPECT_LE(0, server.handshakeTime.count());
1432 }
1433
1434
1435 /**
1436  * Test requireClientCert with no client cert
1437  */
1438 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1439   EventBase eventBase;
1440   auto clientCtx = std::make_shared<SSLContext>();
1441   auto serverCtx = std::make_shared<SSLContext>();
1442   serverCtx->setVerificationOption(
1443       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1444   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1445   serverCtx->loadPrivateKey(kTestKey);
1446   serverCtx->loadCertificate(kTestCert);
1447   serverCtx->loadTrustedCertificates(kTestCA);
1448   serverCtx->loadClientCAList(kTestCA);
1449   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1450   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1451
1452   int fds[2];
1453   getfds(fds);
1454
1455   AsyncSSLSocket::UniquePtr clientSock(
1456       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1457   AsyncSSLSocket::UniquePtr serverSock(
1458       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1459
1460   SSLHandshakeClient client(std::move(clientSock), false, false);
1461   SSLHandshakeServer server(std::move(serverSock), false, false);
1462
1463   eventBase.loop();
1464
1465   EXPECT_FALSE(server.handshakeVerify_);
1466   EXPECT_FALSE(server.handshakeSuccess_);
1467   EXPECT_TRUE(server.handshakeError_);
1468   EXPECT_LE(0, client.handshakeTime.count());
1469   EXPECT_LE(0, server.handshakeTime.count());
1470 }
1471
1472 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1473   auto cert = getFileAsBuf(kTestCert);
1474   auto key = getFileAsBuf(kTestKey);
1475
1476   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1477   BIO_write(certBio.get(), cert.data(), cert.size());
1478   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1479   BIO_write(keyBio.get(), key.data(), key.size());
1480
1481   // Create SSL structs from buffers to get properties
1482   ssl::X509UniquePtr certStruct(
1483       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1484   ssl::EvpPkeyUniquePtr keyStruct(
1485       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1486   certBio = nullptr;
1487   keyBio = nullptr;
1488
1489   auto origCommonName = getCommonName(certStruct.get());
1490   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1491   certStruct = nullptr;
1492   keyStruct = nullptr;
1493
1494   auto ctx = std::make_shared<SSLContext>();
1495   ctx->loadPrivateKeyFromBufferPEM(key);
1496   ctx->loadCertificateFromBufferPEM(cert);
1497   ctx->loadTrustedCertificates(kTestCA);
1498
1499   ssl::SSLUniquePtr ssl(ctx->createSSL());
1500
1501   auto newCert = SSL_get_certificate(ssl.get());
1502   auto newKey = SSL_get_privatekey(ssl.get());
1503
1504   // Get properties from SSL struct
1505   auto newCommonName = getCommonName(newCert);
1506   auto newKeySize = EVP_PKEY_bits(newKey);
1507
1508   // Check that the key and cert have the expected properties
1509   EXPECT_EQ(origCommonName, newCommonName);
1510   EXPECT_EQ(origKeySize, newKeySize);
1511 }
1512
1513 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1514   EventBase eb;
1515
1516   // Set up SSL context.
1517   auto sslContext = std::make_shared<SSLContext>();
1518   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1519
1520   // create SSL socket
1521   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1522
1523   EXPECT_EQ(1500, socket->getMinWriteSize());
1524
1525   socket->setMinWriteSize(0);
1526   EXPECT_EQ(0, socket->getMinWriteSize());
1527   socket->setMinWriteSize(50000);
1528   EXPECT_EQ(50000, socket->getMinWriteSize());
1529 }
1530
1531 class ReadCallbackTerminator : public ReadCallback {
1532  public:
1533   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1534       : ReadCallback(wcb)
1535       , base_(base) {}
1536
1537   // Do not write data back, terminate the loop.
1538   void readDataAvailable(size_t len) noexcept override {
1539     std::cerr << "readDataAvailable, len " << len << std::endl;
1540
1541     currentBuffer.length = len;
1542
1543     buffers.push_back(currentBuffer);
1544     currentBuffer.reset();
1545     state = STATE_SUCCEEDED;
1546
1547     socket_->setReadCB(nullptr);
1548     base_->terminateLoopSoon();
1549   }
1550  private:
1551   EventBase* base_;
1552 };
1553
1554
1555 /**
1556  * Test a full unencrypted codepath
1557  */
1558 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1559   EventBase base;
1560
1561   auto clientCtx = std::make_shared<folly::SSLContext>();
1562   auto serverCtx = std::make_shared<folly::SSLContext>();
1563   int fds[2];
1564   getfds(fds);
1565   getctx(clientCtx, serverCtx);
1566   auto client = AsyncSSLSocket::newSocket(
1567                   clientCtx, &base, fds[0], false, true);
1568   auto server = AsyncSSLSocket::newSocket(
1569                   serverCtx, &base, fds[1], true, true);
1570
1571   ReadCallbackTerminator readCallback(&base, nullptr);
1572   server->setReadCB(&readCallback);
1573   readCallback.setSocket(server);
1574
1575   uint8_t buf[128];
1576   memset(buf, 'a', sizeof(buf));
1577   client->write(nullptr, buf, sizeof(buf));
1578
1579   // Check that bytes are unencrypted
1580   char c;
1581   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1582   EXPECT_EQ('a', c);
1583
1584   EventBaseAborter eba(&base, 3000);
1585   base.loop();
1586
1587   EXPECT_EQ(1, readCallback.buffers.size());
1588   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1589
1590   server->setReadCB(&readCallback);
1591
1592   // Unencrypted
1593   server->sslAccept(nullptr);
1594   client->sslConn(nullptr);
1595
1596   // Do NOT wait for handshake, writing should be queued and happen after
1597
1598   client->write(nullptr, buf, sizeof(buf));
1599
1600   // Check that bytes are *not* unencrypted
1601   char c2;
1602   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1603   EXPECT_NE('a', c2);
1604
1605
1606   base.loop();
1607
1608   EXPECT_EQ(2, readCallback.buffers.size());
1609   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1610 }
1611
1612 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1613   // Start listening on a local port
1614   WriteCallbackBase writeCallback;
1615   WriteErrorCallback readCallback(&writeCallback);
1616   HandshakeCallback handshakeCallback(&readCallback,
1617                                       HandshakeCallback::EXPECT_ERROR);
1618   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1619   TestSSLServer server(&acceptCallback);
1620
1621   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1622   socket->open();
1623   uint8_t buf[3] = {0x16, 0x03, 0x01};
1624   socket->write(buf, sizeof(buf));
1625   socket->closeWithReset();
1626
1627   handshakeCallback.waitForHandshake();
1628   EXPECT_NE(
1629       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1630   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1631 }
1632
1633 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1634   // Start listening on a local port
1635   WriteCallbackBase writeCallback;
1636   WriteErrorCallback readCallback(&writeCallback);
1637   HandshakeCallback handshakeCallback(&readCallback,
1638                                       HandshakeCallback::EXPECT_ERROR);
1639   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1640   TestSSLServer server(&acceptCallback);
1641
1642   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1643   socket->open();
1644   uint8_t buf[3] = {0x16, 0x03, 0x01};
1645   socket->write(buf, sizeof(buf));
1646   socket->close();
1647
1648   handshakeCallback.waitForHandshake();
1649   EXPECT_NE(
1650       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1651   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1652 }
1653
1654 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1655   // Start listening on a local port
1656   WriteCallbackBase writeCallback;
1657   WriteErrorCallback readCallback(&writeCallback);
1658   HandshakeCallback handshakeCallback(&readCallback,
1659                                       HandshakeCallback::EXPECT_ERROR);
1660   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1661   TestSSLServer server(&acceptCallback);
1662
1663   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1664   socket->open();
1665   uint8_t buf[256] = {0x16, 0x03};
1666   memset(buf + 2, 'a', sizeof(buf) - 2);
1667   socket->write(buf, sizeof(buf));
1668   socket->close();
1669
1670   handshakeCallback.waitForHandshake();
1671   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1672             std::string::npos);
1673 #if defined(OPENSSL_IS_BORINGSSL)
1674   EXPECT_NE(
1675       handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
1676       std::string::npos);
1677 #else
1678   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1679             std::string::npos);
1680 #endif
1681 }
1682
1683 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
1684   using folly::ssl::OpenSSLUtils;
1685   EXPECT_EQ(
1686       OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
1687   // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
1688   EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
1689   // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
1690   EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
1691 }
1692
1693 #if FOLLY_ALLOW_TFO
1694
1695 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1696  public:
1697   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1698
1699   explicit MockAsyncTFOSSLSocket(
1700       std::shared_ptr<folly::SSLContext> sslCtx,
1701       EventBase* evb)
1702       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1703
1704   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1705 };
1706
1707 /**
1708  * Test connecting to, writing to, reading from, and closing the
1709  * connection to the SSL server with TFO.
1710  */
1711 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1712   // Start listening on a local port
1713   WriteCallbackBase writeCallback;
1714   ReadCallback readCallback(&writeCallback);
1715   HandshakeCallback handshakeCallback(&readCallback);
1716   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1717   TestSSLServer server(&acceptCallback, true);
1718
1719   // Set up SSL context.
1720   auto sslContext = std::make_shared<SSLContext>();
1721
1722   // connect
1723   auto socket =
1724       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1725   socket->enableTFO();
1726   socket->open();
1727
1728   // write()
1729   std::array<uint8_t, 128> buf;
1730   memset(buf.data(), 'a', buf.size());
1731   socket->write(buf.data(), buf.size());
1732
1733   // read()
1734   std::array<uint8_t, 128> readbuf;
1735   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1736   EXPECT_EQ(bytesRead, 128);
1737   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1738
1739   // close()
1740   socket->close();
1741 }
1742
1743 /**
1744  * Test connecting to, writing to, reading from, and closing the
1745  * connection to the SSL server with TFO.
1746  */
1747 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1748   // Start listening on a local port
1749   WriteCallbackBase writeCallback;
1750   ReadCallback readCallback(&writeCallback);
1751   HandshakeCallback handshakeCallback(&readCallback);
1752   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1753   TestSSLServer server(&acceptCallback, false);
1754
1755   // Set up SSL context.
1756   auto sslContext = std::make_shared<SSLContext>();
1757
1758   // connect
1759   auto socket =
1760       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1761   socket->enableTFO();
1762   socket->open();
1763
1764   // write()
1765   std::array<uint8_t, 128> buf;
1766   memset(buf.data(), 'a', buf.size());
1767   socket->write(buf.data(), buf.size());
1768
1769   // read()
1770   std::array<uint8_t, 128> readbuf;
1771   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1772   EXPECT_EQ(bytesRead, 128);
1773   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1774
1775   // close()
1776   socket->close();
1777 }
1778
1779 class ConnCallback : public AsyncSocket::ConnectCallback {
1780  public:
1781   virtual void connectSuccess() noexcept override {
1782     state = State::SUCCESS;
1783   }
1784
1785   virtual void connectErr(const AsyncSocketException& ex) noexcept override {
1786     state = State::ERROR;
1787     error = ex.what();
1788   }
1789
1790   enum class State { WAITING, SUCCESS, ERROR };
1791
1792   State state{State::WAITING};
1793   std::string error;
1794 };
1795
1796 template <class Cardinality>
1797 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1798     EventBase* evb,
1799     const SocketAddress& address,
1800     Cardinality cardinality) {
1801   // Set up SSL context.
1802   auto sslContext = std::make_shared<SSLContext>();
1803
1804   // connect
1805   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1806       new MockAsyncTFOSSLSocket(sslContext, evb));
1807   socket->enableTFO();
1808
1809   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1810       .Times(cardinality)
1811       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1812         sockaddr_storage addr;
1813         auto len = address.getAddress(&addr);
1814         return connect(fd, (const struct sockaddr*)&addr, len);
1815       }));
1816   return socket;
1817 }
1818
1819 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1820   // Start listening on a local port
1821   WriteCallbackBase writeCallback;
1822   ReadCallback readCallback(&writeCallback);
1823   HandshakeCallback handshakeCallback(&readCallback);
1824   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1825   TestSSLServer server(&acceptCallback, true);
1826
1827   EventBase evb;
1828
1829   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1830   ConnCallback ccb;
1831   socket->connect(&ccb, server.getAddress(), 30);
1832
1833   evb.loop();
1834   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1835
1836   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1837   evb.loop();
1838
1839   BlockingSocket sock(std::move(socket));
1840   // write()
1841   std::array<uint8_t, 128> buf;
1842   memset(buf.data(), 'a', buf.size());
1843   sock.write(buf.data(), buf.size());
1844
1845   // read()
1846   std::array<uint8_t, 128> readbuf;
1847   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1848   EXPECT_EQ(bytesRead, 128);
1849   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1850
1851   // close()
1852   sock.close();
1853 }
1854
1855 #if !defined(OPENSSL_IS_BORINGSSL)
1856 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1857   // Start listening on a local port
1858   ConnectTimeoutCallback acceptCallback;
1859   TestSSLServer server(&acceptCallback, true);
1860
1861   // Set up SSL context.
1862   auto sslContext = std::make_shared<SSLContext>();
1863
1864   // connect
1865   auto socket =
1866       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1867   socket->enableTFO();
1868   EXPECT_THROW(
1869       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
1870 }
1871 #endif
1872
1873 #if !defined(OPENSSL_IS_BORINGSSL)
1874 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1875   // Start listening on a local port
1876   ConnectTimeoutCallback acceptCallback;
1877   TestSSLServer server(&acceptCallback, true);
1878
1879   EventBase evb;
1880
1881   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1882   ConnCallback ccb;
1883   // Set a short timeout
1884   socket->connect(&ccb, server.getAddress(), 1);
1885
1886   evb.loop();
1887   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1888 }
1889 #endif
1890
1891 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
1892   // Start listening on a local port
1893   EmptyReadCallback readCallback;
1894   HandshakeCallback handshakeCallback(
1895       &readCallback, HandshakeCallback::EXPECT_ERROR);
1896   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
1897   TestSSLServer server(&acceptCallback, true);
1898
1899   EventBase evb;
1900
1901   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1902   ConnCallback ccb;
1903   socket->connect(&ccb, server.getAddress(), 100);
1904
1905   evb.loop();
1906   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1907   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
1908 }
1909
1910 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
1911   // Start listening on a local port
1912   EventBase evb;
1913
1914   // Hopefully nothing is listening on this address
1915   SocketAddress addr("127.0.0.1", 65535);
1916   auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
1917   ConnCallback ccb;
1918   socket->connect(&ccb, addr, 100);
1919
1920   evb.loop();
1921   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1922   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
1923 }
1924
1925 TEST(AsyncSSLSocketTest, TestPreReceivedData) {
1926   EventBase clientEventBase;
1927   EventBase serverEventBase;
1928   auto clientCtx = std::make_shared<SSLContext>();
1929   auto dfServerCtx = std::make_shared<SSLContext>();
1930   std::array<int, 2> fds;
1931   getfds(fds.data());
1932   getctx(clientCtx, dfServerCtx);
1933
1934   AsyncSSLSocket::UniquePtr clientSockPtr(
1935       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
1936   AsyncSSLSocket::UniquePtr serverSockPtr(
1937       new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
1938   auto clientSock = clientSockPtr.get();
1939   auto serverSock = serverSockPtr.get();
1940   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
1941
1942   // Steal some data from the server.
1943   clientEventBase.loopOnce();
1944   std::array<uint8_t, 10> buf;
1945   recv(fds[1], buf.data(), buf.size(), 0);
1946
1947   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
1948   SSLHandshakeServer server(std::move(serverSockPtr), true, true);
1949   while (!client.handshakeSuccess_ && !client.handshakeError_) {
1950     serverEventBase.loopOnce();
1951     clientEventBase.loopOnce();
1952   }
1953
1954   EXPECT_TRUE(client.handshakeSuccess_);
1955   EXPECT_TRUE(server.handshakeSuccess_);
1956   EXPECT_EQ(
1957       serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
1958 }
1959
1960 #endif
1961
1962 } // namespace
1963
1964 #ifdef SIGPIPE
1965 ///////////////////////////////////////////////////////////////////////////
1966 // init_unit_test_suite
1967 ///////////////////////////////////////////////////////////////////////////
1968 namespace {
1969 struct Initializer {
1970   Initializer() {
1971     signal(SIGPIPE, SIG_IGN);
1972   }
1973 };
1974 Initializer initializer;
1975 } // anonymous
1976 #endif