27805c7b3cfb8d8c717916a4fa11d11bad81fcdb
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <pthread.h>
19 #include <signal.h>
20
21 #include <folly/SocketAddress.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/EventBase.h>
24 #include <folly/portability/Sockets.h>
25 #include <folly/portability/Unistd.h>
26
27 #include <folly/io/async/test/BlockingSocket.h>
28
29 #include <fcntl.h>
30 #include <folly/io/Cursor.h>
31 #include <gtest/gtest.h>
32 #include <openssl/bio.h>
33 #include <sys/types.h>
34 #include <fstream>
35 #include <iostream>
36 #include <list>
37 #include <set>
38
39 #include <gmock/gmock.h>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 using namespace testing;
49
50 namespace folly {
51 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
52 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
53 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
54
55 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
56 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
57 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
58
59 constexpr size_t SSLClient::kMaxReadBufferSz;
60 constexpr size_t SSLClient::kMaxReadsPerEvent;
61
62 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
63     : ctx_(new folly::SSLContext),
64       acb_(acb),
65       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
66   // Set up the SSL context
67   ctx_->loadCertificate(testCert);
68   ctx_->loadPrivateKey(testKey);
69   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
70
71   acb_->ctx_ = ctx_;
72   acb_->base_ = &evb_;
73
74   // Enable TFO
75   if (enableTFO) {
76     LOG(INFO) << "server TFO enabled";
77     socket_->setTFOEnabled(true, 1000);
78   }
79
80   // set up the listening socket
81   socket_->bind(0);
82   socket_->getAddress(&address_);
83   socket_->listen(100);
84   socket_->addAcceptCallback(acb_, &evb_);
85   socket_->startAccepting();
86
87   int ret = pthread_create(&thread_, nullptr, Main, this);
88   assert(ret == 0);
89   (void)ret;
90
91   std::cerr << "Accepting connections on " << address_ << std::endl;
92 }
93
94 void getfds(int fds[2]) {
95   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
96     FAIL() << "failed to create socketpair: " << strerror(errno);
97   }
98   for (int idx = 0; idx < 2; ++idx) {
99     int flags = fcntl(fds[idx], F_GETFL, 0);
100     if (flags == -1) {
101       FAIL() << "failed to get flags for socket " << idx << ": "
102              << strerror(errno);
103     }
104     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
105       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
106              << strerror(errno);
107     }
108   }
109 }
110
111 void getctx(
112   std::shared_ptr<folly::SSLContext> clientCtx,
113   std::shared_ptr<folly::SSLContext> serverCtx) {
114   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
115
116   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
117   serverCtx->loadCertificate(
118       testCert);
119   serverCtx->loadPrivateKey(
120       testKey);
121 }
122
123 void sslsocketpair(
124   EventBase* eventBase,
125   AsyncSSLSocket::UniquePtr* clientSock,
126   AsyncSSLSocket::UniquePtr* serverSock) {
127   auto clientCtx = std::make_shared<folly::SSLContext>();
128   auto serverCtx = std::make_shared<folly::SSLContext>();
129   int fds[2];
130   getfds(fds);
131   getctx(clientCtx, serverCtx);
132   clientSock->reset(new AsyncSSLSocket(
133                       clientCtx, eventBase, fds[0], false));
134   serverSock->reset(new AsyncSSLSocket(
135                       serverCtx, eventBase, fds[1], true));
136
137   // (*clientSock)->setSendTimeout(100);
138   // (*serverSock)->setSendTimeout(100);
139 }
140
141 // client protocol filters
142 bool clientProtoFilterPickPony(unsigned char** client,
143   unsigned int* client_len, const unsigned char*, unsigned int ) {
144   //the protocol string in length prefixed byte string. the
145   //length byte is not included in the length
146   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
147   *client = p;
148   *client_len = 7;
149   return true;
150 }
151
152 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
153   const unsigned char*, unsigned int) {
154   return false;
155 }
156
157 std::string getFileAsBuf(const char* fileName) {
158   std::string buffer;
159   folly::readFile(fileName, buffer);
160   return buffer;
161 }
162
163 std::string getCommonName(X509* cert) {
164   X509_NAME* subject = X509_get_subject_name(cert);
165   std::string cn;
166   cn.resize(ub_common_name);
167   X509_NAME_get_text_by_NID(
168       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
169   return cn;
170 }
171
172 /**
173  * Test connecting to, writing to, reading from, and closing the
174  * connection to the SSL server.
175  */
176 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
177   // Start listening on a local port
178   WriteCallbackBase writeCallback;
179   ReadCallback readCallback(&writeCallback);
180   HandshakeCallback handshakeCallback(&readCallback);
181   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
182   TestSSLServer server(&acceptCallback);
183
184   // Set up SSL context.
185   std::shared_ptr<SSLContext> sslContext(new SSLContext());
186   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
187   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
188   //sslContext->authenticate(true, false);
189
190   // connect
191   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
192                                                  sslContext);
193   socket->open();
194
195   // write()
196   uint8_t buf[128];
197   memset(buf, 'a', sizeof(buf));
198   socket->write(buf, sizeof(buf));
199
200   // read()
201   uint8_t readbuf[128];
202   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
203   EXPECT_EQ(bytesRead, 128);
204   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
205
206   // close()
207   socket->close();
208
209   cerr << "ConnectWriteReadClose test completed" << endl;
210 }
211
212 /**
213  * Test reading after server close.
214  */
215 TEST(AsyncSSLSocketTest, ReadAfterClose) {
216   // Start listening on a local port
217   WriteCallbackBase writeCallback;
218   ReadEOFCallback readCallback(&writeCallback);
219   HandshakeCallback handshakeCallback(&readCallback);
220   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
221   auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
222
223   // Set up SSL context.
224   auto sslContext = std::make_shared<SSLContext>();
225   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
226
227   auto socket =
228       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
229   socket->open();
230
231   // This should trigger an EOF on the client.
232   auto evb = handshakeCallback.getSocket()->getEventBase();
233   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
234   std::array<uint8_t, 128> readbuf;
235   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
236   EXPECT_EQ(0, bytesRead);
237 }
238
239 /**
240  * Test bad renegotiation
241  */
242 TEST(AsyncSSLSocketTest, Renegotiate) {
243   EventBase eventBase;
244   auto clientCtx = std::make_shared<SSLContext>();
245   auto dfServerCtx = std::make_shared<SSLContext>();
246   std::array<int, 2> fds;
247   getfds(fds.data());
248   getctx(clientCtx, dfServerCtx);
249
250   AsyncSSLSocket::UniquePtr clientSock(
251       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
252   AsyncSSLSocket::UniquePtr serverSock(
253       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
254   SSLHandshakeClient client(std::move(clientSock), true, true);
255   RenegotiatingServer server(std::move(serverSock));
256
257   while (!client.handshakeSuccess_ && !client.handshakeError_) {
258     eventBase.loopOnce();
259   }
260
261   ASSERT_TRUE(client.handshakeSuccess_);
262
263   auto sslSock = std::move(client).moveSocket();
264   sslSock->detachEventBase();
265   // This is nasty, however we don't want to add support for
266   // renegotiation in AsyncSSLSocket.
267   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
268
269   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
270
271   std::thread t([&]() { eventBase.loopForever(); });
272
273   // Trigger the renegotiation.
274   std::array<uint8_t, 128> buf;
275   memset(buf.data(), 'a', buf.size());
276   try {
277     socket->write(buf.data(), buf.size());
278   } catch (AsyncSocketException& e) {
279     LOG(INFO) << "client got error " << e.what();
280   }
281   eventBase.terminateLoopSoon();
282   t.join();
283
284   eventBase.loop();
285   ASSERT_TRUE(server.renegotiationError_);
286 }
287
288 /**
289  * Negative test for handshakeError().
290  */
291 TEST(AsyncSSLSocketTest, HandshakeError) {
292   // Start listening on a local port
293   WriteCallbackBase writeCallback;
294   WriteErrorCallback readCallback(&writeCallback);
295   HandshakeCallback handshakeCallback(&readCallback);
296   HandshakeErrorCallback acceptCallback(&handshakeCallback);
297   TestSSLServer server(&acceptCallback);
298
299   // Set up SSL context.
300   std::shared_ptr<SSLContext> sslContext(new SSLContext());
301   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
302
303   // connect
304   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
305                                                  sslContext);
306   // read()
307   bool ex = false;
308   try {
309     socket->open();
310
311     uint8_t readbuf[128];
312     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
313     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
314   } catch (AsyncSocketException &e) {
315     ex = true;
316   }
317   EXPECT_TRUE(ex);
318
319   // close()
320   socket->close();
321   cerr << "HandshakeError test completed" << endl;
322 }
323
324 /**
325  * Negative test for readError().
326  */
327 TEST(AsyncSSLSocketTest, ReadError) {
328   // Start listening on a local port
329   WriteCallbackBase writeCallback;
330   ReadErrorCallback readCallback(&writeCallback);
331   HandshakeCallback handshakeCallback(&readCallback);
332   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
333   TestSSLServer server(&acceptCallback);
334
335   // Set up SSL context.
336   std::shared_ptr<SSLContext> sslContext(new SSLContext());
337   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
338
339   // connect
340   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
341                                                  sslContext);
342   socket->open();
343
344   // write something to trigger ssl handshake
345   uint8_t buf[128];
346   memset(buf, 'a', sizeof(buf));
347   socket->write(buf, sizeof(buf));
348
349   socket->close();
350   cerr << "ReadError test completed" << endl;
351 }
352
353 /**
354  * Negative test for writeError().
355  */
356 TEST(AsyncSSLSocketTest, WriteError) {
357   // Start listening on a local port
358   WriteCallbackBase writeCallback;
359   WriteErrorCallback readCallback(&writeCallback);
360   HandshakeCallback handshakeCallback(&readCallback);
361   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
362   TestSSLServer server(&acceptCallback);
363
364   // Set up SSL context.
365   std::shared_ptr<SSLContext> sslContext(new SSLContext());
366   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
367
368   // connect
369   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
370                                                  sslContext);
371   socket->open();
372
373   // write something to trigger ssl handshake
374   uint8_t buf[128];
375   memset(buf, 'a', sizeof(buf));
376   socket->write(buf, sizeof(buf));
377
378   socket->close();
379   cerr << "WriteError test completed" << endl;
380 }
381
382 /**
383  * Test a socket with TCP_NODELAY unset.
384  */
385 TEST(AsyncSSLSocketTest, SocketWithDelay) {
386   // Start listening on a local port
387   WriteCallbackBase writeCallback;
388   ReadCallback readCallback(&writeCallback);
389   HandshakeCallback handshakeCallback(&readCallback);
390   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
391   TestSSLServer server(&acceptCallback);
392
393   // Set up SSL context.
394   std::shared_ptr<SSLContext> sslContext(new SSLContext());
395   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
396
397   // connect
398   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
399                                                  sslContext);
400   socket->open();
401
402   // write()
403   uint8_t buf[128];
404   memset(buf, 'a', sizeof(buf));
405   socket->write(buf, sizeof(buf));
406
407   // read()
408   uint8_t readbuf[128];
409   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
410   EXPECT_EQ(bytesRead, 128);
411   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
412
413   // close()
414   socket->close();
415
416   cerr << "SocketWithDelay test completed" << endl;
417 }
418
419 using NextProtocolTypePair =
420     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
421
422 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
423   // For matching protos
424  public:
425   void SetUp() override { getctx(clientCtx, serverCtx); }
426
427   void connect(bool unset = false) {
428     getfds(fds);
429
430     if (unset) {
431       // unsetting NPN for any of [client, server] is enough to make NPN not
432       // work
433       clientCtx->unsetNextProtocols();
434     }
435
436     AsyncSSLSocket::UniquePtr clientSock(
437       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
438     AsyncSSLSocket::UniquePtr serverSock(
439       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
440     client = folly::make_unique<NpnClient>(std::move(clientSock));
441     server = folly::make_unique<NpnServer>(std::move(serverSock));
442
443     eventBase.loop();
444   }
445
446   void expectProtocol(const std::string& proto) {
447     EXPECT_NE(client->nextProtoLength, 0);
448     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
449     EXPECT_EQ(
450         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
451         0);
452     string selected((const char*)client->nextProto, client->nextProtoLength);
453     EXPECT_EQ(proto, selected);
454   }
455
456   void expectNoProtocol() {
457     EXPECT_EQ(client->nextProtoLength, 0);
458     EXPECT_EQ(server->nextProtoLength, 0);
459     EXPECT_EQ(client->nextProto, nullptr);
460     EXPECT_EQ(server->nextProto, nullptr);
461   }
462
463   void expectProtocolType() {
464     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
465         GetParam().second == SSLContext::NextProtocolType::ANY) {
466       EXPECT_EQ(client->protocolType, server->protocolType);
467     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
468                GetParam().second == SSLContext::NextProtocolType::ANY) {
469       // Well not much we can say
470     } else {
471       expectProtocolType(GetParam());
472     }
473   }
474
475   void expectProtocolType(NextProtocolTypePair expected) {
476     EXPECT_EQ(client->protocolType, expected.first);
477     EXPECT_EQ(server->protocolType, expected.second);
478   }
479
480   EventBase eventBase;
481   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
482   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
483   int fds[2];
484   std::unique_ptr<NpnClient> client;
485   std::unique_ptr<NpnServer> server;
486 };
487
488 class NextProtocolTLSExtTest : public NextProtocolTest {
489   // For extended TLS protos
490 };
491
492 class NextProtocolNPNOnlyTest : public NextProtocolTest {
493   // For mismatching protos
494 };
495
496 class NextProtocolMismatchTest : public NextProtocolTest {
497   // For mismatching protos
498 };
499
500 TEST_P(NextProtocolTest, NpnTestOverlap) {
501   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
502   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
503                                         GetParam().second);
504
505   connect();
506
507   expectProtocol("baz");
508   expectProtocolType();
509 }
510
511 TEST_P(NextProtocolTest, NpnTestUnset) {
512   // Identical to above test, except that we want unset NPN before
513   // looping.
514   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
515   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
516                                         GetParam().second);
517
518   connect(true /* unset */);
519
520   // if alpn negotiation fails, type will appear as npn
521   expectNoProtocol();
522   EXPECT_EQ(client->protocolType, server->protocolType);
523 }
524
525 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
526   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
527   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
528                                         GetParam().second);
529
530   connect();
531
532   expectNoProtocol();
533   expectProtocolType(
534       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
535 }
536
537 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
538 // will fail on 1.0.2 before that.
539 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
540   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
541   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
542                                         GetParam().second);
543
544   connect();
545
546   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
547       GetParam().second == SSLContext::NextProtocolType::ALPN) {
548     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
549     // mismatch should result in a fatal alert, but this is OpenSSL's current
550     // behavior and we want to know if it changes.
551     expectNoProtocol();
552   } else {
553     expectProtocol("blub");
554     expectProtocolType(
555         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
556   }
557 }
558
559 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
560   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
561   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
562   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
563                                         GetParam().second);
564
565   connect();
566
567   expectProtocol("ponies");
568   expectProtocolType();
569 }
570
571 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
572   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
573   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
574   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
575                                         GetParam().second);
576
577   connect();
578
579   expectProtocol("blub");
580   expectProtocolType();
581 }
582
583 TEST_P(NextProtocolTest, RandomizedNpnTest) {
584   // Probability that this test will fail is 2^-64, which could be considered
585   // as negligible.
586   const int kTries = 64;
587
588   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
589                                         GetParam().first);
590   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
591                                                   GetParam().second);
592
593   std::set<string> selectedProtocols;
594   for (int i = 0; i < kTries; ++i) {
595     connect();
596
597     EXPECT_NE(client->nextProtoLength, 0);
598     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
599     EXPECT_EQ(
600         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
601         0);
602     string selected((const char*)client->nextProto, client->nextProtoLength);
603     selectedProtocols.insert(selected);
604     expectProtocolType();
605   }
606   EXPECT_EQ(selectedProtocols.size(), 2);
607 }
608
609 INSTANTIATE_TEST_CASE_P(
610     AsyncSSLSocketTest,
611     NextProtocolTest,
612     ::testing::Values(
613         NextProtocolTypePair(
614             SSLContext::NextProtocolType::NPN,
615             SSLContext::NextProtocolType::NPN),
616         NextProtocolTypePair(
617             SSLContext::NextProtocolType::NPN,
618             SSLContext::NextProtocolType::ANY),
619         NextProtocolTypePair(
620             SSLContext::NextProtocolType::ANY,
621             SSLContext::NextProtocolType::ANY)));
622
623 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
624 INSTANTIATE_TEST_CASE_P(
625     AsyncSSLSocketTest,
626     NextProtocolTLSExtTest,
627     ::testing::Values(
628         NextProtocolTypePair(
629             SSLContext::NextProtocolType::ALPN,
630             SSLContext::NextProtocolType::ALPN),
631         NextProtocolTypePair(
632             SSLContext::NextProtocolType::ALPN,
633             SSLContext::NextProtocolType::ANY),
634         NextProtocolTypePair(
635             SSLContext::NextProtocolType::ANY,
636             SSLContext::NextProtocolType::ALPN)));
637 #endif
638
639 INSTANTIATE_TEST_CASE_P(
640     AsyncSSLSocketTest,
641     NextProtocolNPNOnlyTest,
642     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
643                                            SSLContext::NextProtocolType::NPN)));
644
645 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
646 INSTANTIATE_TEST_CASE_P(
647     AsyncSSLSocketTest,
648     NextProtocolMismatchTest,
649     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
650                                            SSLContext::NextProtocolType::ALPN),
651                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
652                                            SSLContext::NextProtocolType::NPN)));
653 #endif
654
655 #ifndef OPENSSL_NO_TLSEXT
656 /**
657  * 1. Client sends TLSEXT_HOSTNAME in client hello.
658  * 2. Server found a match SSL_CTX and use this SSL_CTX to
659  *    continue the SSL handshake.
660  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
661  */
662 TEST(AsyncSSLSocketTest, SNITestMatch) {
663   EventBase eventBase;
664   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
665   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
666   // Use the same SSLContext to continue the handshake after
667   // tlsext_hostname match.
668   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
669   const std::string serverName("xyz.newdev.facebook.com");
670   int fds[2];
671   getfds(fds);
672   getctx(clientCtx, dfServerCtx);
673
674   AsyncSSLSocket::UniquePtr clientSock(
675     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
676   AsyncSSLSocket::UniquePtr serverSock(
677     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
678   SNIClient client(std::move(clientSock));
679   SNIServer server(std::move(serverSock),
680                    dfServerCtx,
681                    hskServerCtx,
682                    serverName);
683
684   eventBase.loop();
685
686   EXPECT_TRUE(client.serverNameMatch);
687   EXPECT_TRUE(server.serverNameMatch);
688 }
689
690 /**
691  * 1. Client sends TLSEXT_HOSTNAME in client hello.
692  * 2. Server cannot find a matching SSL_CTX and continue to use
693  *    the current SSL_CTX to do the handshake.
694  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
695  */
696 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
697   EventBase eventBase;
698   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
699   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
700   // Use the same SSLContext to continue the handshake after
701   // tlsext_hostname match.
702   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
703   const std::string clientRequestingServerName("foo.com");
704   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
705
706   int fds[2];
707   getfds(fds);
708   getctx(clientCtx, dfServerCtx);
709
710   AsyncSSLSocket::UniquePtr clientSock(
711     new AsyncSSLSocket(clientCtx,
712                         &eventBase,
713                         fds[0],
714                         clientRequestingServerName));
715   AsyncSSLSocket::UniquePtr serverSock(
716     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
717   SNIClient client(std::move(clientSock));
718   SNIServer server(std::move(serverSock),
719                    dfServerCtx,
720                    hskServerCtx,
721                    serverExpectedServerName);
722
723   eventBase.loop();
724
725   EXPECT_TRUE(!client.serverNameMatch);
726   EXPECT_TRUE(!server.serverNameMatch);
727 }
728 /**
729  * 1. Client sends TLSEXT_HOSTNAME in client hello.
730  * 2. We then change the serverName.
731  * 3. We expect that we get 'false' as the result for serNameMatch.
732  */
733
734 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
735    EventBase eventBase;
736   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
737   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
738   // Use the same SSLContext to continue the handshake after
739   // tlsext_hostname match.
740   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
741   const std::string serverName("xyz.newdev.facebook.com");
742   int fds[2];
743   getfds(fds);
744   getctx(clientCtx, dfServerCtx);
745
746   AsyncSSLSocket::UniquePtr clientSock(
747     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
748   //Change the server name
749   std::string newName("new.com");
750   clientSock->setServerName(newName);
751   AsyncSSLSocket::UniquePtr serverSock(
752     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
753   SNIClient client(std::move(clientSock));
754   SNIServer server(std::move(serverSock),
755                    dfServerCtx,
756                    hskServerCtx,
757                    serverName);
758
759   eventBase.loop();
760
761   EXPECT_TRUE(!client.serverNameMatch);
762 }
763
764 /**
765  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
766  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
767  */
768 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
769   EventBase eventBase;
770   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
771   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
772   // Use the same SSLContext to continue the handshake after
773   // tlsext_hostname match.
774   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
775   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
776
777   int fds[2];
778   getfds(fds);
779   getctx(clientCtx, dfServerCtx);
780
781   AsyncSSLSocket::UniquePtr clientSock(
782     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
783   AsyncSSLSocket::UniquePtr serverSock(
784     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
785   SNIClient client(std::move(clientSock));
786   SNIServer server(std::move(serverSock),
787                    dfServerCtx,
788                    hskServerCtx,
789                    serverExpectedServerName);
790
791   eventBase.loop();
792
793   EXPECT_TRUE(!client.serverNameMatch);
794   EXPECT_TRUE(!server.serverNameMatch);
795 }
796
797 #endif
798 /**
799  * Test SSL client socket
800  */
801 TEST(AsyncSSLSocketTest, SSLClientTest) {
802   // Start listening on a local port
803   WriteCallbackBase writeCallback;
804   ReadCallback readCallback(&writeCallback);
805   HandshakeCallback handshakeCallback(&readCallback);
806   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
807   TestSSLServer server(&acceptCallback);
808
809   // Set up SSL client
810   EventBase eventBase;
811   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
812
813   client->connect();
814   EventBaseAborter eba(&eventBase, 3000);
815   eventBase.loop();
816
817   EXPECT_EQ(client->getMiss(), 1);
818   EXPECT_EQ(client->getHit(), 0);
819
820   cerr << "SSLClientTest test completed" << endl;
821 }
822
823
824 /**
825  * Test SSL client socket session re-use
826  */
827 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
828   // Start listening on a local port
829   WriteCallbackBase writeCallback;
830   ReadCallback readCallback(&writeCallback);
831   HandshakeCallback handshakeCallback(&readCallback);
832   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
833   TestSSLServer server(&acceptCallback);
834
835   // Set up SSL client
836   EventBase eventBase;
837   auto client =
838       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
839
840   client->connect();
841   EventBaseAborter eba(&eventBase, 3000);
842   eventBase.loop();
843
844   EXPECT_EQ(client->getMiss(), 1);
845   EXPECT_EQ(client->getHit(), 9);
846
847   cerr << "SSLClientTestReuse test completed" << endl;
848 }
849
850 /**
851  * Test SSL client socket timeout
852  */
853 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
854   // Start listening on a local port
855   EmptyReadCallback readCallback;
856   HandshakeCallback handshakeCallback(&readCallback,
857                                       HandshakeCallback::EXPECT_ERROR);
858   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
859   TestSSLServer server(&acceptCallback);
860
861   // Set up SSL client
862   EventBase eventBase;
863   auto client =
864       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
865   client->connect(true /* write before connect completes */);
866   EventBaseAborter eba(&eventBase, 3000);
867   eventBase.loop();
868
869   usleep(100000);
870   // This is checking that the connectError callback precedes any queued
871   // writeError callbacks.  This matches AsyncSocket's behavior
872   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
873   EXPECT_EQ(client->getErrors(), 1);
874   EXPECT_EQ(client->getMiss(), 0);
875   EXPECT_EQ(client->getHit(), 0);
876
877   cerr << "SSLClientTimeoutTest test completed" << endl;
878 }
879
880
881 /**
882  * Test SSL server async cache
883  */
884 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
885   // Start listening on a local port
886   WriteCallbackBase writeCallback;
887   ReadCallback readCallback(&writeCallback);
888   HandshakeCallback handshakeCallback(&readCallback);
889   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
890   TestSSLAsyncCacheServer server(&acceptCallback);
891
892   // Set up SSL client
893   EventBase eventBase;
894   auto client =
895       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
896
897   client->connect();
898   EventBaseAborter eba(&eventBase, 3000);
899   eventBase.loop();
900
901   EXPECT_EQ(server.getAsyncCallbacks(), 18);
902   EXPECT_EQ(server.getAsyncLookups(), 9);
903   EXPECT_EQ(client->getMiss(), 10);
904   EXPECT_EQ(client->getHit(), 0);
905
906   cerr << "SSLServerAsyncCacheTest test completed" << endl;
907 }
908
909
910 /**
911  * Test SSL server accept timeout with cache path
912  */
913 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
914   // Start listening on a local port
915   WriteCallbackBase writeCallback;
916   ReadCallback readCallback(&writeCallback);
917   EmptyReadCallback clientReadCallback;
918   HandshakeCallback handshakeCallback(&readCallback);
919   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
920   TestSSLAsyncCacheServer server(&acceptCallback);
921
922   // Set up SSL client
923   EventBase eventBase;
924   // only do a TCP connect
925   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
926   sock->connect(nullptr, server.getAddress());
927   clientReadCallback.tcpSocket_ = sock;
928   sock->setReadCB(&clientReadCallback);
929
930   EventBaseAborter eba(&eventBase, 3000);
931   eventBase.loop();
932
933   EXPECT_EQ(readCallback.state, STATE_WAITING);
934
935   cerr << "SSLServerTimeoutTest test completed" << endl;
936 }
937
938 /**
939  * Test SSL server accept timeout with cache path
940  */
941 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
942   // Start listening on a local port
943   WriteCallbackBase writeCallback;
944   ReadCallback readCallback(&writeCallback);
945   HandshakeCallback handshakeCallback(&readCallback);
946   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
947   TestSSLAsyncCacheServer server(&acceptCallback);
948
949   // Set up SSL client
950   EventBase eventBase;
951   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
952
953   client->connect();
954   EventBaseAborter eba(&eventBase, 3000);
955   eventBase.loop();
956
957   EXPECT_EQ(server.getAsyncCallbacks(), 1);
958   EXPECT_EQ(server.getAsyncLookups(), 1);
959   EXPECT_EQ(client->getErrors(), 1);
960   EXPECT_EQ(client->getMiss(), 1);
961   EXPECT_EQ(client->getHit(), 0);
962
963   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
964 }
965
966 /**
967  * Test SSL server accept timeout with cache path
968  */
969 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
970   // Start listening on a local port
971   WriteCallbackBase writeCallback;
972   ReadCallback readCallback(&writeCallback);
973   HandshakeCallback handshakeCallback(&readCallback,
974                                       HandshakeCallback::EXPECT_ERROR);
975   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
976   TestSSLAsyncCacheServer server(&acceptCallback, 500);
977
978   // Set up SSL client
979   EventBase eventBase;
980   auto client =
981       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
982
983   client->connect();
984   EventBaseAborter eba(&eventBase, 3000);
985   eventBase.loop();
986
987   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
988       handshakeCallback.closeSocket();});
989   // give time for the cache lookup to come back and find it closed
990   handshakeCallback.waitForHandshake();
991
992   EXPECT_EQ(server.getAsyncCallbacks(), 1);
993   EXPECT_EQ(server.getAsyncLookups(), 1);
994   EXPECT_EQ(client->getErrors(), 1);
995   EXPECT_EQ(client->getMiss(), 1);
996   EXPECT_EQ(client->getHit(), 0);
997
998   cerr << "SSLServerCacheCloseTest test completed" << endl;
999 }
1000
1001 /**
1002  * Verify Client Ciphers obtained using SSL MSG Callback.
1003  */
1004 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1005   EventBase eventBase;
1006   auto clientCtx = std::make_shared<SSLContext>();
1007   auto serverCtx = std::make_shared<SSLContext>();
1008   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1009   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
1010   serverCtx->loadPrivateKey(testKey);
1011   serverCtx->loadCertificate(testCert);
1012   serverCtx->loadTrustedCertificates(testCA);
1013   serverCtx->loadClientCAList(testCA);
1014
1015   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1016   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
1017   clientCtx->loadPrivateKey(testKey);
1018   clientCtx->loadCertificate(testCert);
1019   clientCtx->loadTrustedCertificates(testCA);
1020
1021   int fds[2];
1022   getfds(fds);
1023
1024   AsyncSSLSocket::UniquePtr clientSock(
1025       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1026   AsyncSSLSocket::UniquePtr serverSock(
1027       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1028
1029   SSLHandshakeClient client(std::move(clientSock), true, true);
1030   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1031
1032   eventBase.loop();
1033
1034   EXPECT_EQ(server.clientCiphers_,
1035             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
1036   EXPECT_TRUE(client.handshakeVerify_);
1037   EXPECT_TRUE(client.handshakeSuccess_);
1038   EXPECT_TRUE(!client.handshakeError_);
1039   EXPECT_TRUE(server.handshakeVerify_);
1040   EXPECT_TRUE(server.handshakeSuccess_);
1041   EXPECT_TRUE(!server.handshakeError_);
1042 }
1043
1044 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1045   EventBase eventBase;
1046   auto ctx = std::make_shared<SSLContext>();
1047
1048   int fds[2];
1049   getfds(fds);
1050
1051   int bufLen = 42;
1052   uint8_t majorVersion = 18;
1053   uint8_t minorVersion = 25;
1054
1055   // Create callback buf
1056   auto buf = IOBuf::create(bufLen);
1057   buf->append(bufLen);
1058   folly::io::RWPrivateCursor cursor(buf.get());
1059   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1060   cursor.write<uint16_t>(0);
1061   cursor.write<uint8_t>(38);
1062   cursor.write<uint8_t>(majorVersion);
1063   cursor.write<uint8_t>(minorVersion);
1064   cursor.skip(32);
1065   cursor.write<uint32_t>(0);
1066
1067   SSL* ssl = ctx->createSSL();
1068   SCOPE_EXIT { SSL_free(ssl); };
1069   AsyncSSLSocket::UniquePtr sock(
1070       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1071   sock->enableClientHelloParsing();
1072
1073   // Test client hello parsing in one packet
1074   AsyncSSLSocket::clientHelloParsingCallback(
1075       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1076   buf.reset();
1077
1078   auto parsedClientHello = sock->getClientHelloInfo();
1079   EXPECT_TRUE(parsedClientHello != nullptr);
1080   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1081   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1082 }
1083
1084 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1085   EventBase eventBase;
1086   auto ctx = std::make_shared<SSLContext>();
1087
1088   int fds[2];
1089   getfds(fds);
1090
1091   int bufLen = 42;
1092   uint8_t majorVersion = 18;
1093   uint8_t minorVersion = 25;
1094
1095   // Create callback buf
1096   auto buf = IOBuf::create(bufLen);
1097   buf->append(bufLen);
1098   folly::io::RWPrivateCursor cursor(buf.get());
1099   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1100   cursor.write<uint16_t>(0);
1101   cursor.write<uint8_t>(38);
1102   cursor.write<uint8_t>(majorVersion);
1103   cursor.write<uint8_t>(minorVersion);
1104   cursor.skip(32);
1105   cursor.write<uint32_t>(0);
1106
1107   SSL* ssl = ctx->createSSL();
1108   SCOPE_EXIT { SSL_free(ssl); };
1109   AsyncSSLSocket::UniquePtr sock(
1110       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1111   sock->enableClientHelloParsing();
1112
1113   // Test parsing with two packets with first packet size < 3
1114   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1115   AsyncSSLSocket::clientHelloParsingCallback(
1116       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1117       ssl, sock.get());
1118   bufCopy.reset();
1119   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1120   AsyncSSLSocket::clientHelloParsingCallback(
1121       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1122       ssl, sock.get());
1123   bufCopy.reset();
1124
1125   auto parsedClientHello = sock->getClientHelloInfo();
1126   EXPECT_TRUE(parsedClientHello != nullptr);
1127   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1128   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1129 }
1130
1131 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1132   EventBase eventBase;
1133   auto ctx = std::make_shared<SSLContext>();
1134
1135   int fds[2];
1136   getfds(fds);
1137
1138   int bufLen = 42;
1139   uint8_t majorVersion = 18;
1140   uint8_t minorVersion = 25;
1141
1142   // Create callback buf
1143   auto buf = IOBuf::create(bufLen);
1144   buf->append(bufLen);
1145   folly::io::RWPrivateCursor cursor(buf.get());
1146   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1147   cursor.write<uint16_t>(0);
1148   cursor.write<uint8_t>(38);
1149   cursor.write<uint8_t>(majorVersion);
1150   cursor.write<uint8_t>(minorVersion);
1151   cursor.skip(32);
1152   cursor.write<uint32_t>(0);
1153
1154   SSL* ssl = ctx->createSSL();
1155   SCOPE_EXIT { SSL_free(ssl); };
1156   AsyncSSLSocket::UniquePtr sock(
1157       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1158   sock->enableClientHelloParsing();
1159
1160   // Test parsing with multiple small packets
1161   for (uint64_t i = 0; i < buf->length(); i += 3) {
1162     auto bufCopy = folly::IOBuf::copyBuffer(
1163         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1164     AsyncSSLSocket::clientHelloParsingCallback(
1165         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1166         ssl, sock.get());
1167     bufCopy.reset();
1168   }
1169
1170   auto parsedClientHello = sock->getClientHelloInfo();
1171   EXPECT_TRUE(parsedClientHello != nullptr);
1172   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1173   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1174 }
1175
1176 /**
1177  * Verify sucessful behavior of SSL certificate validation.
1178  */
1179 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1180   EventBase eventBase;
1181   auto clientCtx = std::make_shared<SSLContext>();
1182   auto dfServerCtx = std::make_shared<SSLContext>();
1183
1184   int fds[2];
1185   getfds(fds);
1186   getctx(clientCtx, dfServerCtx);
1187
1188   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1189   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1190
1191   AsyncSSLSocket::UniquePtr clientSock(
1192     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1193   AsyncSSLSocket::UniquePtr serverSock(
1194     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1195
1196   SSLHandshakeClient client(std::move(clientSock), true, true);
1197   clientCtx->loadTrustedCertificates(testCA);
1198
1199   SSLHandshakeServer server(std::move(serverSock), true, true);
1200
1201   eventBase.loop();
1202
1203   EXPECT_TRUE(client.handshakeVerify_);
1204   EXPECT_TRUE(client.handshakeSuccess_);
1205   EXPECT_TRUE(!client.handshakeError_);
1206   EXPECT_LE(0, client.handshakeTime.count());
1207   EXPECT_TRUE(!server.handshakeVerify_);
1208   EXPECT_TRUE(server.handshakeSuccess_);
1209   EXPECT_TRUE(!server.handshakeError_);
1210   EXPECT_LE(0, server.handshakeTime.count());
1211 }
1212
1213 /**
1214  * Verify that the client's verification callback is able to fail SSL
1215  * connection establishment.
1216  */
1217 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1218   EventBase eventBase;
1219   auto clientCtx = std::make_shared<SSLContext>();
1220   auto dfServerCtx = std::make_shared<SSLContext>();
1221
1222   int fds[2];
1223   getfds(fds);
1224   getctx(clientCtx, dfServerCtx);
1225
1226   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1227   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1228
1229   AsyncSSLSocket::UniquePtr clientSock(
1230     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1231   AsyncSSLSocket::UniquePtr serverSock(
1232     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1233
1234   SSLHandshakeClient client(std::move(clientSock), true, false);
1235   clientCtx->loadTrustedCertificates(testCA);
1236
1237   SSLHandshakeServer server(std::move(serverSock), true, true);
1238
1239   eventBase.loop();
1240
1241   EXPECT_TRUE(client.handshakeVerify_);
1242   EXPECT_TRUE(!client.handshakeSuccess_);
1243   EXPECT_TRUE(client.handshakeError_);
1244   EXPECT_LE(0, client.handshakeTime.count());
1245   EXPECT_TRUE(!server.handshakeVerify_);
1246   EXPECT_TRUE(!server.handshakeSuccess_);
1247   EXPECT_TRUE(server.handshakeError_);
1248   EXPECT_LE(0, server.handshakeTime.count());
1249 }
1250
1251 /**
1252  * Verify that the options in SSLContext can be overridden in
1253  * sslConnect/Accept.i.e specifying that no validation should be performed
1254  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1255  * the validation callback.
1256  */
1257 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1258   EventBase eventBase;
1259   auto clientCtx = std::make_shared<SSLContext>();
1260   auto dfServerCtx = std::make_shared<SSLContext>();
1261
1262   int fds[2];
1263   getfds(fds);
1264   getctx(clientCtx, dfServerCtx);
1265
1266   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1267   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1268
1269   AsyncSSLSocket::UniquePtr clientSock(
1270     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1271   AsyncSSLSocket::UniquePtr serverSock(
1272     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1273
1274   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1275   clientCtx->loadTrustedCertificates(testCA);
1276
1277   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1278
1279   eventBase.loop();
1280
1281   EXPECT_TRUE(!client.handshakeVerify_);
1282   EXPECT_TRUE(client.handshakeSuccess_);
1283   EXPECT_TRUE(!client.handshakeError_);
1284   EXPECT_LE(0, client.handshakeTime.count());
1285   EXPECT_TRUE(!server.handshakeVerify_);
1286   EXPECT_TRUE(server.handshakeSuccess_);
1287   EXPECT_TRUE(!server.handshakeError_);
1288   EXPECT_LE(0, server.handshakeTime.count());
1289 }
1290
1291 /**
1292  * Verify that the options in SSLContext can be overridden in
1293  * sslConnect/Accept. Enable verification even if context says otherwise.
1294  * Test requireClientCert with client cert
1295  */
1296 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1297   EventBase eventBase;
1298   auto clientCtx = std::make_shared<SSLContext>();
1299   auto serverCtx = std::make_shared<SSLContext>();
1300   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1301   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1302   serverCtx->loadPrivateKey(testKey);
1303   serverCtx->loadCertificate(testCert);
1304   serverCtx->loadTrustedCertificates(testCA);
1305   serverCtx->loadClientCAList(testCA);
1306
1307   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1308   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1309   clientCtx->loadPrivateKey(testKey);
1310   clientCtx->loadCertificate(testCert);
1311   clientCtx->loadTrustedCertificates(testCA);
1312
1313   int fds[2];
1314   getfds(fds);
1315
1316   AsyncSSLSocket::UniquePtr clientSock(
1317       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1318   AsyncSSLSocket::UniquePtr serverSock(
1319       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1320
1321   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1322   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1323
1324   eventBase.loop();
1325
1326   EXPECT_TRUE(client.handshakeVerify_);
1327   EXPECT_TRUE(client.handshakeSuccess_);
1328   EXPECT_FALSE(client.handshakeError_);
1329   EXPECT_LE(0, client.handshakeTime.count());
1330   EXPECT_TRUE(server.handshakeVerify_);
1331   EXPECT_TRUE(server.handshakeSuccess_);
1332   EXPECT_FALSE(server.handshakeError_);
1333   EXPECT_LE(0, server.handshakeTime.count());
1334 }
1335
1336 /**
1337  * Verify that the client's verification callback is able to override
1338  * the preverification failure and allow a successful connection.
1339  */
1340 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1341   EventBase eventBase;
1342   auto clientCtx = std::make_shared<SSLContext>();
1343   auto dfServerCtx = std::make_shared<SSLContext>();
1344
1345   int fds[2];
1346   getfds(fds);
1347   getctx(clientCtx, dfServerCtx);
1348
1349   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1350   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1351
1352   AsyncSSLSocket::UniquePtr clientSock(
1353     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1354   AsyncSSLSocket::UniquePtr serverSock(
1355     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1356
1357   SSLHandshakeClient client(std::move(clientSock), false, true);
1358   SSLHandshakeServer server(std::move(serverSock), true, true);
1359
1360   eventBase.loop();
1361
1362   EXPECT_TRUE(client.handshakeVerify_);
1363   EXPECT_TRUE(client.handshakeSuccess_);
1364   EXPECT_TRUE(!client.handshakeError_);
1365   EXPECT_LE(0, client.handshakeTime.count());
1366   EXPECT_TRUE(!server.handshakeVerify_);
1367   EXPECT_TRUE(server.handshakeSuccess_);
1368   EXPECT_TRUE(!server.handshakeError_);
1369   EXPECT_LE(0, server.handshakeTime.count());
1370 }
1371
1372 /**
1373  * Verify that specifying that no validation should be performed allows an
1374  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1375  * callback.
1376  */
1377 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1378   EventBase eventBase;
1379   auto clientCtx = std::make_shared<SSLContext>();
1380   auto dfServerCtx = std::make_shared<SSLContext>();
1381
1382   int fds[2];
1383   getfds(fds);
1384   getctx(clientCtx, dfServerCtx);
1385
1386   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1387   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1388
1389   AsyncSSLSocket::UniquePtr clientSock(
1390     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1391   AsyncSSLSocket::UniquePtr serverSock(
1392     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1393
1394   SSLHandshakeClient client(std::move(clientSock), false, false);
1395   SSLHandshakeServer server(std::move(serverSock), false, false);
1396
1397   eventBase.loop();
1398
1399   EXPECT_TRUE(!client.handshakeVerify_);
1400   EXPECT_TRUE(client.handshakeSuccess_);
1401   EXPECT_TRUE(!client.handshakeError_);
1402   EXPECT_LE(0, client.handshakeTime.count());
1403   EXPECT_TRUE(!server.handshakeVerify_);
1404   EXPECT_TRUE(server.handshakeSuccess_);
1405   EXPECT_TRUE(!server.handshakeError_);
1406   EXPECT_LE(0, server.handshakeTime.count());
1407 }
1408
1409 /**
1410  * Test requireClientCert with client cert
1411  */
1412 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1413   EventBase eventBase;
1414   auto clientCtx = std::make_shared<SSLContext>();
1415   auto serverCtx = std::make_shared<SSLContext>();
1416   serverCtx->setVerificationOption(
1417       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1418   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1419   serverCtx->loadPrivateKey(testKey);
1420   serverCtx->loadCertificate(testCert);
1421   serverCtx->loadTrustedCertificates(testCA);
1422   serverCtx->loadClientCAList(testCA);
1423
1424   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1425   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1426   clientCtx->loadPrivateKey(testKey);
1427   clientCtx->loadCertificate(testCert);
1428   clientCtx->loadTrustedCertificates(testCA);
1429
1430   int fds[2];
1431   getfds(fds);
1432
1433   AsyncSSLSocket::UniquePtr clientSock(
1434       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1435   AsyncSSLSocket::UniquePtr serverSock(
1436       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1437
1438   SSLHandshakeClient client(std::move(clientSock), true, true);
1439   SSLHandshakeServer server(std::move(serverSock), true, true);
1440
1441   eventBase.loop();
1442
1443   EXPECT_TRUE(client.handshakeVerify_);
1444   EXPECT_TRUE(client.handshakeSuccess_);
1445   EXPECT_FALSE(client.handshakeError_);
1446   EXPECT_LE(0, client.handshakeTime.count());
1447   EXPECT_TRUE(server.handshakeVerify_);
1448   EXPECT_TRUE(server.handshakeSuccess_);
1449   EXPECT_FALSE(server.handshakeError_);
1450   EXPECT_LE(0, server.handshakeTime.count());
1451 }
1452
1453
1454 /**
1455  * Test requireClientCert with no client cert
1456  */
1457 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1458   EventBase eventBase;
1459   auto clientCtx = std::make_shared<SSLContext>();
1460   auto serverCtx = std::make_shared<SSLContext>();
1461   serverCtx->setVerificationOption(
1462       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1463   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1464   serverCtx->loadPrivateKey(testKey);
1465   serverCtx->loadCertificate(testCert);
1466   serverCtx->loadTrustedCertificates(testCA);
1467   serverCtx->loadClientCAList(testCA);
1468   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1469   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1470
1471   int fds[2];
1472   getfds(fds);
1473
1474   AsyncSSLSocket::UniquePtr clientSock(
1475       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1476   AsyncSSLSocket::UniquePtr serverSock(
1477       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1478
1479   SSLHandshakeClient client(std::move(clientSock), false, false);
1480   SSLHandshakeServer server(std::move(serverSock), false, false);
1481
1482   eventBase.loop();
1483
1484   EXPECT_FALSE(server.handshakeVerify_);
1485   EXPECT_FALSE(server.handshakeSuccess_);
1486   EXPECT_TRUE(server.handshakeError_);
1487   EXPECT_LE(0, client.handshakeTime.count());
1488   EXPECT_LE(0, server.handshakeTime.count());
1489 }
1490
1491 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1492   auto cert = getFileAsBuf(testCert);
1493   auto key = getFileAsBuf(testKey);
1494
1495   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1496   BIO_write(certBio.get(), cert.data(), cert.size());
1497   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1498   BIO_write(keyBio.get(), key.data(), key.size());
1499
1500   // Create SSL structs from buffers to get properties
1501   ssl::X509UniquePtr certStruct(
1502       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1503   ssl::EvpPkeyUniquePtr keyStruct(
1504       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1505   certBio = nullptr;
1506   keyBio = nullptr;
1507
1508   auto origCommonName = getCommonName(certStruct.get());
1509   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1510   certStruct = nullptr;
1511   keyStruct = nullptr;
1512
1513   auto ctx = std::make_shared<SSLContext>();
1514   ctx->loadPrivateKeyFromBufferPEM(key);
1515   ctx->loadCertificateFromBufferPEM(cert);
1516   ctx->loadTrustedCertificates(testCA);
1517
1518   ssl::SSLUniquePtr ssl(ctx->createSSL());
1519
1520   auto newCert = SSL_get_certificate(ssl.get());
1521   auto newKey = SSL_get_privatekey(ssl.get());
1522
1523   // Get properties from SSL struct
1524   auto newCommonName = getCommonName(newCert);
1525   auto newKeySize = EVP_PKEY_bits(newKey);
1526
1527   // Check that the key and cert have the expected properties
1528   EXPECT_EQ(origCommonName, newCommonName);
1529   EXPECT_EQ(origKeySize, newKeySize);
1530 }
1531
1532 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1533   EventBase eb;
1534
1535   // Set up SSL context.
1536   auto sslContext = std::make_shared<SSLContext>();
1537   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1538
1539   // create SSL socket
1540   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1541
1542   EXPECT_EQ(1500, socket->getMinWriteSize());
1543
1544   socket->setMinWriteSize(0);
1545   EXPECT_EQ(0, socket->getMinWriteSize());
1546   socket->setMinWriteSize(50000);
1547   EXPECT_EQ(50000, socket->getMinWriteSize());
1548 }
1549
1550 class ReadCallbackTerminator : public ReadCallback {
1551  public:
1552   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1553       : ReadCallback(wcb)
1554       , base_(base) {}
1555
1556   // Do not write data back, terminate the loop.
1557   void readDataAvailable(size_t len) noexcept override {
1558     std::cerr << "readDataAvailable, len " << len << std::endl;
1559
1560     currentBuffer.length = len;
1561
1562     buffers.push_back(currentBuffer);
1563     currentBuffer.reset();
1564     state = STATE_SUCCEEDED;
1565
1566     socket_->setReadCB(nullptr);
1567     base_->terminateLoopSoon();
1568   }
1569  private:
1570   EventBase* base_;
1571 };
1572
1573
1574 /**
1575  * Test a full unencrypted codepath
1576  */
1577 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1578   EventBase base;
1579
1580   auto clientCtx = std::make_shared<folly::SSLContext>();
1581   auto serverCtx = std::make_shared<folly::SSLContext>();
1582   int fds[2];
1583   getfds(fds);
1584   getctx(clientCtx, serverCtx);
1585   auto client = AsyncSSLSocket::newSocket(
1586                   clientCtx, &base, fds[0], false, true);
1587   auto server = AsyncSSLSocket::newSocket(
1588                   serverCtx, &base, fds[1], true, true);
1589
1590   ReadCallbackTerminator readCallback(&base, nullptr);
1591   server->setReadCB(&readCallback);
1592   readCallback.setSocket(server);
1593
1594   uint8_t buf[128];
1595   memset(buf, 'a', sizeof(buf));
1596   client->write(nullptr, buf, sizeof(buf));
1597
1598   // Check that bytes are unencrypted
1599   char c;
1600   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1601   EXPECT_EQ('a', c);
1602
1603   EventBaseAborter eba(&base, 3000);
1604   base.loop();
1605
1606   EXPECT_EQ(1, readCallback.buffers.size());
1607   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1608
1609   server->setReadCB(&readCallback);
1610
1611   // Unencrypted
1612   server->sslAccept(nullptr);
1613   client->sslConn(nullptr);
1614
1615   // Do NOT wait for handshake, writing should be queued and happen after
1616
1617   client->write(nullptr, buf, sizeof(buf));
1618
1619   // Check that bytes are *not* unencrypted
1620   char c2;
1621   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1622   EXPECT_NE('a', c2);
1623
1624
1625   base.loop();
1626
1627   EXPECT_EQ(2, readCallback.buffers.size());
1628   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1629 }
1630
1631 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1632   // Start listening on a local port
1633   WriteCallbackBase writeCallback;
1634   WriteErrorCallback readCallback(&writeCallback);
1635   HandshakeCallback handshakeCallback(&readCallback,
1636                                       HandshakeCallback::EXPECT_ERROR);
1637   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1638   TestSSLServer server(&acceptCallback);
1639
1640   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1641   socket->open();
1642   uint8_t buf[3] = {0x16, 0x03, 0x01};
1643   socket->write(buf, sizeof(buf));
1644   socket->closeWithReset();
1645
1646   handshakeCallback.waitForHandshake();
1647   EXPECT_NE(
1648       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1649   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1650 }
1651
1652 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1653   // Start listening on a local port
1654   WriteCallbackBase writeCallback;
1655   WriteErrorCallback readCallback(&writeCallback);
1656   HandshakeCallback handshakeCallback(&readCallback,
1657                                       HandshakeCallback::EXPECT_ERROR);
1658   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1659   TestSSLServer server(&acceptCallback);
1660
1661   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1662   socket->open();
1663   uint8_t buf[3] = {0x16, 0x03, 0x01};
1664   socket->write(buf, sizeof(buf));
1665   socket->close();
1666
1667   handshakeCallback.waitForHandshake();
1668   EXPECT_NE(
1669       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1670   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1671 }
1672
1673 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1674   // Start listening on a local port
1675   WriteCallbackBase writeCallback;
1676   WriteErrorCallback readCallback(&writeCallback);
1677   HandshakeCallback handshakeCallback(&readCallback,
1678                                       HandshakeCallback::EXPECT_ERROR);
1679   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1680   TestSSLServer server(&acceptCallback);
1681
1682   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1683   socket->open();
1684   uint8_t buf[256] = {0x16, 0x03};
1685   memset(buf + 2, 'a', sizeof(buf) - 2);
1686   socket->write(buf, sizeof(buf));
1687   socket->close();
1688
1689   handshakeCallback.waitForHandshake();
1690   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1691             std::string::npos);
1692   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1693             std::string::npos);
1694 }
1695
1696 #if FOLLY_ALLOW_TFO
1697
1698 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1699  public:
1700   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1701
1702   explicit MockAsyncTFOSSLSocket(
1703       std::shared_ptr<folly::SSLContext> sslCtx,
1704       EventBase* evb)
1705       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1706
1707   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1708 };
1709
1710 /**
1711  * Test connecting to, writing to, reading from, and closing the
1712  * connection to the SSL server with TFO.
1713  */
1714 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1715   // Start listening on a local port
1716   WriteCallbackBase writeCallback;
1717   ReadCallback readCallback(&writeCallback);
1718   HandshakeCallback handshakeCallback(&readCallback);
1719   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1720   TestSSLServer server(&acceptCallback, true);
1721
1722   // Set up SSL context.
1723   auto sslContext = std::make_shared<SSLContext>();
1724
1725   // connect
1726   auto socket =
1727       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1728   socket->enableTFO();
1729   socket->open();
1730
1731   // write()
1732   std::array<uint8_t, 128> buf;
1733   memset(buf.data(), 'a', buf.size());
1734   socket->write(buf.data(), buf.size());
1735
1736   // read()
1737   std::array<uint8_t, 128> readbuf;
1738   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1739   EXPECT_EQ(bytesRead, 128);
1740   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1741
1742   // close()
1743   socket->close();
1744 }
1745
1746 /**
1747  * Test connecting to, writing to, reading from, and closing the
1748  * connection to the SSL server with TFO.
1749  */
1750 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1751   // Start listening on a local port
1752   WriteCallbackBase writeCallback;
1753   ReadCallback readCallback(&writeCallback);
1754   HandshakeCallback handshakeCallback(&readCallback);
1755   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1756   TestSSLServer server(&acceptCallback, false);
1757
1758   // Set up SSL context.
1759   auto sslContext = std::make_shared<SSLContext>();
1760
1761   // connect
1762   auto socket =
1763       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1764   socket->enableTFO();
1765   socket->open();
1766
1767   // write()
1768   std::array<uint8_t, 128> buf;
1769   memset(buf.data(), 'a', buf.size());
1770   socket->write(buf.data(), buf.size());
1771
1772   // read()
1773   std::array<uint8_t, 128> readbuf;
1774   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1775   EXPECT_EQ(bytesRead, 128);
1776   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1777
1778   // close()
1779   socket->close();
1780 }
1781
1782 class ConnCallback : public AsyncSocket::ConnectCallback {
1783  public:
1784   virtual void connectSuccess() noexcept override {
1785     state = State::SUCCESS;
1786   }
1787
1788   virtual void connectErr(const AsyncSocketException&) noexcept override {
1789     state = State::ERROR;
1790   }
1791
1792   enum class State { WAITING, SUCCESS, ERROR };
1793
1794   State state{State::WAITING};
1795 };
1796
1797 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1798     EventBase* evb,
1799     const SocketAddress& address) {
1800   // Set up SSL context.
1801   auto sslContext = std::make_shared<SSLContext>();
1802
1803   // connect
1804   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1805       new MockAsyncTFOSSLSocket(sslContext, evb));
1806   socket->enableTFO();
1807
1808   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1809       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1810         sockaddr_storage addr;
1811         auto len = address.getAddress(&addr);
1812         return connect(fd, (const struct sockaddr*)&addr, len);
1813       }));
1814   return socket;
1815 }
1816
1817 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1818   // Start listening on a local port
1819   WriteCallbackBase writeCallback;
1820   ReadCallback readCallback(&writeCallback);
1821   HandshakeCallback handshakeCallback(&readCallback);
1822   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1823   TestSSLServer server(&acceptCallback, true);
1824
1825   EventBase evb;
1826
1827   auto socket = setupSocketWithFallback(&evb, server.getAddress());
1828   ConnCallback ccb;
1829   socket->connect(&ccb, server.getAddress(), 30);
1830
1831   evb.loop();
1832   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1833
1834   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1835   evb.loop();
1836
1837   BlockingSocket sock(std::move(socket));
1838   // write()
1839   std::array<uint8_t, 128> buf;
1840   memset(buf.data(), 'a', buf.size());
1841   sock.write(buf.data(), buf.size());
1842
1843   // read()
1844   std::array<uint8_t, 128> readbuf;
1845   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1846   EXPECT_EQ(bytesRead, 128);
1847   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1848
1849   // close()
1850   sock.close();
1851 }
1852
1853 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1854   // Start listening on a local port
1855   WriteCallbackBase writeCallback;
1856   ReadErrorCallback readCallback(&writeCallback);
1857   HandshakeCallback handshakeCallback(&readCallback);
1858   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1859   TestSSLServer server(&acceptCallback, true);
1860
1861   // Set up SSL context.
1862   auto sslContext = std::make_shared<SSLContext>();
1863
1864   // connect
1865   auto socket =
1866       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1867   socket->enableTFO();
1868   EXPECT_THROW(
1869       socket->open(std::chrono::milliseconds(1)), AsyncSocketException);
1870 }
1871
1872 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1873   // Start listening on a local port
1874   WriteCallbackBase writeCallback;
1875   ReadErrorCallback readCallback(&writeCallback);
1876   HandshakeCallback handshakeCallback(&readCallback);
1877   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1878   TestSSLServer server(&acceptCallback, true);
1879
1880   EventBase evb;
1881
1882   auto socket = setupSocketWithFallback(&evb, server.getAddress());
1883   ConnCallback ccb;
1884   // Set a short timeout
1885   socket->connect(&ccb, server.getAddress(), 1);
1886
1887   evb.loop();
1888   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1889 }
1890
1891 #endif
1892
1893 } // namespace
1894
1895 ///////////////////////////////////////////////////////////////////////////
1896 // init_unit_test_suite
1897 ///////////////////////////////////////////////////////////////////////////
1898 namespace {
1899 struct Initializer {
1900   Initializer() {
1901     signal(SIGPIPE, SIG_IGN);
1902   }
1903 };
1904 Initializer initializer;
1905 } // anonymous