0920122930bbf15557ac6850bef01a41a95419fb
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <pthread.h>
19 #include <signal.h>
20
21 #include <folly/SocketAddress.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/EventBase.h>
24 #include <folly/portability/GMock.h>
25 #include <folly/portability/GTest.h>
26 #include <folly/portability/Sockets.h>
27 #include <folly/portability/Unistd.h>
28
29 #include <folly/io/async/test/BlockingSocket.h>
30
31 #include <fcntl.h>
32 #include <folly/io/Cursor.h>
33 #include <openssl/bio.h>
34 #include <sys/types.h>
35 #include <fstream>
36 #include <iostream>
37 #include <list>
38 #include <set>
39 #include <thread>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 using namespace testing;
49
50 namespace folly {
51 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
52 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
53 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
54
55 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
56 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
57 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
58
59 constexpr size_t SSLClient::kMaxReadBufferSz;
60 constexpr size_t SSLClient::kMaxReadsPerEvent;
61
62 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
63     : ctx_(new folly::SSLContext),
64       acb_(acb),
65       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
66   // Set up the SSL context
67   ctx_->loadCertificate(testCert);
68   ctx_->loadPrivateKey(testKey);
69   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
70
71   acb_->ctx_ = ctx_;
72   acb_->base_ = &evb_;
73
74   // Enable TFO
75   if (enableTFO) {
76     LOG(INFO) << "server TFO enabled";
77     socket_->setTFOEnabled(true, 1000);
78   }
79
80   // set up the listening socket
81   socket_->bind(0);
82   socket_->getAddress(&address_);
83   socket_->listen(100);
84   socket_->addAcceptCallback(acb_, &evb_);
85   socket_->startAccepting();
86
87   int ret = pthread_create(&thread_, nullptr, Main, this);
88   assert(ret == 0);
89   (void)ret;
90
91   std::cerr << "Accepting connections on " << address_ << std::endl;
92 }
93
94 void getfds(int fds[2]) {
95   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
96     FAIL() << "failed to create socketpair: " << strerror(errno);
97   }
98   for (int idx = 0; idx < 2; ++idx) {
99     int flags = fcntl(fds[idx], F_GETFL, 0);
100     if (flags == -1) {
101       FAIL() << "failed to get flags for socket " << idx << ": "
102              << strerror(errno);
103     }
104     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
105       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
106              << strerror(errno);
107     }
108   }
109 }
110
111 void getctx(
112   std::shared_ptr<folly::SSLContext> clientCtx,
113   std::shared_ptr<folly::SSLContext> serverCtx) {
114   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
115
116   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
117   serverCtx->loadCertificate(
118       testCert);
119   serverCtx->loadPrivateKey(
120       testKey);
121 }
122
123 void sslsocketpair(
124   EventBase* eventBase,
125   AsyncSSLSocket::UniquePtr* clientSock,
126   AsyncSSLSocket::UniquePtr* serverSock) {
127   auto clientCtx = std::make_shared<folly::SSLContext>();
128   auto serverCtx = std::make_shared<folly::SSLContext>();
129   int fds[2];
130   getfds(fds);
131   getctx(clientCtx, serverCtx);
132   clientSock->reset(new AsyncSSLSocket(
133                       clientCtx, eventBase, fds[0], false));
134   serverSock->reset(new AsyncSSLSocket(
135                       serverCtx, eventBase, fds[1], true));
136
137   // (*clientSock)->setSendTimeout(100);
138   // (*serverSock)->setSendTimeout(100);
139 }
140
141 // client protocol filters
142 bool clientProtoFilterPickPony(unsigned char** client,
143   unsigned int* client_len, const unsigned char*, unsigned int ) {
144   //the protocol string in length prefixed byte string. the
145   //length byte is not included in the length
146   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
147   *client = p;
148   *client_len = 7;
149   return true;
150 }
151
152 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
153   const unsigned char*, unsigned int) {
154   return false;
155 }
156
157 std::string getFileAsBuf(const char* fileName) {
158   std::string buffer;
159   folly::readFile(fileName, buffer);
160   return buffer;
161 }
162
163 std::string getCommonName(X509* cert) {
164   X509_NAME* subject = X509_get_subject_name(cert);
165   std::string cn;
166   cn.resize(ub_common_name);
167   X509_NAME_get_text_by_NID(
168       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
169   return cn;
170 }
171
172 /**
173  * Test connecting to, writing to, reading from, and closing the
174  * connection to the SSL server.
175  */
176 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
177   // Start listening on a local port
178   WriteCallbackBase writeCallback;
179   ReadCallback readCallback(&writeCallback);
180   HandshakeCallback handshakeCallback(&readCallback);
181   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
182   TestSSLServer server(&acceptCallback);
183
184   // Set up SSL context.
185   std::shared_ptr<SSLContext> sslContext(new SSLContext());
186   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
187   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
188   //sslContext->authenticate(true, false);
189
190   // connect
191   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
192                                                  sslContext);
193   socket->open();
194
195   // write()
196   uint8_t buf[128];
197   memset(buf, 'a', sizeof(buf));
198   socket->write(buf, sizeof(buf));
199
200   // read()
201   uint8_t readbuf[128];
202   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
203   EXPECT_EQ(bytesRead, 128);
204   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
205
206   // close()
207   socket->close();
208
209   cerr << "ConnectWriteReadClose test completed" << endl;
210 }
211
212 /**
213  * Test reading after server close.
214  */
215 TEST(AsyncSSLSocketTest, ReadAfterClose) {
216   // Start listening on a local port
217   WriteCallbackBase writeCallback;
218   ReadEOFCallback readCallback(&writeCallback);
219   HandshakeCallback handshakeCallback(&readCallback);
220   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
221   auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
222
223   // Set up SSL context.
224   auto sslContext = std::make_shared<SSLContext>();
225   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
226
227   auto socket =
228       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
229   socket->open();
230
231   // This should trigger an EOF on the client.
232   auto evb = handshakeCallback.getSocket()->getEventBase();
233   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
234   std::array<uint8_t, 128> readbuf;
235   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
236   EXPECT_EQ(0, bytesRead);
237 }
238
239 /**
240  * Test bad renegotiation
241  */
242 TEST(AsyncSSLSocketTest, Renegotiate) {
243   EventBase eventBase;
244   auto clientCtx = std::make_shared<SSLContext>();
245   auto dfServerCtx = std::make_shared<SSLContext>();
246   std::array<int, 2> fds;
247   getfds(fds.data());
248   getctx(clientCtx, dfServerCtx);
249
250   AsyncSSLSocket::UniquePtr clientSock(
251       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
252   AsyncSSLSocket::UniquePtr serverSock(
253       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
254   SSLHandshakeClient client(std::move(clientSock), true, true);
255   RenegotiatingServer server(std::move(serverSock));
256
257   while (!client.handshakeSuccess_ && !client.handshakeError_) {
258     eventBase.loopOnce();
259   }
260
261   ASSERT_TRUE(client.handshakeSuccess_);
262
263   auto sslSock = std::move(client).moveSocket();
264   sslSock->detachEventBase();
265   // This is nasty, however we don't want to add support for
266   // renegotiation in AsyncSSLSocket.
267   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
268
269   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
270
271   std::thread t([&]() { eventBase.loopForever(); });
272
273   // Trigger the renegotiation.
274   std::array<uint8_t, 128> buf;
275   memset(buf.data(), 'a', buf.size());
276   try {
277     socket->write(buf.data(), buf.size());
278   } catch (AsyncSocketException& e) {
279     LOG(INFO) << "client got error " << e.what();
280   }
281   eventBase.terminateLoopSoon();
282   t.join();
283
284   eventBase.loop();
285   ASSERT_TRUE(server.renegotiationError_);
286 }
287
288 /**
289  * Negative test for handshakeError().
290  */
291 TEST(AsyncSSLSocketTest, HandshakeError) {
292   // Start listening on a local port
293   WriteCallbackBase writeCallback;
294   WriteErrorCallback readCallback(&writeCallback);
295   HandshakeCallback handshakeCallback(&readCallback);
296   HandshakeErrorCallback acceptCallback(&handshakeCallback);
297   TestSSLServer server(&acceptCallback);
298
299   // Set up SSL context.
300   std::shared_ptr<SSLContext> sslContext(new SSLContext());
301   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
302
303   // connect
304   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
305                                                  sslContext);
306   // read()
307   bool ex = false;
308   try {
309     socket->open();
310
311     uint8_t readbuf[128];
312     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
313     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
314   } catch (AsyncSocketException &e) {
315     ex = true;
316   }
317   EXPECT_TRUE(ex);
318
319   // close()
320   socket->close();
321   cerr << "HandshakeError test completed" << endl;
322 }
323
324 /**
325  * Negative test for readError().
326  */
327 TEST(AsyncSSLSocketTest, ReadError) {
328   // Start listening on a local port
329   WriteCallbackBase writeCallback;
330   ReadErrorCallback readCallback(&writeCallback);
331   HandshakeCallback handshakeCallback(&readCallback);
332   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
333   TestSSLServer server(&acceptCallback);
334
335   // Set up SSL context.
336   std::shared_ptr<SSLContext> sslContext(new SSLContext());
337   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
338
339   // connect
340   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
341                                                  sslContext);
342   socket->open();
343
344   // write something to trigger ssl handshake
345   uint8_t buf[128];
346   memset(buf, 'a', sizeof(buf));
347   socket->write(buf, sizeof(buf));
348
349   socket->close();
350   cerr << "ReadError test completed" << endl;
351 }
352
353 /**
354  * Negative test for writeError().
355  */
356 TEST(AsyncSSLSocketTest, WriteError) {
357   // Start listening on a local port
358   WriteCallbackBase writeCallback;
359   WriteErrorCallback readCallback(&writeCallback);
360   HandshakeCallback handshakeCallback(&readCallback);
361   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
362   TestSSLServer server(&acceptCallback);
363
364   // Set up SSL context.
365   std::shared_ptr<SSLContext> sslContext(new SSLContext());
366   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
367
368   // connect
369   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
370                                                  sslContext);
371   socket->open();
372
373   // write something to trigger ssl handshake
374   uint8_t buf[128];
375   memset(buf, 'a', sizeof(buf));
376   socket->write(buf, sizeof(buf));
377
378   socket->close();
379   cerr << "WriteError test completed" << endl;
380 }
381
382 /**
383  * Test a socket with TCP_NODELAY unset.
384  */
385 TEST(AsyncSSLSocketTest, SocketWithDelay) {
386   // Start listening on a local port
387   WriteCallbackBase writeCallback;
388   ReadCallback readCallback(&writeCallback);
389   HandshakeCallback handshakeCallback(&readCallback);
390   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
391   TestSSLServer server(&acceptCallback);
392
393   // Set up SSL context.
394   std::shared_ptr<SSLContext> sslContext(new SSLContext());
395   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
396
397   // connect
398   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
399                                                  sslContext);
400   socket->open();
401
402   // write()
403   uint8_t buf[128];
404   memset(buf, 'a', sizeof(buf));
405   socket->write(buf, sizeof(buf));
406
407   // read()
408   uint8_t readbuf[128];
409   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
410   EXPECT_EQ(bytesRead, 128);
411   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
412
413   // close()
414   socket->close();
415
416   cerr << "SocketWithDelay test completed" << endl;
417 }
418
419 using NextProtocolTypePair =
420     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
421
422 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
423   // For matching protos
424  public:
425   void SetUp() override { getctx(clientCtx, serverCtx); }
426
427   void connect(bool unset = false) {
428     getfds(fds);
429
430     if (unset) {
431       // unsetting NPN for any of [client, server] is enough to make NPN not
432       // work
433       clientCtx->unsetNextProtocols();
434     }
435
436     AsyncSSLSocket::UniquePtr clientSock(
437       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
438     AsyncSSLSocket::UniquePtr serverSock(
439       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
440     client = folly::make_unique<NpnClient>(std::move(clientSock));
441     server = folly::make_unique<NpnServer>(std::move(serverSock));
442
443     eventBase.loop();
444   }
445
446   void expectProtocol(const std::string& proto) {
447     EXPECT_NE(client->nextProtoLength, 0);
448     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
449     EXPECT_EQ(
450         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
451         0);
452     string selected((const char*)client->nextProto, client->nextProtoLength);
453     EXPECT_EQ(proto, selected);
454   }
455
456   void expectNoProtocol() {
457     EXPECT_EQ(client->nextProtoLength, 0);
458     EXPECT_EQ(server->nextProtoLength, 0);
459     EXPECT_EQ(client->nextProto, nullptr);
460     EXPECT_EQ(server->nextProto, nullptr);
461   }
462
463   void expectProtocolType() {
464     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
465         GetParam().second == SSLContext::NextProtocolType::ANY) {
466       EXPECT_EQ(client->protocolType, server->protocolType);
467     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
468                GetParam().second == SSLContext::NextProtocolType::ANY) {
469       // Well not much we can say
470     } else {
471       expectProtocolType(GetParam());
472     }
473   }
474
475   void expectProtocolType(NextProtocolTypePair expected) {
476     EXPECT_EQ(client->protocolType, expected.first);
477     EXPECT_EQ(server->protocolType, expected.second);
478   }
479
480   EventBase eventBase;
481   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
482   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
483   int fds[2];
484   std::unique_ptr<NpnClient> client;
485   std::unique_ptr<NpnServer> server;
486 };
487
488 class NextProtocolTLSExtTest : public NextProtocolTest {
489   // For extended TLS protos
490 };
491
492 class NextProtocolNPNOnlyTest : public NextProtocolTest {
493   // For mismatching protos
494 };
495
496 class NextProtocolMismatchTest : public NextProtocolTest {
497   // For mismatching protos
498 };
499
500 TEST_P(NextProtocolTest, NpnTestOverlap) {
501   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
502   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
503                                         GetParam().second);
504
505   connect();
506
507   expectProtocol("baz");
508   expectProtocolType();
509 }
510
511 TEST_P(NextProtocolTest, NpnTestUnset) {
512   // Identical to above test, except that we want unset NPN before
513   // looping.
514   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
515   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
516                                         GetParam().second);
517
518   connect(true /* unset */);
519
520   // if alpn negotiation fails, type will appear as npn
521   expectNoProtocol();
522   EXPECT_EQ(client->protocolType, server->protocolType);
523 }
524
525 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
526   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
527   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
528                                         GetParam().second);
529
530   connect();
531
532   expectNoProtocol();
533   expectProtocolType(
534       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
535 }
536
537 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
538 // will fail on 1.0.2 before that.
539 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
540   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
541   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
542                                         GetParam().second);
543
544   connect();
545
546   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
547       GetParam().second == SSLContext::NextProtocolType::ALPN) {
548     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
549     // mismatch should result in a fatal alert, but this is OpenSSL's current
550     // behavior and we want to know if it changes.
551     expectNoProtocol();
552   } else {
553     expectProtocol("blub");
554     expectProtocolType(
555         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
556   }
557 }
558
559 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
560   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
561   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
562   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
563                                         GetParam().second);
564
565   connect();
566
567   expectProtocol("ponies");
568   expectProtocolType();
569 }
570
571 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
572   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
573   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
574   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
575                                         GetParam().second);
576
577   connect();
578
579   expectProtocol("blub");
580   expectProtocolType();
581 }
582
583 TEST_P(NextProtocolTest, RandomizedNpnTest) {
584   // Probability that this test will fail is 2^-64, which could be considered
585   // as negligible.
586   const int kTries = 64;
587
588   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
589                                         GetParam().first);
590   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
591                                                   GetParam().second);
592
593   std::set<string> selectedProtocols;
594   for (int i = 0; i < kTries; ++i) {
595     connect();
596
597     EXPECT_NE(client->nextProtoLength, 0);
598     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
599     EXPECT_EQ(
600         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
601         0);
602     string selected((const char*)client->nextProto, client->nextProtoLength);
603     selectedProtocols.insert(selected);
604     expectProtocolType();
605   }
606   EXPECT_EQ(selectedProtocols.size(), 2);
607 }
608
609 INSTANTIATE_TEST_CASE_P(
610     AsyncSSLSocketTest,
611     NextProtocolTest,
612     ::testing::Values(
613         NextProtocolTypePair(
614             SSLContext::NextProtocolType::NPN,
615             SSLContext::NextProtocolType::NPN),
616         NextProtocolTypePair(
617             SSLContext::NextProtocolType::NPN,
618             SSLContext::NextProtocolType::ANY),
619         NextProtocolTypePair(
620             SSLContext::NextProtocolType::ANY,
621             SSLContext::NextProtocolType::ANY)));
622
623 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
624 INSTANTIATE_TEST_CASE_P(
625     AsyncSSLSocketTest,
626     NextProtocolTLSExtTest,
627     ::testing::Values(
628         NextProtocolTypePair(
629             SSLContext::NextProtocolType::ALPN,
630             SSLContext::NextProtocolType::ALPN),
631         NextProtocolTypePair(
632             SSLContext::NextProtocolType::ALPN,
633             SSLContext::NextProtocolType::ANY),
634         NextProtocolTypePair(
635             SSLContext::NextProtocolType::ANY,
636             SSLContext::NextProtocolType::ALPN)));
637 #endif
638
639 INSTANTIATE_TEST_CASE_P(
640     AsyncSSLSocketTest,
641     NextProtocolNPNOnlyTest,
642     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
643                                            SSLContext::NextProtocolType::NPN)));
644
645 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
646 INSTANTIATE_TEST_CASE_P(
647     AsyncSSLSocketTest,
648     NextProtocolMismatchTest,
649     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
650                                            SSLContext::NextProtocolType::ALPN),
651                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
652                                            SSLContext::NextProtocolType::NPN)));
653 #endif
654
655 #ifndef OPENSSL_NO_TLSEXT
656 /**
657  * 1. Client sends TLSEXT_HOSTNAME in client hello.
658  * 2. Server found a match SSL_CTX and use this SSL_CTX to
659  *    continue the SSL handshake.
660  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
661  */
662 TEST(AsyncSSLSocketTest, SNITestMatch) {
663   EventBase eventBase;
664   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
665   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
666   // Use the same SSLContext to continue the handshake after
667   // tlsext_hostname match.
668   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
669   const std::string serverName("xyz.newdev.facebook.com");
670   int fds[2];
671   getfds(fds);
672   getctx(clientCtx, dfServerCtx);
673
674   AsyncSSLSocket::UniquePtr clientSock(
675     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
676   AsyncSSLSocket::UniquePtr serverSock(
677     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
678   SNIClient client(std::move(clientSock));
679   SNIServer server(std::move(serverSock),
680                    dfServerCtx,
681                    hskServerCtx,
682                    serverName);
683
684   eventBase.loop();
685
686   EXPECT_TRUE(client.serverNameMatch);
687   EXPECT_TRUE(server.serverNameMatch);
688 }
689
690 /**
691  * 1. Client sends TLSEXT_HOSTNAME in client hello.
692  * 2. Server cannot find a matching SSL_CTX and continue to use
693  *    the current SSL_CTX to do the handshake.
694  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
695  */
696 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
697   EventBase eventBase;
698   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
699   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
700   // Use the same SSLContext to continue the handshake after
701   // tlsext_hostname match.
702   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
703   const std::string clientRequestingServerName("foo.com");
704   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
705
706   int fds[2];
707   getfds(fds);
708   getctx(clientCtx, dfServerCtx);
709
710   AsyncSSLSocket::UniquePtr clientSock(
711     new AsyncSSLSocket(clientCtx,
712                         &eventBase,
713                         fds[0],
714                         clientRequestingServerName));
715   AsyncSSLSocket::UniquePtr serverSock(
716     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
717   SNIClient client(std::move(clientSock));
718   SNIServer server(std::move(serverSock),
719                    dfServerCtx,
720                    hskServerCtx,
721                    serverExpectedServerName);
722
723   eventBase.loop();
724
725   EXPECT_TRUE(!client.serverNameMatch);
726   EXPECT_TRUE(!server.serverNameMatch);
727 }
728 /**
729  * 1. Client sends TLSEXT_HOSTNAME in client hello.
730  * 2. We then change the serverName.
731  * 3. We expect that we get 'false' as the result for serNameMatch.
732  */
733
734 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
735    EventBase eventBase;
736   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
737   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
738   // Use the same SSLContext to continue the handshake after
739   // tlsext_hostname match.
740   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
741   const std::string serverName("xyz.newdev.facebook.com");
742   int fds[2];
743   getfds(fds);
744   getctx(clientCtx, dfServerCtx);
745
746   AsyncSSLSocket::UniquePtr clientSock(
747     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
748   //Change the server name
749   std::string newName("new.com");
750   clientSock->setServerName(newName);
751   AsyncSSLSocket::UniquePtr serverSock(
752     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
753   SNIClient client(std::move(clientSock));
754   SNIServer server(std::move(serverSock),
755                    dfServerCtx,
756                    hskServerCtx,
757                    serverName);
758
759   eventBase.loop();
760
761   EXPECT_TRUE(!client.serverNameMatch);
762 }
763
764 /**
765  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
766  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
767  */
768 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
769   EventBase eventBase;
770   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
771   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
772   // Use the same SSLContext to continue the handshake after
773   // tlsext_hostname match.
774   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
775   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
776
777   int fds[2];
778   getfds(fds);
779   getctx(clientCtx, dfServerCtx);
780
781   AsyncSSLSocket::UniquePtr clientSock(
782     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
783   AsyncSSLSocket::UniquePtr serverSock(
784     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
785   SNIClient client(std::move(clientSock));
786   SNIServer server(std::move(serverSock),
787                    dfServerCtx,
788                    hskServerCtx,
789                    serverExpectedServerName);
790
791   eventBase.loop();
792
793   EXPECT_TRUE(!client.serverNameMatch);
794   EXPECT_TRUE(!server.serverNameMatch);
795 }
796
797 #endif
798 /**
799  * Test SSL client socket
800  */
801 TEST(AsyncSSLSocketTest, SSLClientTest) {
802   // Start listening on a local port
803   WriteCallbackBase writeCallback;
804   ReadCallback readCallback(&writeCallback);
805   HandshakeCallback handshakeCallback(&readCallback);
806   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
807   TestSSLServer server(&acceptCallback);
808
809   // Set up SSL client
810   EventBase eventBase;
811   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
812
813   client->connect();
814   EventBaseAborter eba(&eventBase, 3000);
815   eventBase.loop();
816
817   EXPECT_EQ(client->getMiss(), 1);
818   EXPECT_EQ(client->getHit(), 0);
819
820   cerr << "SSLClientTest test completed" << endl;
821 }
822
823
824 /**
825  * Test SSL client socket session re-use
826  */
827 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
828   // Start listening on a local port
829   WriteCallbackBase writeCallback;
830   ReadCallback readCallback(&writeCallback);
831   HandshakeCallback handshakeCallback(&readCallback);
832   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
833   TestSSLServer server(&acceptCallback);
834
835   // Set up SSL client
836   EventBase eventBase;
837   auto client =
838       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
839
840   client->connect();
841   EventBaseAborter eba(&eventBase, 3000);
842   eventBase.loop();
843
844   EXPECT_EQ(client->getMiss(), 1);
845   EXPECT_EQ(client->getHit(), 9);
846
847   cerr << "SSLClientTestReuse test completed" << endl;
848 }
849
850 /**
851  * Test SSL client socket timeout
852  */
853 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
854   // Start listening on a local port
855   EmptyReadCallback readCallback;
856   HandshakeCallback handshakeCallback(&readCallback,
857                                       HandshakeCallback::EXPECT_ERROR);
858   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
859   TestSSLServer server(&acceptCallback);
860
861   // Set up SSL client
862   EventBase eventBase;
863   auto client =
864       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
865   client->connect(true /* write before connect completes */);
866   EventBaseAborter eba(&eventBase, 3000);
867   eventBase.loop();
868
869   usleep(100000);
870   // This is checking that the connectError callback precedes any queued
871   // writeError callbacks.  This matches AsyncSocket's behavior
872   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
873   EXPECT_EQ(client->getErrors(), 1);
874   EXPECT_EQ(client->getMiss(), 0);
875   EXPECT_EQ(client->getHit(), 0);
876
877   cerr << "SSLClientTimeoutTest test completed" << endl;
878 }
879
880 // This is a FB-only extension, and the tests will fail without it
881 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
882 /**
883  * Test SSL server async cache
884  */
885 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
886   // Start listening on a local port
887   WriteCallbackBase writeCallback;
888   ReadCallback readCallback(&writeCallback);
889   HandshakeCallback handshakeCallback(&readCallback);
890   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
891   TestSSLAsyncCacheServer server(&acceptCallback);
892
893   // Set up SSL client
894   EventBase eventBase;
895   auto client =
896       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
897
898   client->connect();
899   EventBaseAborter eba(&eventBase, 3000);
900   eventBase.loop();
901
902   EXPECT_EQ(server.getAsyncCallbacks(), 18);
903   EXPECT_EQ(server.getAsyncLookups(), 9);
904   EXPECT_EQ(client->getMiss(), 10);
905   EXPECT_EQ(client->getHit(), 0);
906
907   cerr << "SSLServerAsyncCacheTest test completed" << endl;
908 }
909
910
911 /**
912  * Test SSL server accept timeout with cache path
913  */
914 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
915   // Start listening on a local port
916   WriteCallbackBase writeCallback;
917   ReadCallback readCallback(&writeCallback);
918   EmptyReadCallback clientReadCallback;
919   HandshakeCallback handshakeCallback(&readCallback);
920   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
921   TestSSLAsyncCacheServer server(&acceptCallback);
922
923   // Set up SSL client
924   EventBase eventBase;
925   // only do a TCP connect
926   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
927   sock->connect(nullptr, server.getAddress());
928   clientReadCallback.tcpSocket_ = sock;
929   sock->setReadCB(&clientReadCallback);
930
931   EventBaseAborter eba(&eventBase, 3000);
932   eventBase.loop();
933
934   EXPECT_EQ(readCallback.state, STATE_WAITING);
935
936   cerr << "SSLServerTimeoutTest test completed" << endl;
937 }
938
939 /**
940  * Test SSL server accept timeout with cache path
941  */
942 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
943   // Start listening on a local port
944   WriteCallbackBase writeCallback;
945   ReadCallback readCallback(&writeCallback);
946   HandshakeCallback handshakeCallback(&readCallback);
947   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
948   TestSSLAsyncCacheServer server(&acceptCallback);
949
950   // Set up SSL client
951   EventBase eventBase;
952   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
953
954   client->connect();
955   EventBaseAborter eba(&eventBase, 3000);
956   eventBase.loop();
957
958   EXPECT_EQ(server.getAsyncCallbacks(), 1);
959   EXPECT_EQ(server.getAsyncLookups(), 1);
960   EXPECT_EQ(client->getErrors(), 1);
961   EXPECT_EQ(client->getMiss(), 1);
962   EXPECT_EQ(client->getHit(), 0);
963
964   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
965 }
966
967 /**
968  * Test SSL server accept timeout with cache path
969  */
970 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
971   // Start listening on a local port
972   WriteCallbackBase writeCallback;
973   ReadCallback readCallback(&writeCallback);
974   HandshakeCallback handshakeCallback(&readCallback,
975                                       HandshakeCallback::EXPECT_ERROR);
976   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
977   TestSSLAsyncCacheServer server(&acceptCallback, 500);
978
979   // Set up SSL client
980   EventBase eventBase;
981   auto client =
982       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
983
984   client->connect();
985   EventBaseAborter eba(&eventBase, 3000);
986   eventBase.loop();
987
988   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
989       handshakeCallback.closeSocket();});
990   // give time for the cache lookup to come back and find it closed
991   handshakeCallback.waitForHandshake();
992
993   EXPECT_EQ(server.getAsyncCallbacks(), 1);
994   EXPECT_EQ(server.getAsyncLookups(), 1);
995   EXPECT_EQ(client->getErrors(), 1);
996   EXPECT_EQ(client->getMiss(), 1);
997   EXPECT_EQ(client->getHit(), 0);
998
999   cerr << "SSLServerCacheCloseTest test completed" << endl;
1000 }
1001 #endif
1002
1003 /**
1004  * Verify Client Ciphers obtained using SSL MSG Callback.
1005  */
1006 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1007   EventBase eventBase;
1008   auto clientCtx = std::make_shared<SSLContext>();
1009   auto serverCtx = std::make_shared<SSLContext>();
1010   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1011   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
1012   serverCtx->loadPrivateKey(testKey);
1013   serverCtx->loadCertificate(testCert);
1014   serverCtx->loadTrustedCertificates(testCA);
1015   serverCtx->loadClientCAList(testCA);
1016
1017   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1018   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
1019   clientCtx->loadPrivateKey(testKey);
1020   clientCtx->loadCertificate(testCert);
1021   clientCtx->loadTrustedCertificates(testCA);
1022
1023   int fds[2];
1024   getfds(fds);
1025
1026   AsyncSSLSocket::UniquePtr clientSock(
1027       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1028   AsyncSSLSocket::UniquePtr serverSock(
1029       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1030
1031   SSLHandshakeClient client(std::move(clientSock), true, true);
1032   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1033
1034   eventBase.loop();
1035
1036   EXPECT_EQ(server.clientCiphers_,
1037             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
1038   EXPECT_TRUE(client.handshakeVerify_);
1039   EXPECT_TRUE(client.handshakeSuccess_);
1040   EXPECT_TRUE(!client.handshakeError_);
1041   EXPECT_TRUE(server.handshakeVerify_);
1042   EXPECT_TRUE(server.handshakeSuccess_);
1043   EXPECT_TRUE(!server.handshakeError_);
1044 }
1045
1046 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1047   EventBase eventBase;
1048   auto ctx = std::make_shared<SSLContext>();
1049
1050   int fds[2];
1051   getfds(fds);
1052
1053   int bufLen = 42;
1054   uint8_t majorVersion = 18;
1055   uint8_t minorVersion = 25;
1056
1057   // Create callback buf
1058   auto buf = IOBuf::create(bufLen);
1059   buf->append(bufLen);
1060   folly::io::RWPrivateCursor cursor(buf.get());
1061   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1062   cursor.write<uint16_t>(0);
1063   cursor.write<uint8_t>(38);
1064   cursor.write<uint8_t>(majorVersion);
1065   cursor.write<uint8_t>(minorVersion);
1066   cursor.skip(32);
1067   cursor.write<uint32_t>(0);
1068
1069   SSL* ssl = ctx->createSSL();
1070   SCOPE_EXIT { SSL_free(ssl); };
1071   AsyncSSLSocket::UniquePtr sock(
1072       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1073   sock->enableClientHelloParsing();
1074
1075   // Test client hello parsing in one packet
1076   AsyncSSLSocket::clientHelloParsingCallback(
1077       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1078   buf.reset();
1079
1080   auto parsedClientHello = sock->getClientHelloInfo();
1081   EXPECT_TRUE(parsedClientHello != nullptr);
1082   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1083   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1084 }
1085
1086 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1087   EventBase eventBase;
1088   auto ctx = std::make_shared<SSLContext>();
1089
1090   int fds[2];
1091   getfds(fds);
1092
1093   int bufLen = 42;
1094   uint8_t majorVersion = 18;
1095   uint8_t minorVersion = 25;
1096
1097   // Create callback buf
1098   auto buf = IOBuf::create(bufLen);
1099   buf->append(bufLen);
1100   folly::io::RWPrivateCursor cursor(buf.get());
1101   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1102   cursor.write<uint16_t>(0);
1103   cursor.write<uint8_t>(38);
1104   cursor.write<uint8_t>(majorVersion);
1105   cursor.write<uint8_t>(minorVersion);
1106   cursor.skip(32);
1107   cursor.write<uint32_t>(0);
1108
1109   SSL* ssl = ctx->createSSL();
1110   SCOPE_EXIT { SSL_free(ssl); };
1111   AsyncSSLSocket::UniquePtr sock(
1112       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1113   sock->enableClientHelloParsing();
1114
1115   // Test parsing with two packets with first packet size < 3
1116   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1117   AsyncSSLSocket::clientHelloParsingCallback(
1118       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1119       ssl, sock.get());
1120   bufCopy.reset();
1121   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1122   AsyncSSLSocket::clientHelloParsingCallback(
1123       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1124       ssl, sock.get());
1125   bufCopy.reset();
1126
1127   auto parsedClientHello = sock->getClientHelloInfo();
1128   EXPECT_TRUE(parsedClientHello != nullptr);
1129   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1130   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1131 }
1132
1133 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1134   EventBase eventBase;
1135   auto ctx = std::make_shared<SSLContext>();
1136
1137   int fds[2];
1138   getfds(fds);
1139
1140   int bufLen = 42;
1141   uint8_t majorVersion = 18;
1142   uint8_t minorVersion = 25;
1143
1144   // Create callback buf
1145   auto buf = IOBuf::create(bufLen);
1146   buf->append(bufLen);
1147   folly::io::RWPrivateCursor cursor(buf.get());
1148   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1149   cursor.write<uint16_t>(0);
1150   cursor.write<uint8_t>(38);
1151   cursor.write<uint8_t>(majorVersion);
1152   cursor.write<uint8_t>(minorVersion);
1153   cursor.skip(32);
1154   cursor.write<uint32_t>(0);
1155
1156   SSL* ssl = ctx->createSSL();
1157   SCOPE_EXIT { SSL_free(ssl); };
1158   AsyncSSLSocket::UniquePtr sock(
1159       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1160   sock->enableClientHelloParsing();
1161
1162   // Test parsing with multiple small packets
1163   for (uint64_t i = 0; i < buf->length(); i += 3) {
1164     auto bufCopy = folly::IOBuf::copyBuffer(
1165         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1166     AsyncSSLSocket::clientHelloParsingCallback(
1167         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1168         ssl, sock.get());
1169     bufCopy.reset();
1170   }
1171
1172   auto parsedClientHello = sock->getClientHelloInfo();
1173   EXPECT_TRUE(parsedClientHello != nullptr);
1174   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1175   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1176 }
1177
1178 /**
1179  * Verify sucessful behavior of SSL certificate validation.
1180  */
1181 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1182   EventBase eventBase;
1183   auto clientCtx = std::make_shared<SSLContext>();
1184   auto dfServerCtx = std::make_shared<SSLContext>();
1185
1186   int fds[2];
1187   getfds(fds);
1188   getctx(clientCtx, dfServerCtx);
1189
1190   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1191   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1192
1193   AsyncSSLSocket::UniquePtr clientSock(
1194     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1195   AsyncSSLSocket::UniquePtr serverSock(
1196     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1197
1198   SSLHandshakeClient client(std::move(clientSock), true, true);
1199   clientCtx->loadTrustedCertificates(testCA);
1200
1201   SSLHandshakeServer server(std::move(serverSock), true, true);
1202
1203   eventBase.loop();
1204
1205   EXPECT_TRUE(client.handshakeVerify_);
1206   EXPECT_TRUE(client.handshakeSuccess_);
1207   EXPECT_TRUE(!client.handshakeError_);
1208   EXPECT_LE(0, client.handshakeTime.count());
1209   EXPECT_TRUE(!server.handshakeVerify_);
1210   EXPECT_TRUE(server.handshakeSuccess_);
1211   EXPECT_TRUE(!server.handshakeError_);
1212   EXPECT_LE(0, server.handshakeTime.count());
1213 }
1214
1215 /**
1216  * Verify that the client's verification callback is able to fail SSL
1217  * connection establishment.
1218  */
1219 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1220   EventBase eventBase;
1221   auto clientCtx = std::make_shared<SSLContext>();
1222   auto dfServerCtx = std::make_shared<SSLContext>();
1223
1224   int fds[2];
1225   getfds(fds);
1226   getctx(clientCtx, dfServerCtx);
1227
1228   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1229   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1230
1231   AsyncSSLSocket::UniquePtr clientSock(
1232     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1233   AsyncSSLSocket::UniquePtr serverSock(
1234     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1235
1236   SSLHandshakeClient client(std::move(clientSock), true, false);
1237   clientCtx->loadTrustedCertificates(testCA);
1238
1239   SSLHandshakeServer server(std::move(serverSock), true, true);
1240
1241   eventBase.loop();
1242
1243   EXPECT_TRUE(client.handshakeVerify_);
1244   EXPECT_TRUE(!client.handshakeSuccess_);
1245   EXPECT_TRUE(client.handshakeError_);
1246   EXPECT_LE(0, client.handshakeTime.count());
1247   EXPECT_TRUE(!server.handshakeVerify_);
1248   EXPECT_TRUE(!server.handshakeSuccess_);
1249   EXPECT_TRUE(server.handshakeError_);
1250   EXPECT_LE(0, server.handshakeTime.count());
1251 }
1252
1253 /**
1254  * Verify that the options in SSLContext can be overridden in
1255  * sslConnect/Accept.i.e specifying that no validation should be performed
1256  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1257  * the validation callback.
1258  */
1259 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1260   EventBase eventBase;
1261   auto clientCtx = std::make_shared<SSLContext>();
1262   auto dfServerCtx = std::make_shared<SSLContext>();
1263
1264   int fds[2];
1265   getfds(fds);
1266   getctx(clientCtx, dfServerCtx);
1267
1268   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1269   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1270
1271   AsyncSSLSocket::UniquePtr clientSock(
1272     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1273   AsyncSSLSocket::UniquePtr serverSock(
1274     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1275
1276   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1277   clientCtx->loadTrustedCertificates(testCA);
1278
1279   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1280
1281   eventBase.loop();
1282
1283   EXPECT_TRUE(!client.handshakeVerify_);
1284   EXPECT_TRUE(client.handshakeSuccess_);
1285   EXPECT_TRUE(!client.handshakeError_);
1286   EXPECT_LE(0, client.handshakeTime.count());
1287   EXPECT_TRUE(!server.handshakeVerify_);
1288   EXPECT_TRUE(server.handshakeSuccess_);
1289   EXPECT_TRUE(!server.handshakeError_);
1290   EXPECT_LE(0, server.handshakeTime.count());
1291 }
1292
1293 /**
1294  * Verify that the options in SSLContext can be overridden in
1295  * sslConnect/Accept. Enable verification even if context says otherwise.
1296  * Test requireClientCert with client cert
1297  */
1298 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1299   EventBase eventBase;
1300   auto clientCtx = std::make_shared<SSLContext>();
1301   auto serverCtx = std::make_shared<SSLContext>();
1302   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1303   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1304   serverCtx->loadPrivateKey(testKey);
1305   serverCtx->loadCertificate(testCert);
1306   serverCtx->loadTrustedCertificates(testCA);
1307   serverCtx->loadClientCAList(testCA);
1308
1309   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1310   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1311   clientCtx->loadPrivateKey(testKey);
1312   clientCtx->loadCertificate(testCert);
1313   clientCtx->loadTrustedCertificates(testCA);
1314
1315   int fds[2];
1316   getfds(fds);
1317
1318   AsyncSSLSocket::UniquePtr clientSock(
1319       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1320   AsyncSSLSocket::UniquePtr serverSock(
1321       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1322
1323   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1324   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1325
1326   eventBase.loop();
1327
1328   EXPECT_TRUE(client.handshakeVerify_);
1329   EXPECT_TRUE(client.handshakeSuccess_);
1330   EXPECT_FALSE(client.handshakeError_);
1331   EXPECT_LE(0, client.handshakeTime.count());
1332   EXPECT_TRUE(server.handshakeVerify_);
1333   EXPECT_TRUE(server.handshakeSuccess_);
1334   EXPECT_FALSE(server.handshakeError_);
1335   EXPECT_LE(0, server.handshakeTime.count());
1336 }
1337
1338 /**
1339  * Verify that the client's verification callback is able to override
1340  * the preverification failure and allow a successful connection.
1341  */
1342 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1343   EventBase eventBase;
1344   auto clientCtx = std::make_shared<SSLContext>();
1345   auto dfServerCtx = std::make_shared<SSLContext>();
1346
1347   int fds[2];
1348   getfds(fds);
1349   getctx(clientCtx, dfServerCtx);
1350
1351   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1352   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1353
1354   AsyncSSLSocket::UniquePtr clientSock(
1355     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1356   AsyncSSLSocket::UniquePtr serverSock(
1357     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1358
1359   SSLHandshakeClient client(std::move(clientSock), false, true);
1360   SSLHandshakeServer server(std::move(serverSock), true, true);
1361
1362   eventBase.loop();
1363
1364   EXPECT_TRUE(client.handshakeVerify_);
1365   EXPECT_TRUE(client.handshakeSuccess_);
1366   EXPECT_TRUE(!client.handshakeError_);
1367   EXPECT_LE(0, client.handshakeTime.count());
1368   EXPECT_TRUE(!server.handshakeVerify_);
1369   EXPECT_TRUE(server.handshakeSuccess_);
1370   EXPECT_TRUE(!server.handshakeError_);
1371   EXPECT_LE(0, server.handshakeTime.count());
1372 }
1373
1374 /**
1375  * Verify that specifying that no validation should be performed allows an
1376  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1377  * callback.
1378  */
1379 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1380   EventBase eventBase;
1381   auto clientCtx = std::make_shared<SSLContext>();
1382   auto dfServerCtx = std::make_shared<SSLContext>();
1383
1384   int fds[2];
1385   getfds(fds);
1386   getctx(clientCtx, dfServerCtx);
1387
1388   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1389   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1390
1391   AsyncSSLSocket::UniquePtr clientSock(
1392     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1393   AsyncSSLSocket::UniquePtr serverSock(
1394     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1395
1396   SSLHandshakeClient client(std::move(clientSock), false, false);
1397   SSLHandshakeServer server(std::move(serverSock), false, false);
1398
1399   eventBase.loop();
1400
1401   EXPECT_TRUE(!client.handshakeVerify_);
1402   EXPECT_TRUE(client.handshakeSuccess_);
1403   EXPECT_TRUE(!client.handshakeError_);
1404   EXPECT_LE(0, client.handshakeTime.count());
1405   EXPECT_TRUE(!server.handshakeVerify_);
1406   EXPECT_TRUE(server.handshakeSuccess_);
1407   EXPECT_TRUE(!server.handshakeError_);
1408   EXPECT_LE(0, server.handshakeTime.count());
1409 }
1410
1411 /**
1412  * Test requireClientCert with client cert
1413  */
1414 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1415   EventBase eventBase;
1416   auto clientCtx = std::make_shared<SSLContext>();
1417   auto serverCtx = std::make_shared<SSLContext>();
1418   serverCtx->setVerificationOption(
1419       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1420   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1421   serverCtx->loadPrivateKey(testKey);
1422   serverCtx->loadCertificate(testCert);
1423   serverCtx->loadTrustedCertificates(testCA);
1424   serverCtx->loadClientCAList(testCA);
1425
1426   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1427   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1428   clientCtx->loadPrivateKey(testKey);
1429   clientCtx->loadCertificate(testCert);
1430   clientCtx->loadTrustedCertificates(testCA);
1431
1432   int fds[2];
1433   getfds(fds);
1434
1435   AsyncSSLSocket::UniquePtr clientSock(
1436       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1437   AsyncSSLSocket::UniquePtr serverSock(
1438       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1439
1440   SSLHandshakeClient client(std::move(clientSock), true, true);
1441   SSLHandshakeServer server(std::move(serverSock), true, true);
1442
1443   eventBase.loop();
1444
1445   EXPECT_TRUE(client.handshakeVerify_);
1446   EXPECT_TRUE(client.handshakeSuccess_);
1447   EXPECT_FALSE(client.handshakeError_);
1448   EXPECT_LE(0, client.handshakeTime.count());
1449   EXPECT_TRUE(server.handshakeVerify_);
1450   EXPECT_TRUE(server.handshakeSuccess_);
1451   EXPECT_FALSE(server.handshakeError_);
1452   EXPECT_LE(0, server.handshakeTime.count());
1453 }
1454
1455
1456 /**
1457  * Test requireClientCert with no client cert
1458  */
1459 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1460   EventBase eventBase;
1461   auto clientCtx = std::make_shared<SSLContext>();
1462   auto serverCtx = std::make_shared<SSLContext>();
1463   serverCtx->setVerificationOption(
1464       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1465   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1466   serverCtx->loadPrivateKey(testKey);
1467   serverCtx->loadCertificate(testCert);
1468   serverCtx->loadTrustedCertificates(testCA);
1469   serverCtx->loadClientCAList(testCA);
1470   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1471   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1472
1473   int fds[2];
1474   getfds(fds);
1475
1476   AsyncSSLSocket::UniquePtr clientSock(
1477       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1478   AsyncSSLSocket::UniquePtr serverSock(
1479       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1480
1481   SSLHandshakeClient client(std::move(clientSock), false, false);
1482   SSLHandshakeServer server(std::move(serverSock), false, false);
1483
1484   eventBase.loop();
1485
1486   EXPECT_FALSE(server.handshakeVerify_);
1487   EXPECT_FALSE(server.handshakeSuccess_);
1488   EXPECT_TRUE(server.handshakeError_);
1489   EXPECT_LE(0, client.handshakeTime.count());
1490   EXPECT_LE(0, server.handshakeTime.count());
1491 }
1492
1493 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1494   auto cert = getFileAsBuf(testCert);
1495   auto key = getFileAsBuf(testKey);
1496
1497   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1498   BIO_write(certBio.get(), cert.data(), cert.size());
1499   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1500   BIO_write(keyBio.get(), key.data(), key.size());
1501
1502   // Create SSL structs from buffers to get properties
1503   ssl::X509UniquePtr certStruct(
1504       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1505   ssl::EvpPkeyUniquePtr keyStruct(
1506       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1507   certBio = nullptr;
1508   keyBio = nullptr;
1509
1510   auto origCommonName = getCommonName(certStruct.get());
1511   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1512   certStruct = nullptr;
1513   keyStruct = nullptr;
1514
1515   auto ctx = std::make_shared<SSLContext>();
1516   ctx->loadPrivateKeyFromBufferPEM(key);
1517   ctx->loadCertificateFromBufferPEM(cert);
1518   ctx->loadTrustedCertificates(testCA);
1519
1520   ssl::SSLUniquePtr ssl(ctx->createSSL());
1521
1522   auto newCert = SSL_get_certificate(ssl.get());
1523   auto newKey = SSL_get_privatekey(ssl.get());
1524
1525   // Get properties from SSL struct
1526   auto newCommonName = getCommonName(newCert);
1527   auto newKeySize = EVP_PKEY_bits(newKey);
1528
1529   // Check that the key and cert have the expected properties
1530   EXPECT_EQ(origCommonName, newCommonName);
1531   EXPECT_EQ(origKeySize, newKeySize);
1532 }
1533
1534 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1535   EventBase eb;
1536
1537   // Set up SSL context.
1538   auto sslContext = std::make_shared<SSLContext>();
1539   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1540
1541   // create SSL socket
1542   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1543
1544   EXPECT_EQ(1500, socket->getMinWriteSize());
1545
1546   socket->setMinWriteSize(0);
1547   EXPECT_EQ(0, socket->getMinWriteSize());
1548   socket->setMinWriteSize(50000);
1549   EXPECT_EQ(50000, socket->getMinWriteSize());
1550 }
1551
1552 class ReadCallbackTerminator : public ReadCallback {
1553  public:
1554   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1555       : ReadCallback(wcb)
1556       , base_(base) {}
1557
1558   // Do not write data back, terminate the loop.
1559   void readDataAvailable(size_t len) noexcept override {
1560     std::cerr << "readDataAvailable, len " << len << std::endl;
1561
1562     currentBuffer.length = len;
1563
1564     buffers.push_back(currentBuffer);
1565     currentBuffer.reset();
1566     state = STATE_SUCCEEDED;
1567
1568     socket_->setReadCB(nullptr);
1569     base_->terminateLoopSoon();
1570   }
1571  private:
1572   EventBase* base_;
1573 };
1574
1575
1576 /**
1577  * Test a full unencrypted codepath
1578  */
1579 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1580   EventBase base;
1581
1582   auto clientCtx = std::make_shared<folly::SSLContext>();
1583   auto serverCtx = std::make_shared<folly::SSLContext>();
1584   int fds[2];
1585   getfds(fds);
1586   getctx(clientCtx, serverCtx);
1587   auto client = AsyncSSLSocket::newSocket(
1588                   clientCtx, &base, fds[0], false, true);
1589   auto server = AsyncSSLSocket::newSocket(
1590                   serverCtx, &base, fds[1], true, true);
1591
1592   ReadCallbackTerminator readCallback(&base, nullptr);
1593   server->setReadCB(&readCallback);
1594   readCallback.setSocket(server);
1595
1596   uint8_t buf[128];
1597   memset(buf, 'a', sizeof(buf));
1598   client->write(nullptr, buf, sizeof(buf));
1599
1600   // Check that bytes are unencrypted
1601   char c;
1602   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1603   EXPECT_EQ('a', c);
1604
1605   EventBaseAborter eba(&base, 3000);
1606   base.loop();
1607
1608   EXPECT_EQ(1, readCallback.buffers.size());
1609   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1610
1611   server->setReadCB(&readCallback);
1612
1613   // Unencrypted
1614   server->sslAccept(nullptr);
1615   client->sslConn(nullptr);
1616
1617   // Do NOT wait for handshake, writing should be queued and happen after
1618
1619   client->write(nullptr, buf, sizeof(buf));
1620
1621   // Check that bytes are *not* unencrypted
1622   char c2;
1623   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1624   EXPECT_NE('a', c2);
1625
1626
1627   base.loop();
1628
1629   EXPECT_EQ(2, readCallback.buffers.size());
1630   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1631 }
1632
1633 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1634   // Start listening on a local port
1635   WriteCallbackBase writeCallback;
1636   WriteErrorCallback readCallback(&writeCallback);
1637   HandshakeCallback handshakeCallback(&readCallback,
1638                                       HandshakeCallback::EXPECT_ERROR);
1639   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1640   TestSSLServer server(&acceptCallback);
1641
1642   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1643   socket->open();
1644   uint8_t buf[3] = {0x16, 0x03, 0x01};
1645   socket->write(buf, sizeof(buf));
1646   socket->closeWithReset();
1647
1648   handshakeCallback.waitForHandshake();
1649   EXPECT_NE(
1650       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1651   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1652 }
1653
1654 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1655   // Start listening on a local port
1656   WriteCallbackBase writeCallback;
1657   WriteErrorCallback readCallback(&writeCallback);
1658   HandshakeCallback handshakeCallback(&readCallback,
1659                                       HandshakeCallback::EXPECT_ERROR);
1660   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1661   TestSSLServer server(&acceptCallback);
1662
1663   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1664   socket->open();
1665   uint8_t buf[3] = {0x16, 0x03, 0x01};
1666   socket->write(buf, sizeof(buf));
1667   socket->close();
1668
1669   handshakeCallback.waitForHandshake();
1670   EXPECT_NE(
1671       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1672   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1673 }
1674
1675 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1676   // Start listening on a local port
1677   WriteCallbackBase writeCallback;
1678   WriteErrorCallback readCallback(&writeCallback);
1679   HandshakeCallback handshakeCallback(&readCallback,
1680                                       HandshakeCallback::EXPECT_ERROR);
1681   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1682   TestSSLServer server(&acceptCallback);
1683
1684   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1685   socket->open();
1686   uint8_t buf[256] = {0x16, 0x03};
1687   memset(buf + 2, 'a', sizeof(buf) - 2);
1688   socket->write(buf, sizeof(buf));
1689   socket->close();
1690
1691   handshakeCallback.waitForHandshake();
1692   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1693             std::string::npos);
1694   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1695             std::string::npos);
1696 }
1697
1698 #if FOLLY_ALLOW_TFO
1699
1700 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1701  public:
1702   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1703
1704   explicit MockAsyncTFOSSLSocket(
1705       std::shared_ptr<folly::SSLContext> sslCtx,
1706       EventBase* evb)
1707       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1708
1709   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1710 };
1711
1712 /**
1713  * Test connecting to, writing to, reading from, and closing the
1714  * connection to the SSL server with TFO.
1715  */
1716 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1717   // Start listening on a local port
1718   WriteCallbackBase writeCallback;
1719   ReadCallback readCallback(&writeCallback);
1720   HandshakeCallback handshakeCallback(&readCallback);
1721   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1722   TestSSLServer server(&acceptCallback, true);
1723
1724   // Set up SSL context.
1725   auto sslContext = std::make_shared<SSLContext>();
1726
1727   // connect
1728   auto socket =
1729       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1730   socket->enableTFO();
1731   socket->open();
1732
1733   // write()
1734   std::array<uint8_t, 128> buf;
1735   memset(buf.data(), 'a', buf.size());
1736   socket->write(buf.data(), buf.size());
1737
1738   // read()
1739   std::array<uint8_t, 128> readbuf;
1740   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1741   EXPECT_EQ(bytesRead, 128);
1742   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1743
1744   // close()
1745   socket->close();
1746 }
1747
1748 /**
1749  * Test connecting to, writing to, reading from, and closing the
1750  * connection to the SSL server with TFO.
1751  */
1752 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1753   // Start listening on a local port
1754   WriteCallbackBase writeCallback;
1755   ReadCallback readCallback(&writeCallback);
1756   HandshakeCallback handshakeCallback(&readCallback);
1757   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1758   TestSSLServer server(&acceptCallback, false);
1759
1760   // Set up SSL context.
1761   auto sslContext = std::make_shared<SSLContext>();
1762
1763   // connect
1764   auto socket =
1765       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1766   socket->enableTFO();
1767   socket->open();
1768
1769   // write()
1770   std::array<uint8_t, 128> buf;
1771   memset(buf.data(), 'a', buf.size());
1772   socket->write(buf.data(), buf.size());
1773
1774   // read()
1775   std::array<uint8_t, 128> readbuf;
1776   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1777   EXPECT_EQ(bytesRead, 128);
1778   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1779
1780   // close()
1781   socket->close();
1782 }
1783
1784 class ConnCallback : public AsyncSocket::ConnectCallback {
1785  public:
1786   virtual void connectSuccess() noexcept override {
1787     state = State::SUCCESS;
1788   }
1789
1790   virtual void connectErr(const AsyncSocketException& ex) noexcept override {
1791     state = State::ERROR;
1792     error = ex.what();
1793   }
1794
1795   enum class State { WAITING, SUCCESS, ERROR };
1796
1797   State state{State::WAITING};
1798   std::string error;
1799 };
1800
1801 template <class Cardinality>
1802 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1803     EventBase* evb,
1804     const SocketAddress& address,
1805     Cardinality cardinality) {
1806   // Set up SSL context.
1807   auto sslContext = std::make_shared<SSLContext>();
1808
1809   // connect
1810   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1811       new MockAsyncTFOSSLSocket(sslContext, evb));
1812   socket->enableTFO();
1813
1814   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1815       .Times(cardinality)
1816       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1817         sockaddr_storage addr;
1818         auto len = address.getAddress(&addr);
1819         return connect(fd, (const struct sockaddr*)&addr, len);
1820       }));
1821   return socket;
1822 }
1823
1824 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1825   // Start listening on a local port
1826   WriteCallbackBase writeCallback;
1827   ReadCallback readCallback(&writeCallback);
1828   HandshakeCallback handshakeCallback(&readCallback);
1829   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1830   TestSSLServer server(&acceptCallback, true);
1831
1832   EventBase evb;
1833
1834   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1835   ConnCallback ccb;
1836   socket->connect(&ccb, server.getAddress(), 30);
1837
1838   evb.loop();
1839   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1840
1841   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1842   evb.loop();
1843
1844   BlockingSocket sock(std::move(socket));
1845   // write()
1846   std::array<uint8_t, 128> buf;
1847   memset(buf.data(), 'a', buf.size());
1848   sock.write(buf.data(), buf.size());
1849
1850   // read()
1851   std::array<uint8_t, 128> readbuf;
1852   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1853   EXPECT_EQ(bytesRead, 128);
1854   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1855
1856   // close()
1857   sock.close();
1858 }
1859
1860 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1861   // Start listening on a local port
1862   ConnectTimeoutCallback acceptCallback;
1863   TestSSLServer server(&acceptCallback, true);
1864
1865   // Set up SSL context.
1866   auto sslContext = std::make_shared<SSLContext>();
1867
1868   // connect
1869   auto socket =
1870       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1871   socket->enableTFO();
1872   EXPECT_THROW(
1873       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
1874 }
1875
1876 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1877   // Start listening on a local port
1878   ConnectTimeoutCallback acceptCallback;
1879   TestSSLServer server(&acceptCallback, true);
1880
1881   EventBase evb;
1882
1883   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1884   ConnCallback ccb;
1885   // Set a short timeout
1886   socket->connect(&ccb, server.getAddress(), 1);
1887
1888   evb.loop();
1889   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1890 }
1891
1892 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
1893   // Start listening on a local port
1894   EmptyReadCallback readCallback;
1895   HandshakeCallback handshakeCallback(
1896       &readCallback, HandshakeCallback::EXPECT_ERROR);
1897   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
1898   TestSSLServer server(&acceptCallback, true);
1899
1900   EventBase evb;
1901
1902   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1903   ConnCallback ccb;
1904   socket->connect(&ccb, server.getAddress(), 100);
1905
1906   evb.loop();
1907   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1908   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
1909 }
1910
1911 #endif
1912
1913 } // namespace
1914
1915 #ifdef SIGPIPE
1916 ///////////////////////////////////////////////////////////////////////////
1917 // init_unit_test_suite
1918 ///////////////////////////////////////////////////////////////////////////
1919 namespace {
1920 struct Initializer {
1921   Initializer() {
1922     signal(SIGPIPE, SIG_IGN);
1923   }
1924 };
1925 Initializer initializer;
1926 } // anonymous
1927 #endif