Disable SSL socket cache tests if cache isn't available
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <pthread.h>
19 #include <signal.h>
20
21 #include <folly/SocketAddress.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/EventBase.h>
24 #include <folly/portability/Sockets.h>
25 #include <folly/portability/Unistd.h>
26
27 #include <folly/io/async/test/BlockingSocket.h>
28
29 #include <fcntl.h>
30 #include <folly/io/Cursor.h>
31 #include <gtest/gtest.h>
32 #include <openssl/bio.h>
33 #include <sys/types.h>
34 #include <fstream>
35 #include <iostream>
36 #include <list>
37 #include <set>
38 #include <thread>
39
40 #include <gmock/gmock.h>
41
42 using std::string;
43 using std::vector;
44 using std::min;
45 using std::cerr;
46 using std::endl;
47 using std::list;
48
49 using namespace testing;
50
51 namespace folly {
52 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
53 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
54 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
55
56 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
57 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
58 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
59
60 constexpr size_t SSLClient::kMaxReadBufferSz;
61 constexpr size_t SSLClient::kMaxReadsPerEvent;
62
63 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
64     : ctx_(new folly::SSLContext),
65       acb_(acb),
66       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
67   // Set up the SSL context
68   ctx_->loadCertificate(testCert);
69   ctx_->loadPrivateKey(testKey);
70   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
71
72   acb_->ctx_ = ctx_;
73   acb_->base_ = &evb_;
74
75   // Enable TFO
76   if (enableTFO) {
77     LOG(INFO) << "server TFO enabled";
78     socket_->setTFOEnabled(true, 1000);
79   }
80
81   // set up the listening socket
82   socket_->bind(0);
83   socket_->getAddress(&address_);
84   socket_->listen(100);
85   socket_->addAcceptCallback(acb_, &evb_);
86   socket_->startAccepting();
87
88   int ret = pthread_create(&thread_, nullptr, Main, this);
89   assert(ret == 0);
90   (void)ret;
91
92   std::cerr << "Accepting connections on " << address_ << std::endl;
93 }
94
95 void getfds(int fds[2]) {
96   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
97     FAIL() << "failed to create socketpair: " << strerror(errno);
98   }
99   for (int idx = 0; idx < 2; ++idx) {
100     int flags = fcntl(fds[idx], F_GETFL, 0);
101     if (flags == -1) {
102       FAIL() << "failed to get flags for socket " << idx << ": "
103              << strerror(errno);
104     }
105     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
106       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
107              << strerror(errno);
108     }
109   }
110 }
111
112 void getctx(
113   std::shared_ptr<folly::SSLContext> clientCtx,
114   std::shared_ptr<folly::SSLContext> serverCtx) {
115   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
116
117   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
118   serverCtx->loadCertificate(
119       testCert);
120   serverCtx->loadPrivateKey(
121       testKey);
122 }
123
124 void sslsocketpair(
125   EventBase* eventBase,
126   AsyncSSLSocket::UniquePtr* clientSock,
127   AsyncSSLSocket::UniquePtr* serverSock) {
128   auto clientCtx = std::make_shared<folly::SSLContext>();
129   auto serverCtx = std::make_shared<folly::SSLContext>();
130   int fds[2];
131   getfds(fds);
132   getctx(clientCtx, serverCtx);
133   clientSock->reset(new AsyncSSLSocket(
134                       clientCtx, eventBase, fds[0], false));
135   serverSock->reset(new AsyncSSLSocket(
136                       serverCtx, eventBase, fds[1], true));
137
138   // (*clientSock)->setSendTimeout(100);
139   // (*serverSock)->setSendTimeout(100);
140 }
141
142 // client protocol filters
143 bool clientProtoFilterPickPony(unsigned char** client,
144   unsigned int* client_len, const unsigned char*, unsigned int ) {
145   //the protocol string in length prefixed byte string. the
146   //length byte is not included in the length
147   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
148   *client = p;
149   *client_len = 7;
150   return true;
151 }
152
153 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
154   const unsigned char*, unsigned int) {
155   return false;
156 }
157
158 std::string getFileAsBuf(const char* fileName) {
159   std::string buffer;
160   folly::readFile(fileName, buffer);
161   return buffer;
162 }
163
164 std::string getCommonName(X509* cert) {
165   X509_NAME* subject = X509_get_subject_name(cert);
166   std::string cn;
167   cn.resize(ub_common_name);
168   X509_NAME_get_text_by_NID(
169       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
170   return cn;
171 }
172
173 /**
174  * Test connecting to, writing to, reading from, and closing the
175  * connection to the SSL server.
176  */
177 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
178   // Start listening on a local port
179   WriteCallbackBase writeCallback;
180   ReadCallback readCallback(&writeCallback);
181   HandshakeCallback handshakeCallback(&readCallback);
182   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
183   TestSSLServer server(&acceptCallback);
184
185   // Set up SSL context.
186   std::shared_ptr<SSLContext> sslContext(new SSLContext());
187   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
188   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
189   //sslContext->authenticate(true, false);
190
191   // connect
192   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
193                                                  sslContext);
194   socket->open();
195
196   // write()
197   uint8_t buf[128];
198   memset(buf, 'a', sizeof(buf));
199   socket->write(buf, sizeof(buf));
200
201   // read()
202   uint8_t readbuf[128];
203   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
204   EXPECT_EQ(bytesRead, 128);
205   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
206
207   // close()
208   socket->close();
209
210   cerr << "ConnectWriteReadClose test completed" << endl;
211 }
212
213 /**
214  * Test reading after server close.
215  */
216 TEST(AsyncSSLSocketTest, ReadAfterClose) {
217   // Start listening on a local port
218   WriteCallbackBase writeCallback;
219   ReadEOFCallback readCallback(&writeCallback);
220   HandshakeCallback handshakeCallback(&readCallback);
221   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
222   auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
223
224   // Set up SSL context.
225   auto sslContext = std::make_shared<SSLContext>();
226   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
227
228   auto socket =
229       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
230   socket->open();
231
232   // This should trigger an EOF on the client.
233   auto evb = handshakeCallback.getSocket()->getEventBase();
234   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
235   std::array<uint8_t, 128> readbuf;
236   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
237   EXPECT_EQ(0, bytesRead);
238 }
239
240 /**
241  * Test bad renegotiation
242  */
243 TEST(AsyncSSLSocketTest, Renegotiate) {
244   EventBase eventBase;
245   auto clientCtx = std::make_shared<SSLContext>();
246   auto dfServerCtx = std::make_shared<SSLContext>();
247   std::array<int, 2> fds;
248   getfds(fds.data());
249   getctx(clientCtx, dfServerCtx);
250
251   AsyncSSLSocket::UniquePtr clientSock(
252       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
253   AsyncSSLSocket::UniquePtr serverSock(
254       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
255   SSLHandshakeClient client(std::move(clientSock), true, true);
256   RenegotiatingServer server(std::move(serverSock));
257
258   while (!client.handshakeSuccess_ && !client.handshakeError_) {
259     eventBase.loopOnce();
260   }
261
262   ASSERT_TRUE(client.handshakeSuccess_);
263
264   auto sslSock = std::move(client).moveSocket();
265   sslSock->detachEventBase();
266   // This is nasty, however we don't want to add support for
267   // renegotiation in AsyncSSLSocket.
268   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
269
270   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
271
272   std::thread t([&]() { eventBase.loopForever(); });
273
274   // Trigger the renegotiation.
275   std::array<uint8_t, 128> buf;
276   memset(buf.data(), 'a', buf.size());
277   try {
278     socket->write(buf.data(), buf.size());
279   } catch (AsyncSocketException& e) {
280     LOG(INFO) << "client got error " << e.what();
281   }
282   eventBase.terminateLoopSoon();
283   t.join();
284
285   eventBase.loop();
286   ASSERT_TRUE(server.renegotiationError_);
287 }
288
289 /**
290  * Negative test for handshakeError().
291  */
292 TEST(AsyncSSLSocketTest, HandshakeError) {
293   // Start listening on a local port
294   WriteCallbackBase writeCallback;
295   WriteErrorCallback readCallback(&writeCallback);
296   HandshakeCallback handshakeCallback(&readCallback);
297   HandshakeErrorCallback acceptCallback(&handshakeCallback);
298   TestSSLServer server(&acceptCallback);
299
300   // Set up SSL context.
301   std::shared_ptr<SSLContext> sslContext(new SSLContext());
302   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
303
304   // connect
305   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
306                                                  sslContext);
307   // read()
308   bool ex = false;
309   try {
310     socket->open();
311
312     uint8_t readbuf[128];
313     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
314     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
315   } catch (AsyncSocketException &e) {
316     ex = true;
317   }
318   EXPECT_TRUE(ex);
319
320   // close()
321   socket->close();
322   cerr << "HandshakeError test completed" << endl;
323 }
324
325 /**
326  * Negative test for readError().
327  */
328 TEST(AsyncSSLSocketTest, ReadError) {
329   // Start listening on a local port
330   WriteCallbackBase writeCallback;
331   ReadErrorCallback readCallback(&writeCallback);
332   HandshakeCallback handshakeCallback(&readCallback);
333   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
334   TestSSLServer server(&acceptCallback);
335
336   // Set up SSL context.
337   std::shared_ptr<SSLContext> sslContext(new SSLContext());
338   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
339
340   // connect
341   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
342                                                  sslContext);
343   socket->open();
344
345   // write something to trigger ssl handshake
346   uint8_t buf[128];
347   memset(buf, 'a', sizeof(buf));
348   socket->write(buf, sizeof(buf));
349
350   socket->close();
351   cerr << "ReadError test completed" << endl;
352 }
353
354 /**
355  * Negative test for writeError().
356  */
357 TEST(AsyncSSLSocketTest, WriteError) {
358   // Start listening on a local port
359   WriteCallbackBase writeCallback;
360   WriteErrorCallback readCallback(&writeCallback);
361   HandshakeCallback handshakeCallback(&readCallback);
362   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
363   TestSSLServer server(&acceptCallback);
364
365   // Set up SSL context.
366   std::shared_ptr<SSLContext> sslContext(new SSLContext());
367   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
368
369   // connect
370   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
371                                                  sslContext);
372   socket->open();
373
374   // write something to trigger ssl handshake
375   uint8_t buf[128];
376   memset(buf, 'a', sizeof(buf));
377   socket->write(buf, sizeof(buf));
378
379   socket->close();
380   cerr << "WriteError test completed" << endl;
381 }
382
383 /**
384  * Test a socket with TCP_NODELAY unset.
385  */
386 TEST(AsyncSSLSocketTest, SocketWithDelay) {
387   // Start listening on a local port
388   WriteCallbackBase writeCallback;
389   ReadCallback readCallback(&writeCallback);
390   HandshakeCallback handshakeCallback(&readCallback);
391   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
392   TestSSLServer server(&acceptCallback);
393
394   // Set up SSL context.
395   std::shared_ptr<SSLContext> sslContext(new SSLContext());
396   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
397
398   // connect
399   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
400                                                  sslContext);
401   socket->open();
402
403   // write()
404   uint8_t buf[128];
405   memset(buf, 'a', sizeof(buf));
406   socket->write(buf, sizeof(buf));
407
408   // read()
409   uint8_t readbuf[128];
410   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
411   EXPECT_EQ(bytesRead, 128);
412   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
413
414   // close()
415   socket->close();
416
417   cerr << "SocketWithDelay test completed" << endl;
418 }
419
420 using NextProtocolTypePair =
421     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
422
423 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
424   // For matching protos
425  public:
426   void SetUp() override { getctx(clientCtx, serverCtx); }
427
428   void connect(bool unset = false) {
429     getfds(fds);
430
431     if (unset) {
432       // unsetting NPN for any of [client, server] is enough to make NPN not
433       // work
434       clientCtx->unsetNextProtocols();
435     }
436
437     AsyncSSLSocket::UniquePtr clientSock(
438       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
439     AsyncSSLSocket::UniquePtr serverSock(
440       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
441     client = folly::make_unique<NpnClient>(std::move(clientSock));
442     server = folly::make_unique<NpnServer>(std::move(serverSock));
443
444     eventBase.loop();
445   }
446
447   void expectProtocol(const std::string& proto) {
448     EXPECT_NE(client->nextProtoLength, 0);
449     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
450     EXPECT_EQ(
451         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
452         0);
453     string selected((const char*)client->nextProto, client->nextProtoLength);
454     EXPECT_EQ(proto, selected);
455   }
456
457   void expectNoProtocol() {
458     EXPECT_EQ(client->nextProtoLength, 0);
459     EXPECT_EQ(server->nextProtoLength, 0);
460     EXPECT_EQ(client->nextProto, nullptr);
461     EXPECT_EQ(server->nextProto, nullptr);
462   }
463
464   void expectProtocolType() {
465     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
466         GetParam().second == SSLContext::NextProtocolType::ANY) {
467       EXPECT_EQ(client->protocolType, server->protocolType);
468     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
469                GetParam().second == SSLContext::NextProtocolType::ANY) {
470       // Well not much we can say
471     } else {
472       expectProtocolType(GetParam());
473     }
474   }
475
476   void expectProtocolType(NextProtocolTypePair expected) {
477     EXPECT_EQ(client->protocolType, expected.first);
478     EXPECT_EQ(server->protocolType, expected.second);
479   }
480
481   EventBase eventBase;
482   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
483   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
484   int fds[2];
485   std::unique_ptr<NpnClient> client;
486   std::unique_ptr<NpnServer> server;
487 };
488
489 class NextProtocolTLSExtTest : public NextProtocolTest {
490   // For extended TLS protos
491 };
492
493 class NextProtocolNPNOnlyTest : public NextProtocolTest {
494   // For mismatching protos
495 };
496
497 class NextProtocolMismatchTest : public NextProtocolTest {
498   // For mismatching protos
499 };
500
501 TEST_P(NextProtocolTest, NpnTestOverlap) {
502   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
503   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
504                                         GetParam().second);
505
506   connect();
507
508   expectProtocol("baz");
509   expectProtocolType();
510 }
511
512 TEST_P(NextProtocolTest, NpnTestUnset) {
513   // Identical to above test, except that we want unset NPN before
514   // looping.
515   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
516   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
517                                         GetParam().second);
518
519   connect(true /* unset */);
520
521   // if alpn negotiation fails, type will appear as npn
522   expectNoProtocol();
523   EXPECT_EQ(client->protocolType, server->protocolType);
524 }
525
526 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
527   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
528   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
529                                         GetParam().second);
530
531   connect();
532
533   expectNoProtocol();
534   expectProtocolType(
535       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
536 }
537
538 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
539 // will fail on 1.0.2 before that.
540 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
541   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
542   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
543                                         GetParam().second);
544
545   connect();
546
547   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
548       GetParam().second == SSLContext::NextProtocolType::ALPN) {
549     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
550     // mismatch should result in a fatal alert, but this is OpenSSL's current
551     // behavior and we want to know if it changes.
552     expectNoProtocol();
553   } else {
554     expectProtocol("blub");
555     expectProtocolType(
556         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
557   }
558 }
559
560 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
561   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
562   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
563   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
564                                         GetParam().second);
565
566   connect();
567
568   expectProtocol("ponies");
569   expectProtocolType();
570 }
571
572 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
573   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
574   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
575   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
576                                         GetParam().second);
577
578   connect();
579
580   expectProtocol("blub");
581   expectProtocolType();
582 }
583
584 TEST_P(NextProtocolTest, RandomizedNpnTest) {
585   // Probability that this test will fail is 2^-64, which could be considered
586   // as negligible.
587   const int kTries = 64;
588
589   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
590                                         GetParam().first);
591   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
592                                                   GetParam().second);
593
594   std::set<string> selectedProtocols;
595   for (int i = 0; i < kTries; ++i) {
596     connect();
597
598     EXPECT_NE(client->nextProtoLength, 0);
599     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
600     EXPECT_EQ(
601         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
602         0);
603     string selected((const char*)client->nextProto, client->nextProtoLength);
604     selectedProtocols.insert(selected);
605     expectProtocolType();
606   }
607   EXPECT_EQ(selectedProtocols.size(), 2);
608 }
609
610 INSTANTIATE_TEST_CASE_P(
611     AsyncSSLSocketTest,
612     NextProtocolTest,
613     ::testing::Values(
614         NextProtocolTypePair(
615             SSLContext::NextProtocolType::NPN,
616             SSLContext::NextProtocolType::NPN),
617         NextProtocolTypePair(
618             SSLContext::NextProtocolType::NPN,
619             SSLContext::NextProtocolType::ANY),
620         NextProtocolTypePair(
621             SSLContext::NextProtocolType::ANY,
622             SSLContext::NextProtocolType::ANY)));
623
624 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
625 INSTANTIATE_TEST_CASE_P(
626     AsyncSSLSocketTest,
627     NextProtocolTLSExtTest,
628     ::testing::Values(
629         NextProtocolTypePair(
630             SSLContext::NextProtocolType::ALPN,
631             SSLContext::NextProtocolType::ALPN),
632         NextProtocolTypePair(
633             SSLContext::NextProtocolType::ALPN,
634             SSLContext::NextProtocolType::ANY),
635         NextProtocolTypePair(
636             SSLContext::NextProtocolType::ANY,
637             SSLContext::NextProtocolType::ALPN)));
638 #endif
639
640 INSTANTIATE_TEST_CASE_P(
641     AsyncSSLSocketTest,
642     NextProtocolNPNOnlyTest,
643     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
644                                            SSLContext::NextProtocolType::NPN)));
645
646 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
647 INSTANTIATE_TEST_CASE_P(
648     AsyncSSLSocketTest,
649     NextProtocolMismatchTest,
650     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
651                                            SSLContext::NextProtocolType::ALPN),
652                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
653                                            SSLContext::NextProtocolType::NPN)));
654 #endif
655
656 #ifndef OPENSSL_NO_TLSEXT
657 /**
658  * 1. Client sends TLSEXT_HOSTNAME in client hello.
659  * 2. Server found a match SSL_CTX and use this SSL_CTX to
660  *    continue the SSL handshake.
661  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
662  */
663 TEST(AsyncSSLSocketTest, SNITestMatch) {
664   EventBase eventBase;
665   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
666   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
667   // Use the same SSLContext to continue the handshake after
668   // tlsext_hostname match.
669   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
670   const std::string serverName("xyz.newdev.facebook.com");
671   int fds[2];
672   getfds(fds);
673   getctx(clientCtx, dfServerCtx);
674
675   AsyncSSLSocket::UniquePtr clientSock(
676     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
677   AsyncSSLSocket::UniquePtr serverSock(
678     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
679   SNIClient client(std::move(clientSock));
680   SNIServer server(std::move(serverSock),
681                    dfServerCtx,
682                    hskServerCtx,
683                    serverName);
684
685   eventBase.loop();
686
687   EXPECT_TRUE(client.serverNameMatch);
688   EXPECT_TRUE(server.serverNameMatch);
689 }
690
691 /**
692  * 1. Client sends TLSEXT_HOSTNAME in client hello.
693  * 2. Server cannot find a matching SSL_CTX and continue to use
694  *    the current SSL_CTX to do the handshake.
695  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
696  */
697 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
698   EventBase eventBase;
699   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
700   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
701   // Use the same SSLContext to continue the handshake after
702   // tlsext_hostname match.
703   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
704   const std::string clientRequestingServerName("foo.com");
705   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
706
707   int fds[2];
708   getfds(fds);
709   getctx(clientCtx, dfServerCtx);
710
711   AsyncSSLSocket::UniquePtr clientSock(
712     new AsyncSSLSocket(clientCtx,
713                         &eventBase,
714                         fds[0],
715                         clientRequestingServerName));
716   AsyncSSLSocket::UniquePtr serverSock(
717     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
718   SNIClient client(std::move(clientSock));
719   SNIServer server(std::move(serverSock),
720                    dfServerCtx,
721                    hskServerCtx,
722                    serverExpectedServerName);
723
724   eventBase.loop();
725
726   EXPECT_TRUE(!client.serverNameMatch);
727   EXPECT_TRUE(!server.serverNameMatch);
728 }
729 /**
730  * 1. Client sends TLSEXT_HOSTNAME in client hello.
731  * 2. We then change the serverName.
732  * 3. We expect that we get 'false' as the result for serNameMatch.
733  */
734
735 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
736    EventBase eventBase;
737   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
738   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
739   // Use the same SSLContext to continue the handshake after
740   // tlsext_hostname match.
741   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
742   const std::string serverName("xyz.newdev.facebook.com");
743   int fds[2];
744   getfds(fds);
745   getctx(clientCtx, dfServerCtx);
746
747   AsyncSSLSocket::UniquePtr clientSock(
748     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
749   //Change the server name
750   std::string newName("new.com");
751   clientSock->setServerName(newName);
752   AsyncSSLSocket::UniquePtr serverSock(
753     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
754   SNIClient client(std::move(clientSock));
755   SNIServer server(std::move(serverSock),
756                    dfServerCtx,
757                    hskServerCtx,
758                    serverName);
759
760   eventBase.loop();
761
762   EXPECT_TRUE(!client.serverNameMatch);
763 }
764
765 /**
766  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
767  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
768  */
769 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
770   EventBase eventBase;
771   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
772   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
773   // Use the same SSLContext to continue the handshake after
774   // tlsext_hostname match.
775   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
776   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
777
778   int fds[2];
779   getfds(fds);
780   getctx(clientCtx, dfServerCtx);
781
782   AsyncSSLSocket::UniquePtr clientSock(
783     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
784   AsyncSSLSocket::UniquePtr serverSock(
785     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
786   SNIClient client(std::move(clientSock));
787   SNIServer server(std::move(serverSock),
788                    dfServerCtx,
789                    hskServerCtx,
790                    serverExpectedServerName);
791
792   eventBase.loop();
793
794   EXPECT_TRUE(!client.serverNameMatch);
795   EXPECT_TRUE(!server.serverNameMatch);
796 }
797
798 #endif
799 /**
800  * Test SSL client socket
801  */
802 TEST(AsyncSSLSocketTest, SSLClientTest) {
803   // Start listening on a local port
804   WriteCallbackBase writeCallback;
805   ReadCallback readCallback(&writeCallback);
806   HandshakeCallback handshakeCallback(&readCallback);
807   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
808   TestSSLServer server(&acceptCallback);
809
810   // Set up SSL client
811   EventBase eventBase;
812   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
813
814   client->connect();
815   EventBaseAborter eba(&eventBase, 3000);
816   eventBase.loop();
817
818   EXPECT_EQ(client->getMiss(), 1);
819   EXPECT_EQ(client->getHit(), 0);
820
821   cerr << "SSLClientTest test completed" << endl;
822 }
823
824
825 /**
826  * Test SSL client socket session re-use
827  */
828 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
829   // Start listening on a local port
830   WriteCallbackBase writeCallback;
831   ReadCallback readCallback(&writeCallback);
832   HandshakeCallback handshakeCallback(&readCallback);
833   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
834   TestSSLServer server(&acceptCallback);
835
836   // Set up SSL client
837   EventBase eventBase;
838   auto client =
839       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
840
841   client->connect();
842   EventBaseAborter eba(&eventBase, 3000);
843   eventBase.loop();
844
845   EXPECT_EQ(client->getMiss(), 1);
846   EXPECT_EQ(client->getHit(), 9);
847
848   cerr << "SSLClientTestReuse test completed" << endl;
849 }
850
851 /**
852  * Test SSL client socket timeout
853  */
854 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
855   // Start listening on a local port
856   EmptyReadCallback readCallback;
857   HandshakeCallback handshakeCallback(&readCallback,
858                                       HandshakeCallback::EXPECT_ERROR);
859   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
860   TestSSLServer server(&acceptCallback);
861
862   // Set up SSL client
863   EventBase eventBase;
864   auto client =
865       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
866   client->connect(true /* write before connect completes */);
867   EventBaseAborter eba(&eventBase, 3000);
868   eventBase.loop();
869
870   usleep(100000);
871   // This is checking that the connectError callback precedes any queued
872   // writeError callbacks.  This matches AsyncSocket's behavior
873   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
874   EXPECT_EQ(client->getErrors(), 1);
875   EXPECT_EQ(client->getMiss(), 0);
876   EXPECT_EQ(client->getHit(), 0);
877
878   cerr << "SSLClientTimeoutTest test completed" << endl;
879 }
880
881 // This is a FB-only extension, and the tests will fail without it
882 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
883 /**
884  * Test SSL server async cache
885  */
886 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
887   // Start listening on a local port
888   WriteCallbackBase writeCallback;
889   ReadCallback readCallback(&writeCallback);
890   HandshakeCallback handshakeCallback(&readCallback);
891   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
892   TestSSLAsyncCacheServer server(&acceptCallback);
893
894   // Set up SSL client
895   EventBase eventBase;
896   auto client =
897       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
898
899   client->connect();
900   EventBaseAborter eba(&eventBase, 3000);
901   eventBase.loop();
902
903   EXPECT_EQ(server.getAsyncCallbacks(), 18);
904   EXPECT_EQ(server.getAsyncLookups(), 9);
905   EXPECT_EQ(client->getMiss(), 10);
906   EXPECT_EQ(client->getHit(), 0);
907
908   cerr << "SSLServerAsyncCacheTest test completed" << endl;
909 }
910
911
912 /**
913  * Test SSL server accept timeout with cache path
914  */
915 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
916   // Start listening on a local port
917   WriteCallbackBase writeCallback;
918   ReadCallback readCallback(&writeCallback);
919   EmptyReadCallback clientReadCallback;
920   HandshakeCallback handshakeCallback(&readCallback);
921   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
922   TestSSLAsyncCacheServer server(&acceptCallback);
923
924   // Set up SSL client
925   EventBase eventBase;
926   // only do a TCP connect
927   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
928   sock->connect(nullptr, server.getAddress());
929   clientReadCallback.tcpSocket_ = sock;
930   sock->setReadCB(&clientReadCallback);
931
932   EventBaseAborter eba(&eventBase, 3000);
933   eventBase.loop();
934
935   EXPECT_EQ(readCallback.state, STATE_WAITING);
936
937   cerr << "SSLServerTimeoutTest test completed" << endl;
938 }
939
940 /**
941  * Test SSL server accept timeout with cache path
942  */
943 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
944   // Start listening on a local port
945   WriteCallbackBase writeCallback;
946   ReadCallback readCallback(&writeCallback);
947   HandshakeCallback handshakeCallback(&readCallback);
948   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
949   TestSSLAsyncCacheServer server(&acceptCallback);
950
951   // Set up SSL client
952   EventBase eventBase;
953   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
954
955   client->connect();
956   EventBaseAborter eba(&eventBase, 3000);
957   eventBase.loop();
958
959   EXPECT_EQ(server.getAsyncCallbacks(), 1);
960   EXPECT_EQ(server.getAsyncLookups(), 1);
961   EXPECT_EQ(client->getErrors(), 1);
962   EXPECT_EQ(client->getMiss(), 1);
963   EXPECT_EQ(client->getHit(), 0);
964
965   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
966 }
967
968 /**
969  * Test SSL server accept timeout with cache path
970  */
971 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
972   // Start listening on a local port
973   WriteCallbackBase writeCallback;
974   ReadCallback readCallback(&writeCallback);
975   HandshakeCallback handshakeCallback(&readCallback,
976                                       HandshakeCallback::EXPECT_ERROR);
977   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
978   TestSSLAsyncCacheServer server(&acceptCallback, 500);
979
980   // Set up SSL client
981   EventBase eventBase;
982   auto client =
983       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
984
985   client->connect();
986   EventBaseAborter eba(&eventBase, 3000);
987   eventBase.loop();
988
989   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
990       handshakeCallback.closeSocket();});
991   // give time for the cache lookup to come back and find it closed
992   handshakeCallback.waitForHandshake();
993
994   EXPECT_EQ(server.getAsyncCallbacks(), 1);
995   EXPECT_EQ(server.getAsyncLookups(), 1);
996   EXPECT_EQ(client->getErrors(), 1);
997   EXPECT_EQ(client->getMiss(), 1);
998   EXPECT_EQ(client->getHit(), 0);
999
1000   cerr << "SSLServerCacheCloseTest test completed" << endl;
1001 }
1002 #endif
1003
1004 /**
1005  * Verify Client Ciphers obtained using SSL MSG Callback.
1006  */
1007 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1008   EventBase eventBase;
1009   auto clientCtx = std::make_shared<SSLContext>();
1010   auto serverCtx = std::make_shared<SSLContext>();
1011   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1012   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
1013   serverCtx->loadPrivateKey(testKey);
1014   serverCtx->loadCertificate(testCert);
1015   serverCtx->loadTrustedCertificates(testCA);
1016   serverCtx->loadClientCAList(testCA);
1017
1018   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1019   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
1020   clientCtx->loadPrivateKey(testKey);
1021   clientCtx->loadCertificate(testCert);
1022   clientCtx->loadTrustedCertificates(testCA);
1023
1024   int fds[2];
1025   getfds(fds);
1026
1027   AsyncSSLSocket::UniquePtr clientSock(
1028       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1029   AsyncSSLSocket::UniquePtr serverSock(
1030       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1031
1032   SSLHandshakeClient client(std::move(clientSock), true, true);
1033   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1034
1035   eventBase.loop();
1036
1037   EXPECT_EQ(server.clientCiphers_,
1038             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
1039   EXPECT_TRUE(client.handshakeVerify_);
1040   EXPECT_TRUE(client.handshakeSuccess_);
1041   EXPECT_TRUE(!client.handshakeError_);
1042   EXPECT_TRUE(server.handshakeVerify_);
1043   EXPECT_TRUE(server.handshakeSuccess_);
1044   EXPECT_TRUE(!server.handshakeError_);
1045 }
1046
1047 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1048   EventBase eventBase;
1049   auto ctx = std::make_shared<SSLContext>();
1050
1051   int fds[2];
1052   getfds(fds);
1053
1054   int bufLen = 42;
1055   uint8_t majorVersion = 18;
1056   uint8_t minorVersion = 25;
1057
1058   // Create callback buf
1059   auto buf = IOBuf::create(bufLen);
1060   buf->append(bufLen);
1061   folly::io::RWPrivateCursor cursor(buf.get());
1062   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1063   cursor.write<uint16_t>(0);
1064   cursor.write<uint8_t>(38);
1065   cursor.write<uint8_t>(majorVersion);
1066   cursor.write<uint8_t>(minorVersion);
1067   cursor.skip(32);
1068   cursor.write<uint32_t>(0);
1069
1070   SSL* ssl = ctx->createSSL();
1071   SCOPE_EXIT { SSL_free(ssl); };
1072   AsyncSSLSocket::UniquePtr sock(
1073       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1074   sock->enableClientHelloParsing();
1075
1076   // Test client hello parsing in one packet
1077   AsyncSSLSocket::clientHelloParsingCallback(
1078       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1079   buf.reset();
1080
1081   auto parsedClientHello = sock->getClientHelloInfo();
1082   EXPECT_TRUE(parsedClientHello != nullptr);
1083   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1084   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1085 }
1086
1087 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1088   EventBase eventBase;
1089   auto ctx = std::make_shared<SSLContext>();
1090
1091   int fds[2];
1092   getfds(fds);
1093
1094   int bufLen = 42;
1095   uint8_t majorVersion = 18;
1096   uint8_t minorVersion = 25;
1097
1098   // Create callback buf
1099   auto buf = IOBuf::create(bufLen);
1100   buf->append(bufLen);
1101   folly::io::RWPrivateCursor cursor(buf.get());
1102   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1103   cursor.write<uint16_t>(0);
1104   cursor.write<uint8_t>(38);
1105   cursor.write<uint8_t>(majorVersion);
1106   cursor.write<uint8_t>(minorVersion);
1107   cursor.skip(32);
1108   cursor.write<uint32_t>(0);
1109
1110   SSL* ssl = ctx->createSSL();
1111   SCOPE_EXIT { SSL_free(ssl); };
1112   AsyncSSLSocket::UniquePtr sock(
1113       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1114   sock->enableClientHelloParsing();
1115
1116   // Test parsing with two packets with first packet size < 3
1117   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1118   AsyncSSLSocket::clientHelloParsingCallback(
1119       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1120       ssl, sock.get());
1121   bufCopy.reset();
1122   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1123   AsyncSSLSocket::clientHelloParsingCallback(
1124       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1125       ssl, sock.get());
1126   bufCopy.reset();
1127
1128   auto parsedClientHello = sock->getClientHelloInfo();
1129   EXPECT_TRUE(parsedClientHello != nullptr);
1130   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1131   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1132 }
1133
1134 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1135   EventBase eventBase;
1136   auto ctx = std::make_shared<SSLContext>();
1137
1138   int fds[2];
1139   getfds(fds);
1140
1141   int bufLen = 42;
1142   uint8_t majorVersion = 18;
1143   uint8_t minorVersion = 25;
1144
1145   // Create callback buf
1146   auto buf = IOBuf::create(bufLen);
1147   buf->append(bufLen);
1148   folly::io::RWPrivateCursor cursor(buf.get());
1149   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1150   cursor.write<uint16_t>(0);
1151   cursor.write<uint8_t>(38);
1152   cursor.write<uint8_t>(majorVersion);
1153   cursor.write<uint8_t>(minorVersion);
1154   cursor.skip(32);
1155   cursor.write<uint32_t>(0);
1156
1157   SSL* ssl = ctx->createSSL();
1158   SCOPE_EXIT { SSL_free(ssl); };
1159   AsyncSSLSocket::UniquePtr sock(
1160       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1161   sock->enableClientHelloParsing();
1162
1163   // Test parsing with multiple small packets
1164   for (uint64_t i = 0; i < buf->length(); i += 3) {
1165     auto bufCopy = folly::IOBuf::copyBuffer(
1166         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1167     AsyncSSLSocket::clientHelloParsingCallback(
1168         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1169         ssl, sock.get());
1170     bufCopy.reset();
1171   }
1172
1173   auto parsedClientHello = sock->getClientHelloInfo();
1174   EXPECT_TRUE(parsedClientHello != nullptr);
1175   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1176   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1177 }
1178
1179 /**
1180  * Verify sucessful behavior of SSL certificate validation.
1181  */
1182 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1183   EventBase eventBase;
1184   auto clientCtx = std::make_shared<SSLContext>();
1185   auto dfServerCtx = std::make_shared<SSLContext>();
1186
1187   int fds[2];
1188   getfds(fds);
1189   getctx(clientCtx, dfServerCtx);
1190
1191   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1192   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1193
1194   AsyncSSLSocket::UniquePtr clientSock(
1195     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1196   AsyncSSLSocket::UniquePtr serverSock(
1197     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1198
1199   SSLHandshakeClient client(std::move(clientSock), true, true);
1200   clientCtx->loadTrustedCertificates(testCA);
1201
1202   SSLHandshakeServer server(std::move(serverSock), true, true);
1203
1204   eventBase.loop();
1205
1206   EXPECT_TRUE(client.handshakeVerify_);
1207   EXPECT_TRUE(client.handshakeSuccess_);
1208   EXPECT_TRUE(!client.handshakeError_);
1209   EXPECT_LE(0, client.handshakeTime.count());
1210   EXPECT_TRUE(!server.handshakeVerify_);
1211   EXPECT_TRUE(server.handshakeSuccess_);
1212   EXPECT_TRUE(!server.handshakeError_);
1213   EXPECT_LE(0, server.handshakeTime.count());
1214 }
1215
1216 /**
1217  * Verify that the client's verification callback is able to fail SSL
1218  * connection establishment.
1219  */
1220 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1221   EventBase eventBase;
1222   auto clientCtx = std::make_shared<SSLContext>();
1223   auto dfServerCtx = std::make_shared<SSLContext>();
1224
1225   int fds[2];
1226   getfds(fds);
1227   getctx(clientCtx, dfServerCtx);
1228
1229   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1230   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1231
1232   AsyncSSLSocket::UniquePtr clientSock(
1233     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1234   AsyncSSLSocket::UniquePtr serverSock(
1235     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1236
1237   SSLHandshakeClient client(std::move(clientSock), true, false);
1238   clientCtx->loadTrustedCertificates(testCA);
1239
1240   SSLHandshakeServer server(std::move(serverSock), true, true);
1241
1242   eventBase.loop();
1243
1244   EXPECT_TRUE(client.handshakeVerify_);
1245   EXPECT_TRUE(!client.handshakeSuccess_);
1246   EXPECT_TRUE(client.handshakeError_);
1247   EXPECT_LE(0, client.handshakeTime.count());
1248   EXPECT_TRUE(!server.handshakeVerify_);
1249   EXPECT_TRUE(!server.handshakeSuccess_);
1250   EXPECT_TRUE(server.handshakeError_);
1251   EXPECT_LE(0, server.handshakeTime.count());
1252 }
1253
1254 /**
1255  * Verify that the options in SSLContext can be overridden in
1256  * sslConnect/Accept.i.e specifying that no validation should be performed
1257  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1258  * the validation callback.
1259  */
1260 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1261   EventBase eventBase;
1262   auto clientCtx = std::make_shared<SSLContext>();
1263   auto dfServerCtx = std::make_shared<SSLContext>();
1264
1265   int fds[2];
1266   getfds(fds);
1267   getctx(clientCtx, dfServerCtx);
1268
1269   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1270   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1271
1272   AsyncSSLSocket::UniquePtr clientSock(
1273     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1274   AsyncSSLSocket::UniquePtr serverSock(
1275     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1276
1277   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1278   clientCtx->loadTrustedCertificates(testCA);
1279
1280   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1281
1282   eventBase.loop();
1283
1284   EXPECT_TRUE(!client.handshakeVerify_);
1285   EXPECT_TRUE(client.handshakeSuccess_);
1286   EXPECT_TRUE(!client.handshakeError_);
1287   EXPECT_LE(0, client.handshakeTime.count());
1288   EXPECT_TRUE(!server.handshakeVerify_);
1289   EXPECT_TRUE(server.handshakeSuccess_);
1290   EXPECT_TRUE(!server.handshakeError_);
1291   EXPECT_LE(0, server.handshakeTime.count());
1292 }
1293
1294 /**
1295  * Verify that the options in SSLContext can be overridden in
1296  * sslConnect/Accept. Enable verification even if context says otherwise.
1297  * Test requireClientCert with client cert
1298  */
1299 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1300   EventBase eventBase;
1301   auto clientCtx = std::make_shared<SSLContext>();
1302   auto serverCtx = std::make_shared<SSLContext>();
1303   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1304   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1305   serverCtx->loadPrivateKey(testKey);
1306   serverCtx->loadCertificate(testCert);
1307   serverCtx->loadTrustedCertificates(testCA);
1308   serverCtx->loadClientCAList(testCA);
1309
1310   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1311   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1312   clientCtx->loadPrivateKey(testKey);
1313   clientCtx->loadCertificate(testCert);
1314   clientCtx->loadTrustedCertificates(testCA);
1315
1316   int fds[2];
1317   getfds(fds);
1318
1319   AsyncSSLSocket::UniquePtr clientSock(
1320       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1321   AsyncSSLSocket::UniquePtr serverSock(
1322       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1323
1324   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1325   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1326
1327   eventBase.loop();
1328
1329   EXPECT_TRUE(client.handshakeVerify_);
1330   EXPECT_TRUE(client.handshakeSuccess_);
1331   EXPECT_FALSE(client.handshakeError_);
1332   EXPECT_LE(0, client.handshakeTime.count());
1333   EXPECT_TRUE(server.handshakeVerify_);
1334   EXPECT_TRUE(server.handshakeSuccess_);
1335   EXPECT_FALSE(server.handshakeError_);
1336   EXPECT_LE(0, server.handshakeTime.count());
1337 }
1338
1339 /**
1340  * Verify that the client's verification callback is able to override
1341  * the preverification failure and allow a successful connection.
1342  */
1343 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1344   EventBase eventBase;
1345   auto clientCtx = std::make_shared<SSLContext>();
1346   auto dfServerCtx = std::make_shared<SSLContext>();
1347
1348   int fds[2];
1349   getfds(fds);
1350   getctx(clientCtx, dfServerCtx);
1351
1352   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1353   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1354
1355   AsyncSSLSocket::UniquePtr clientSock(
1356     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1357   AsyncSSLSocket::UniquePtr serverSock(
1358     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1359
1360   SSLHandshakeClient client(std::move(clientSock), false, true);
1361   SSLHandshakeServer server(std::move(serverSock), true, true);
1362
1363   eventBase.loop();
1364
1365   EXPECT_TRUE(client.handshakeVerify_);
1366   EXPECT_TRUE(client.handshakeSuccess_);
1367   EXPECT_TRUE(!client.handshakeError_);
1368   EXPECT_LE(0, client.handshakeTime.count());
1369   EXPECT_TRUE(!server.handshakeVerify_);
1370   EXPECT_TRUE(server.handshakeSuccess_);
1371   EXPECT_TRUE(!server.handshakeError_);
1372   EXPECT_LE(0, server.handshakeTime.count());
1373 }
1374
1375 /**
1376  * Verify that specifying that no validation should be performed allows an
1377  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1378  * callback.
1379  */
1380 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1381   EventBase eventBase;
1382   auto clientCtx = std::make_shared<SSLContext>();
1383   auto dfServerCtx = std::make_shared<SSLContext>();
1384
1385   int fds[2];
1386   getfds(fds);
1387   getctx(clientCtx, dfServerCtx);
1388
1389   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1390   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1391
1392   AsyncSSLSocket::UniquePtr clientSock(
1393     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1394   AsyncSSLSocket::UniquePtr serverSock(
1395     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1396
1397   SSLHandshakeClient client(std::move(clientSock), false, false);
1398   SSLHandshakeServer server(std::move(serverSock), false, false);
1399
1400   eventBase.loop();
1401
1402   EXPECT_TRUE(!client.handshakeVerify_);
1403   EXPECT_TRUE(client.handshakeSuccess_);
1404   EXPECT_TRUE(!client.handshakeError_);
1405   EXPECT_LE(0, client.handshakeTime.count());
1406   EXPECT_TRUE(!server.handshakeVerify_);
1407   EXPECT_TRUE(server.handshakeSuccess_);
1408   EXPECT_TRUE(!server.handshakeError_);
1409   EXPECT_LE(0, server.handshakeTime.count());
1410 }
1411
1412 /**
1413  * Test requireClientCert with client cert
1414  */
1415 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1416   EventBase eventBase;
1417   auto clientCtx = std::make_shared<SSLContext>();
1418   auto serverCtx = std::make_shared<SSLContext>();
1419   serverCtx->setVerificationOption(
1420       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1421   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1422   serverCtx->loadPrivateKey(testKey);
1423   serverCtx->loadCertificate(testCert);
1424   serverCtx->loadTrustedCertificates(testCA);
1425   serverCtx->loadClientCAList(testCA);
1426
1427   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1428   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1429   clientCtx->loadPrivateKey(testKey);
1430   clientCtx->loadCertificate(testCert);
1431   clientCtx->loadTrustedCertificates(testCA);
1432
1433   int fds[2];
1434   getfds(fds);
1435
1436   AsyncSSLSocket::UniquePtr clientSock(
1437       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1438   AsyncSSLSocket::UniquePtr serverSock(
1439       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1440
1441   SSLHandshakeClient client(std::move(clientSock), true, true);
1442   SSLHandshakeServer server(std::move(serverSock), true, true);
1443
1444   eventBase.loop();
1445
1446   EXPECT_TRUE(client.handshakeVerify_);
1447   EXPECT_TRUE(client.handshakeSuccess_);
1448   EXPECT_FALSE(client.handshakeError_);
1449   EXPECT_LE(0, client.handshakeTime.count());
1450   EXPECT_TRUE(server.handshakeVerify_);
1451   EXPECT_TRUE(server.handshakeSuccess_);
1452   EXPECT_FALSE(server.handshakeError_);
1453   EXPECT_LE(0, server.handshakeTime.count());
1454 }
1455
1456
1457 /**
1458  * Test requireClientCert with no client cert
1459  */
1460 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1461   EventBase eventBase;
1462   auto clientCtx = std::make_shared<SSLContext>();
1463   auto serverCtx = std::make_shared<SSLContext>();
1464   serverCtx->setVerificationOption(
1465       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1466   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1467   serverCtx->loadPrivateKey(testKey);
1468   serverCtx->loadCertificate(testCert);
1469   serverCtx->loadTrustedCertificates(testCA);
1470   serverCtx->loadClientCAList(testCA);
1471   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1472   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1473
1474   int fds[2];
1475   getfds(fds);
1476
1477   AsyncSSLSocket::UniquePtr clientSock(
1478       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1479   AsyncSSLSocket::UniquePtr serverSock(
1480       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1481
1482   SSLHandshakeClient client(std::move(clientSock), false, false);
1483   SSLHandshakeServer server(std::move(serverSock), false, false);
1484
1485   eventBase.loop();
1486
1487   EXPECT_FALSE(server.handshakeVerify_);
1488   EXPECT_FALSE(server.handshakeSuccess_);
1489   EXPECT_TRUE(server.handshakeError_);
1490   EXPECT_LE(0, client.handshakeTime.count());
1491   EXPECT_LE(0, server.handshakeTime.count());
1492 }
1493
1494 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1495   auto cert = getFileAsBuf(testCert);
1496   auto key = getFileAsBuf(testKey);
1497
1498   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1499   BIO_write(certBio.get(), cert.data(), cert.size());
1500   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1501   BIO_write(keyBio.get(), key.data(), key.size());
1502
1503   // Create SSL structs from buffers to get properties
1504   ssl::X509UniquePtr certStruct(
1505       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1506   ssl::EvpPkeyUniquePtr keyStruct(
1507       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1508   certBio = nullptr;
1509   keyBio = nullptr;
1510
1511   auto origCommonName = getCommonName(certStruct.get());
1512   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1513   certStruct = nullptr;
1514   keyStruct = nullptr;
1515
1516   auto ctx = std::make_shared<SSLContext>();
1517   ctx->loadPrivateKeyFromBufferPEM(key);
1518   ctx->loadCertificateFromBufferPEM(cert);
1519   ctx->loadTrustedCertificates(testCA);
1520
1521   ssl::SSLUniquePtr ssl(ctx->createSSL());
1522
1523   auto newCert = SSL_get_certificate(ssl.get());
1524   auto newKey = SSL_get_privatekey(ssl.get());
1525
1526   // Get properties from SSL struct
1527   auto newCommonName = getCommonName(newCert);
1528   auto newKeySize = EVP_PKEY_bits(newKey);
1529
1530   // Check that the key and cert have the expected properties
1531   EXPECT_EQ(origCommonName, newCommonName);
1532   EXPECT_EQ(origKeySize, newKeySize);
1533 }
1534
1535 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1536   EventBase eb;
1537
1538   // Set up SSL context.
1539   auto sslContext = std::make_shared<SSLContext>();
1540   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1541
1542   // create SSL socket
1543   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1544
1545   EXPECT_EQ(1500, socket->getMinWriteSize());
1546
1547   socket->setMinWriteSize(0);
1548   EXPECT_EQ(0, socket->getMinWriteSize());
1549   socket->setMinWriteSize(50000);
1550   EXPECT_EQ(50000, socket->getMinWriteSize());
1551 }
1552
1553 class ReadCallbackTerminator : public ReadCallback {
1554  public:
1555   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1556       : ReadCallback(wcb)
1557       , base_(base) {}
1558
1559   // Do not write data back, terminate the loop.
1560   void readDataAvailable(size_t len) noexcept override {
1561     std::cerr << "readDataAvailable, len " << len << std::endl;
1562
1563     currentBuffer.length = len;
1564
1565     buffers.push_back(currentBuffer);
1566     currentBuffer.reset();
1567     state = STATE_SUCCEEDED;
1568
1569     socket_->setReadCB(nullptr);
1570     base_->terminateLoopSoon();
1571   }
1572  private:
1573   EventBase* base_;
1574 };
1575
1576
1577 /**
1578  * Test a full unencrypted codepath
1579  */
1580 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1581   EventBase base;
1582
1583   auto clientCtx = std::make_shared<folly::SSLContext>();
1584   auto serverCtx = std::make_shared<folly::SSLContext>();
1585   int fds[2];
1586   getfds(fds);
1587   getctx(clientCtx, serverCtx);
1588   auto client = AsyncSSLSocket::newSocket(
1589                   clientCtx, &base, fds[0], false, true);
1590   auto server = AsyncSSLSocket::newSocket(
1591                   serverCtx, &base, fds[1], true, true);
1592
1593   ReadCallbackTerminator readCallback(&base, nullptr);
1594   server->setReadCB(&readCallback);
1595   readCallback.setSocket(server);
1596
1597   uint8_t buf[128];
1598   memset(buf, 'a', sizeof(buf));
1599   client->write(nullptr, buf, sizeof(buf));
1600
1601   // Check that bytes are unencrypted
1602   char c;
1603   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1604   EXPECT_EQ('a', c);
1605
1606   EventBaseAborter eba(&base, 3000);
1607   base.loop();
1608
1609   EXPECT_EQ(1, readCallback.buffers.size());
1610   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1611
1612   server->setReadCB(&readCallback);
1613
1614   // Unencrypted
1615   server->sslAccept(nullptr);
1616   client->sslConn(nullptr);
1617
1618   // Do NOT wait for handshake, writing should be queued and happen after
1619
1620   client->write(nullptr, buf, sizeof(buf));
1621
1622   // Check that bytes are *not* unencrypted
1623   char c2;
1624   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1625   EXPECT_NE('a', c2);
1626
1627
1628   base.loop();
1629
1630   EXPECT_EQ(2, readCallback.buffers.size());
1631   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1632 }
1633
1634 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1635   // Start listening on a local port
1636   WriteCallbackBase writeCallback;
1637   WriteErrorCallback readCallback(&writeCallback);
1638   HandshakeCallback handshakeCallback(&readCallback,
1639                                       HandshakeCallback::EXPECT_ERROR);
1640   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1641   TestSSLServer server(&acceptCallback);
1642
1643   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1644   socket->open();
1645   uint8_t buf[3] = {0x16, 0x03, 0x01};
1646   socket->write(buf, sizeof(buf));
1647   socket->closeWithReset();
1648
1649   handshakeCallback.waitForHandshake();
1650   EXPECT_NE(
1651       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1652   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1653 }
1654
1655 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1656   // Start listening on a local port
1657   WriteCallbackBase writeCallback;
1658   WriteErrorCallback readCallback(&writeCallback);
1659   HandshakeCallback handshakeCallback(&readCallback,
1660                                       HandshakeCallback::EXPECT_ERROR);
1661   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1662   TestSSLServer server(&acceptCallback);
1663
1664   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1665   socket->open();
1666   uint8_t buf[3] = {0x16, 0x03, 0x01};
1667   socket->write(buf, sizeof(buf));
1668   socket->close();
1669
1670   handshakeCallback.waitForHandshake();
1671   EXPECT_NE(
1672       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1673   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1674 }
1675
1676 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1677   // Start listening on a local port
1678   WriteCallbackBase writeCallback;
1679   WriteErrorCallback readCallback(&writeCallback);
1680   HandshakeCallback handshakeCallback(&readCallback,
1681                                       HandshakeCallback::EXPECT_ERROR);
1682   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1683   TestSSLServer server(&acceptCallback);
1684
1685   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1686   socket->open();
1687   uint8_t buf[256] = {0x16, 0x03};
1688   memset(buf + 2, 'a', sizeof(buf) - 2);
1689   socket->write(buf, sizeof(buf));
1690   socket->close();
1691
1692   handshakeCallback.waitForHandshake();
1693   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1694             std::string::npos);
1695   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1696             std::string::npos);
1697 }
1698
1699 #if FOLLY_ALLOW_TFO
1700
1701 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1702  public:
1703   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1704
1705   explicit MockAsyncTFOSSLSocket(
1706       std::shared_ptr<folly::SSLContext> sslCtx,
1707       EventBase* evb)
1708       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1709
1710   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1711 };
1712
1713 /**
1714  * Test connecting to, writing to, reading from, and closing the
1715  * connection to the SSL server with TFO.
1716  */
1717 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1718   // Start listening on a local port
1719   WriteCallbackBase writeCallback;
1720   ReadCallback readCallback(&writeCallback);
1721   HandshakeCallback handshakeCallback(&readCallback);
1722   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1723   TestSSLServer server(&acceptCallback, true);
1724
1725   // Set up SSL context.
1726   auto sslContext = std::make_shared<SSLContext>();
1727
1728   // connect
1729   auto socket =
1730       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1731   socket->enableTFO();
1732   socket->open();
1733
1734   // write()
1735   std::array<uint8_t, 128> buf;
1736   memset(buf.data(), 'a', buf.size());
1737   socket->write(buf.data(), buf.size());
1738
1739   // read()
1740   std::array<uint8_t, 128> readbuf;
1741   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1742   EXPECT_EQ(bytesRead, 128);
1743   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1744
1745   // close()
1746   socket->close();
1747 }
1748
1749 /**
1750  * Test connecting to, writing to, reading from, and closing the
1751  * connection to the SSL server with TFO.
1752  */
1753 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1754   // Start listening on a local port
1755   WriteCallbackBase writeCallback;
1756   ReadCallback readCallback(&writeCallback);
1757   HandshakeCallback handshakeCallback(&readCallback);
1758   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1759   TestSSLServer server(&acceptCallback, false);
1760
1761   // Set up SSL context.
1762   auto sslContext = std::make_shared<SSLContext>();
1763
1764   // connect
1765   auto socket =
1766       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1767   socket->enableTFO();
1768   socket->open();
1769
1770   // write()
1771   std::array<uint8_t, 128> buf;
1772   memset(buf.data(), 'a', buf.size());
1773   socket->write(buf.data(), buf.size());
1774
1775   // read()
1776   std::array<uint8_t, 128> readbuf;
1777   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1778   EXPECT_EQ(bytesRead, 128);
1779   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1780
1781   // close()
1782   socket->close();
1783 }
1784
1785 class ConnCallback : public AsyncSocket::ConnectCallback {
1786  public:
1787   virtual void connectSuccess() noexcept override {
1788     state = State::SUCCESS;
1789   }
1790
1791   virtual void connectErr(const AsyncSocketException&) noexcept override {
1792     state = State::ERROR;
1793   }
1794
1795   enum class State { WAITING, SUCCESS, ERROR };
1796
1797   State state{State::WAITING};
1798 };
1799
1800 template <class Cardinality>
1801 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1802     EventBase* evb,
1803     const SocketAddress& address,
1804     Cardinality cardinality) {
1805   // Set up SSL context.
1806   auto sslContext = std::make_shared<SSLContext>();
1807
1808   // connect
1809   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1810       new MockAsyncTFOSSLSocket(sslContext, evb));
1811   socket->enableTFO();
1812
1813   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1814       .Times(cardinality)
1815       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1816         sockaddr_storage addr;
1817         auto len = address.getAddress(&addr);
1818         return connect(fd, (const struct sockaddr*)&addr, len);
1819       }));
1820   return socket;
1821 }
1822
1823 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1824   // Start listening on a local port
1825   WriteCallbackBase writeCallback;
1826   ReadCallback readCallback(&writeCallback);
1827   HandshakeCallback handshakeCallback(&readCallback);
1828   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1829   TestSSLServer server(&acceptCallback, true);
1830
1831   EventBase evb;
1832
1833   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1834   ConnCallback ccb;
1835   socket->connect(&ccb, server.getAddress(), 30);
1836
1837   evb.loop();
1838   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1839
1840   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1841   evb.loop();
1842
1843   BlockingSocket sock(std::move(socket));
1844   // write()
1845   std::array<uint8_t, 128> buf;
1846   memset(buf.data(), 'a', buf.size());
1847   sock.write(buf.data(), buf.size());
1848
1849   // read()
1850   std::array<uint8_t, 128> readbuf;
1851   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1852   EXPECT_EQ(bytesRead, 128);
1853   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1854
1855   // close()
1856   sock.close();
1857 }
1858
1859 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1860   // Start listening on a local port
1861   ConnectTimeoutCallback acceptCallback;
1862   TestSSLServer server(&acceptCallback, true);
1863
1864   // Set up SSL context.
1865   auto sslContext = std::make_shared<SSLContext>();
1866
1867   // connect
1868   auto socket =
1869       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1870   socket->enableTFO();
1871   EXPECT_THROW(
1872       socket->open(std::chrono::milliseconds(1)), AsyncSocketException);
1873 }
1874
1875 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1876   // Start listening on a local port
1877   ConnectTimeoutCallback acceptCallback;
1878   TestSSLServer server(&acceptCallback, true);
1879
1880   EventBase evb;
1881
1882   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1883   ConnCallback ccb;
1884   // Set a short timeout
1885   socket->connect(&ccb, server.getAddress(), 1);
1886
1887   evb.loop();
1888   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1889 }
1890
1891 #endif
1892
1893 } // namespace
1894
1895 #ifdef SIGPIPE
1896 ///////////////////////////////////////////////////////////////////////////
1897 // init_unit_test_suite
1898 ///////////////////////////////////////////////////////////////////////////
1899 namespace {
1900 struct Initializer {
1901   Initializer() {
1902     signal(SIGPIPE, SIG_IGN);
1903   }
1904 };
1905 Initializer initializer;
1906 } // anonymous
1907 #endif