Move RequestContext definitions to source files
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <pthread.h>
19 #include <signal.h>
20
21 #include <folly/SocketAddress.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/EventBase.h>
24 #include <folly/portability/Sockets.h>
25 #include <folly/portability/Unistd.h>
26
27 #include <folly/io/async/test/BlockingSocket.h>
28
29 #include <fcntl.h>
30 #include <folly/io/Cursor.h>
31 #include <gtest/gtest.h>
32 #include <openssl/bio.h>
33 #include <sys/types.h>
34 #include <fstream>
35 #include <iostream>
36 #include <list>
37 #include <set>
38 #include <thread>
39
40 #include <gmock/gmock.h>
41
42 using std::string;
43 using std::vector;
44 using std::min;
45 using std::cerr;
46 using std::endl;
47 using std::list;
48
49 using namespace testing;
50
51 namespace folly {
52 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
53 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
54 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
55
56 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
57 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
58 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
59
60 constexpr size_t SSLClient::kMaxReadBufferSz;
61 constexpr size_t SSLClient::kMaxReadsPerEvent;
62
63 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
64     : ctx_(new folly::SSLContext),
65       acb_(acb),
66       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
67   // Set up the SSL context
68   ctx_->loadCertificate(testCert);
69   ctx_->loadPrivateKey(testKey);
70   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
71
72   acb_->ctx_ = ctx_;
73   acb_->base_ = &evb_;
74
75   // Enable TFO
76   if (enableTFO) {
77     LOG(INFO) << "server TFO enabled";
78     socket_->setTFOEnabled(true, 1000);
79   }
80
81   // set up the listening socket
82   socket_->bind(0);
83   socket_->getAddress(&address_);
84   socket_->listen(100);
85   socket_->addAcceptCallback(acb_, &evb_);
86   socket_->startAccepting();
87
88   int ret = pthread_create(&thread_, nullptr, Main, this);
89   assert(ret == 0);
90   (void)ret;
91
92   std::cerr << "Accepting connections on " << address_ << std::endl;
93 }
94
95 void getfds(int fds[2]) {
96   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
97     FAIL() << "failed to create socketpair: " << strerror(errno);
98   }
99   for (int idx = 0; idx < 2; ++idx) {
100     int flags = fcntl(fds[idx], F_GETFL, 0);
101     if (flags == -1) {
102       FAIL() << "failed to get flags for socket " << idx << ": "
103              << strerror(errno);
104     }
105     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
106       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
107              << strerror(errno);
108     }
109   }
110 }
111
112 void getctx(
113   std::shared_ptr<folly::SSLContext> clientCtx,
114   std::shared_ptr<folly::SSLContext> serverCtx) {
115   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
116
117   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
118   serverCtx->loadCertificate(
119       testCert);
120   serverCtx->loadPrivateKey(
121       testKey);
122 }
123
124 void sslsocketpair(
125   EventBase* eventBase,
126   AsyncSSLSocket::UniquePtr* clientSock,
127   AsyncSSLSocket::UniquePtr* serverSock) {
128   auto clientCtx = std::make_shared<folly::SSLContext>();
129   auto serverCtx = std::make_shared<folly::SSLContext>();
130   int fds[2];
131   getfds(fds);
132   getctx(clientCtx, serverCtx);
133   clientSock->reset(new AsyncSSLSocket(
134                       clientCtx, eventBase, fds[0], false));
135   serverSock->reset(new AsyncSSLSocket(
136                       serverCtx, eventBase, fds[1], true));
137
138   // (*clientSock)->setSendTimeout(100);
139   // (*serverSock)->setSendTimeout(100);
140 }
141
142 // client protocol filters
143 bool clientProtoFilterPickPony(unsigned char** client,
144   unsigned int* client_len, const unsigned char*, unsigned int ) {
145   //the protocol string in length prefixed byte string. the
146   //length byte is not included in the length
147   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
148   *client = p;
149   *client_len = 7;
150   return true;
151 }
152
153 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
154   const unsigned char*, unsigned int) {
155   return false;
156 }
157
158 std::string getFileAsBuf(const char* fileName) {
159   std::string buffer;
160   folly::readFile(fileName, buffer);
161   return buffer;
162 }
163
164 std::string getCommonName(X509* cert) {
165   X509_NAME* subject = X509_get_subject_name(cert);
166   std::string cn;
167   cn.resize(ub_common_name);
168   X509_NAME_get_text_by_NID(
169       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
170   return cn;
171 }
172
173 /**
174  * Test connecting to, writing to, reading from, and closing the
175  * connection to the SSL server.
176  */
177 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
178   // Start listening on a local port
179   WriteCallbackBase writeCallback;
180   ReadCallback readCallback(&writeCallback);
181   HandshakeCallback handshakeCallback(&readCallback);
182   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
183   TestSSLServer server(&acceptCallback);
184
185   // Set up SSL context.
186   std::shared_ptr<SSLContext> sslContext(new SSLContext());
187   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
188   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
189   //sslContext->authenticate(true, false);
190
191   // connect
192   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
193                                                  sslContext);
194   socket->open();
195
196   // write()
197   uint8_t buf[128];
198   memset(buf, 'a', sizeof(buf));
199   socket->write(buf, sizeof(buf));
200
201   // read()
202   uint8_t readbuf[128];
203   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
204   EXPECT_EQ(bytesRead, 128);
205   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
206
207   // close()
208   socket->close();
209
210   cerr << "ConnectWriteReadClose test completed" << endl;
211 }
212
213 /**
214  * Test reading after server close.
215  */
216 TEST(AsyncSSLSocketTest, ReadAfterClose) {
217   // Start listening on a local port
218   WriteCallbackBase writeCallback;
219   ReadEOFCallback readCallback(&writeCallback);
220   HandshakeCallback handshakeCallback(&readCallback);
221   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
222   auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
223
224   // Set up SSL context.
225   auto sslContext = std::make_shared<SSLContext>();
226   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
227
228   auto socket =
229       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
230   socket->open();
231
232   // This should trigger an EOF on the client.
233   auto evb = handshakeCallback.getSocket()->getEventBase();
234   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
235   std::array<uint8_t, 128> readbuf;
236   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
237   EXPECT_EQ(0, bytesRead);
238 }
239
240 /**
241  * Test bad renegotiation
242  */
243 TEST(AsyncSSLSocketTest, Renegotiate) {
244   EventBase eventBase;
245   auto clientCtx = std::make_shared<SSLContext>();
246   auto dfServerCtx = std::make_shared<SSLContext>();
247   std::array<int, 2> fds;
248   getfds(fds.data());
249   getctx(clientCtx, dfServerCtx);
250
251   AsyncSSLSocket::UniquePtr clientSock(
252       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
253   AsyncSSLSocket::UniquePtr serverSock(
254       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
255   SSLHandshakeClient client(std::move(clientSock), true, true);
256   RenegotiatingServer server(std::move(serverSock));
257
258   while (!client.handshakeSuccess_ && !client.handshakeError_) {
259     eventBase.loopOnce();
260   }
261
262   ASSERT_TRUE(client.handshakeSuccess_);
263
264   auto sslSock = std::move(client).moveSocket();
265   sslSock->detachEventBase();
266   // This is nasty, however we don't want to add support for
267   // renegotiation in AsyncSSLSocket.
268   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
269
270   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
271
272   std::thread t([&]() { eventBase.loopForever(); });
273
274   // Trigger the renegotiation.
275   std::array<uint8_t, 128> buf;
276   memset(buf.data(), 'a', buf.size());
277   try {
278     socket->write(buf.data(), buf.size());
279   } catch (AsyncSocketException& e) {
280     LOG(INFO) << "client got error " << e.what();
281   }
282   eventBase.terminateLoopSoon();
283   t.join();
284
285   eventBase.loop();
286   ASSERT_TRUE(server.renegotiationError_);
287 }
288
289 /**
290  * Negative test for handshakeError().
291  */
292 TEST(AsyncSSLSocketTest, HandshakeError) {
293   // Start listening on a local port
294   WriteCallbackBase writeCallback;
295   WriteErrorCallback readCallback(&writeCallback);
296   HandshakeCallback handshakeCallback(&readCallback);
297   HandshakeErrorCallback acceptCallback(&handshakeCallback);
298   TestSSLServer server(&acceptCallback);
299
300   // Set up SSL context.
301   std::shared_ptr<SSLContext> sslContext(new SSLContext());
302   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
303
304   // connect
305   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
306                                                  sslContext);
307   // read()
308   bool ex = false;
309   try {
310     socket->open();
311
312     uint8_t readbuf[128];
313     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
314     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
315   } catch (AsyncSocketException &e) {
316     ex = true;
317   }
318   EXPECT_TRUE(ex);
319
320   // close()
321   socket->close();
322   cerr << "HandshakeError test completed" << endl;
323 }
324
325 /**
326  * Negative test for readError().
327  */
328 TEST(AsyncSSLSocketTest, ReadError) {
329   // Start listening on a local port
330   WriteCallbackBase writeCallback;
331   ReadErrorCallback readCallback(&writeCallback);
332   HandshakeCallback handshakeCallback(&readCallback);
333   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
334   TestSSLServer server(&acceptCallback);
335
336   // Set up SSL context.
337   std::shared_ptr<SSLContext> sslContext(new SSLContext());
338   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
339
340   // connect
341   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
342                                                  sslContext);
343   socket->open();
344
345   // write something to trigger ssl handshake
346   uint8_t buf[128];
347   memset(buf, 'a', sizeof(buf));
348   socket->write(buf, sizeof(buf));
349
350   socket->close();
351   cerr << "ReadError test completed" << endl;
352 }
353
354 /**
355  * Negative test for writeError().
356  */
357 TEST(AsyncSSLSocketTest, WriteError) {
358   // Start listening on a local port
359   WriteCallbackBase writeCallback;
360   WriteErrorCallback readCallback(&writeCallback);
361   HandshakeCallback handshakeCallback(&readCallback);
362   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
363   TestSSLServer server(&acceptCallback);
364
365   // Set up SSL context.
366   std::shared_ptr<SSLContext> sslContext(new SSLContext());
367   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
368
369   // connect
370   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
371                                                  sslContext);
372   socket->open();
373
374   // write something to trigger ssl handshake
375   uint8_t buf[128];
376   memset(buf, 'a', sizeof(buf));
377   socket->write(buf, sizeof(buf));
378
379   socket->close();
380   cerr << "WriteError test completed" << endl;
381 }
382
383 /**
384  * Test a socket with TCP_NODELAY unset.
385  */
386 TEST(AsyncSSLSocketTest, SocketWithDelay) {
387   // Start listening on a local port
388   WriteCallbackBase writeCallback;
389   ReadCallback readCallback(&writeCallback);
390   HandshakeCallback handshakeCallback(&readCallback);
391   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
392   TestSSLServer server(&acceptCallback);
393
394   // Set up SSL context.
395   std::shared_ptr<SSLContext> sslContext(new SSLContext());
396   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
397
398   // connect
399   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
400                                                  sslContext);
401   socket->open();
402
403   // write()
404   uint8_t buf[128];
405   memset(buf, 'a', sizeof(buf));
406   socket->write(buf, sizeof(buf));
407
408   // read()
409   uint8_t readbuf[128];
410   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
411   EXPECT_EQ(bytesRead, 128);
412   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
413
414   // close()
415   socket->close();
416
417   cerr << "SocketWithDelay test completed" << endl;
418 }
419
420 using NextProtocolTypePair =
421     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
422
423 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
424   // For matching protos
425  public:
426   void SetUp() override { getctx(clientCtx, serverCtx); }
427
428   void connect(bool unset = false) {
429     getfds(fds);
430
431     if (unset) {
432       // unsetting NPN for any of [client, server] is enough to make NPN not
433       // work
434       clientCtx->unsetNextProtocols();
435     }
436
437     AsyncSSLSocket::UniquePtr clientSock(
438       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
439     AsyncSSLSocket::UniquePtr serverSock(
440       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
441     client = folly::make_unique<NpnClient>(std::move(clientSock));
442     server = folly::make_unique<NpnServer>(std::move(serverSock));
443
444     eventBase.loop();
445   }
446
447   void expectProtocol(const std::string& proto) {
448     EXPECT_NE(client->nextProtoLength, 0);
449     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
450     EXPECT_EQ(
451         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
452         0);
453     string selected((const char*)client->nextProto, client->nextProtoLength);
454     EXPECT_EQ(proto, selected);
455   }
456
457   void expectNoProtocol() {
458     EXPECT_EQ(client->nextProtoLength, 0);
459     EXPECT_EQ(server->nextProtoLength, 0);
460     EXPECT_EQ(client->nextProto, nullptr);
461     EXPECT_EQ(server->nextProto, nullptr);
462   }
463
464   void expectProtocolType() {
465     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
466         GetParam().second == SSLContext::NextProtocolType::ANY) {
467       EXPECT_EQ(client->protocolType, server->protocolType);
468     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
469                GetParam().second == SSLContext::NextProtocolType::ANY) {
470       // Well not much we can say
471     } else {
472       expectProtocolType(GetParam());
473     }
474   }
475
476   void expectProtocolType(NextProtocolTypePair expected) {
477     EXPECT_EQ(client->protocolType, expected.first);
478     EXPECT_EQ(server->protocolType, expected.second);
479   }
480
481   EventBase eventBase;
482   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
483   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
484   int fds[2];
485   std::unique_ptr<NpnClient> client;
486   std::unique_ptr<NpnServer> server;
487 };
488
489 class NextProtocolTLSExtTest : public NextProtocolTest {
490   // For extended TLS protos
491 };
492
493 class NextProtocolNPNOnlyTest : public NextProtocolTest {
494   // For mismatching protos
495 };
496
497 class NextProtocolMismatchTest : public NextProtocolTest {
498   // For mismatching protos
499 };
500
501 TEST_P(NextProtocolTest, NpnTestOverlap) {
502   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
503   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
504                                         GetParam().second);
505
506   connect();
507
508   expectProtocol("baz");
509   expectProtocolType();
510 }
511
512 TEST_P(NextProtocolTest, NpnTestUnset) {
513   // Identical to above test, except that we want unset NPN before
514   // looping.
515   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
516   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
517                                         GetParam().second);
518
519   connect(true /* unset */);
520
521   // if alpn negotiation fails, type will appear as npn
522   expectNoProtocol();
523   EXPECT_EQ(client->protocolType, server->protocolType);
524 }
525
526 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
527   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
528   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
529                                         GetParam().second);
530
531   connect();
532
533   expectNoProtocol();
534   expectProtocolType(
535       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
536 }
537
538 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
539 // will fail on 1.0.2 before that.
540 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
541   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
542   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
543                                         GetParam().second);
544
545   connect();
546
547   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
548       GetParam().second == SSLContext::NextProtocolType::ALPN) {
549     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
550     // mismatch should result in a fatal alert, but this is OpenSSL's current
551     // behavior and we want to know if it changes.
552     expectNoProtocol();
553   } else {
554     expectProtocol("blub");
555     expectProtocolType(
556         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
557   }
558 }
559
560 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
561   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
562   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
563   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
564                                         GetParam().second);
565
566   connect();
567
568   expectProtocol("ponies");
569   expectProtocolType();
570 }
571
572 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
573   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
574   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
575   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
576                                         GetParam().second);
577
578   connect();
579
580   expectProtocol("blub");
581   expectProtocolType();
582 }
583
584 TEST_P(NextProtocolTest, RandomizedNpnTest) {
585   // Probability that this test will fail is 2^-64, which could be considered
586   // as negligible.
587   const int kTries = 64;
588
589   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
590                                         GetParam().first);
591   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
592                                                   GetParam().second);
593
594   std::set<string> selectedProtocols;
595   for (int i = 0; i < kTries; ++i) {
596     connect();
597
598     EXPECT_NE(client->nextProtoLength, 0);
599     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
600     EXPECT_EQ(
601         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
602         0);
603     string selected((const char*)client->nextProto, client->nextProtoLength);
604     selectedProtocols.insert(selected);
605     expectProtocolType();
606   }
607   EXPECT_EQ(selectedProtocols.size(), 2);
608 }
609
610 INSTANTIATE_TEST_CASE_P(
611     AsyncSSLSocketTest,
612     NextProtocolTest,
613     ::testing::Values(
614         NextProtocolTypePair(
615             SSLContext::NextProtocolType::NPN,
616             SSLContext::NextProtocolType::NPN),
617         NextProtocolTypePair(
618             SSLContext::NextProtocolType::NPN,
619             SSLContext::NextProtocolType::ANY),
620         NextProtocolTypePair(
621             SSLContext::NextProtocolType::ANY,
622             SSLContext::NextProtocolType::ANY)));
623
624 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
625 INSTANTIATE_TEST_CASE_P(
626     AsyncSSLSocketTest,
627     NextProtocolTLSExtTest,
628     ::testing::Values(
629         NextProtocolTypePair(
630             SSLContext::NextProtocolType::ALPN,
631             SSLContext::NextProtocolType::ALPN),
632         NextProtocolTypePair(
633             SSLContext::NextProtocolType::ALPN,
634             SSLContext::NextProtocolType::ANY),
635         NextProtocolTypePair(
636             SSLContext::NextProtocolType::ANY,
637             SSLContext::NextProtocolType::ALPN)));
638 #endif
639
640 INSTANTIATE_TEST_CASE_P(
641     AsyncSSLSocketTest,
642     NextProtocolNPNOnlyTest,
643     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
644                                            SSLContext::NextProtocolType::NPN)));
645
646 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
647 INSTANTIATE_TEST_CASE_P(
648     AsyncSSLSocketTest,
649     NextProtocolMismatchTest,
650     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
651                                            SSLContext::NextProtocolType::ALPN),
652                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
653                                            SSLContext::NextProtocolType::NPN)));
654 #endif
655
656 #ifndef OPENSSL_NO_TLSEXT
657 /**
658  * 1. Client sends TLSEXT_HOSTNAME in client hello.
659  * 2. Server found a match SSL_CTX and use this SSL_CTX to
660  *    continue the SSL handshake.
661  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
662  */
663 TEST(AsyncSSLSocketTest, SNITestMatch) {
664   EventBase eventBase;
665   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
666   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
667   // Use the same SSLContext to continue the handshake after
668   // tlsext_hostname match.
669   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
670   const std::string serverName("xyz.newdev.facebook.com");
671   int fds[2];
672   getfds(fds);
673   getctx(clientCtx, dfServerCtx);
674
675   AsyncSSLSocket::UniquePtr clientSock(
676     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
677   AsyncSSLSocket::UniquePtr serverSock(
678     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
679   SNIClient client(std::move(clientSock));
680   SNIServer server(std::move(serverSock),
681                    dfServerCtx,
682                    hskServerCtx,
683                    serverName);
684
685   eventBase.loop();
686
687   EXPECT_TRUE(client.serverNameMatch);
688   EXPECT_TRUE(server.serverNameMatch);
689 }
690
691 /**
692  * 1. Client sends TLSEXT_HOSTNAME in client hello.
693  * 2. Server cannot find a matching SSL_CTX and continue to use
694  *    the current SSL_CTX to do the handshake.
695  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
696  */
697 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
698   EventBase eventBase;
699   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
700   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
701   // Use the same SSLContext to continue the handshake after
702   // tlsext_hostname match.
703   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
704   const std::string clientRequestingServerName("foo.com");
705   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
706
707   int fds[2];
708   getfds(fds);
709   getctx(clientCtx, dfServerCtx);
710
711   AsyncSSLSocket::UniquePtr clientSock(
712     new AsyncSSLSocket(clientCtx,
713                         &eventBase,
714                         fds[0],
715                         clientRequestingServerName));
716   AsyncSSLSocket::UniquePtr serverSock(
717     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
718   SNIClient client(std::move(clientSock));
719   SNIServer server(std::move(serverSock),
720                    dfServerCtx,
721                    hskServerCtx,
722                    serverExpectedServerName);
723
724   eventBase.loop();
725
726   EXPECT_TRUE(!client.serverNameMatch);
727   EXPECT_TRUE(!server.serverNameMatch);
728 }
729 /**
730  * 1. Client sends TLSEXT_HOSTNAME in client hello.
731  * 2. We then change the serverName.
732  * 3. We expect that we get 'false' as the result for serNameMatch.
733  */
734
735 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
736    EventBase eventBase;
737   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
738   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
739   // Use the same SSLContext to continue the handshake after
740   // tlsext_hostname match.
741   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
742   const std::string serverName("xyz.newdev.facebook.com");
743   int fds[2];
744   getfds(fds);
745   getctx(clientCtx, dfServerCtx);
746
747   AsyncSSLSocket::UniquePtr clientSock(
748     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
749   //Change the server name
750   std::string newName("new.com");
751   clientSock->setServerName(newName);
752   AsyncSSLSocket::UniquePtr serverSock(
753     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
754   SNIClient client(std::move(clientSock));
755   SNIServer server(std::move(serverSock),
756                    dfServerCtx,
757                    hskServerCtx,
758                    serverName);
759
760   eventBase.loop();
761
762   EXPECT_TRUE(!client.serverNameMatch);
763 }
764
765 /**
766  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
767  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
768  */
769 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
770   EventBase eventBase;
771   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
772   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
773   // Use the same SSLContext to continue the handshake after
774   // tlsext_hostname match.
775   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
776   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
777
778   int fds[2];
779   getfds(fds);
780   getctx(clientCtx, dfServerCtx);
781
782   AsyncSSLSocket::UniquePtr clientSock(
783     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
784   AsyncSSLSocket::UniquePtr serverSock(
785     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
786   SNIClient client(std::move(clientSock));
787   SNIServer server(std::move(serverSock),
788                    dfServerCtx,
789                    hskServerCtx,
790                    serverExpectedServerName);
791
792   eventBase.loop();
793
794   EXPECT_TRUE(!client.serverNameMatch);
795   EXPECT_TRUE(!server.serverNameMatch);
796 }
797
798 #endif
799 /**
800  * Test SSL client socket
801  */
802 TEST(AsyncSSLSocketTest, SSLClientTest) {
803   // Start listening on a local port
804   WriteCallbackBase writeCallback;
805   ReadCallback readCallback(&writeCallback);
806   HandshakeCallback handshakeCallback(&readCallback);
807   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
808   TestSSLServer server(&acceptCallback);
809
810   // Set up SSL client
811   EventBase eventBase;
812   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
813
814   client->connect();
815   EventBaseAborter eba(&eventBase, 3000);
816   eventBase.loop();
817
818   EXPECT_EQ(client->getMiss(), 1);
819   EXPECT_EQ(client->getHit(), 0);
820
821   cerr << "SSLClientTest test completed" << endl;
822 }
823
824
825 /**
826  * Test SSL client socket session re-use
827  */
828 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
829   // Start listening on a local port
830   WriteCallbackBase writeCallback;
831   ReadCallback readCallback(&writeCallback);
832   HandshakeCallback handshakeCallback(&readCallback);
833   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
834   TestSSLServer server(&acceptCallback);
835
836   // Set up SSL client
837   EventBase eventBase;
838   auto client =
839       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
840
841   client->connect();
842   EventBaseAborter eba(&eventBase, 3000);
843   eventBase.loop();
844
845   EXPECT_EQ(client->getMiss(), 1);
846   EXPECT_EQ(client->getHit(), 9);
847
848   cerr << "SSLClientTestReuse test completed" << endl;
849 }
850
851 /**
852  * Test SSL client socket timeout
853  */
854 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
855   // Start listening on a local port
856   EmptyReadCallback readCallback;
857   HandshakeCallback handshakeCallback(&readCallback,
858                                       HandshakeCallback::EXPECT_ERROR);
859   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
860   TestSSLServer server(&acceptCallback);
861
862   // Set up SSL client
863   EventBase eventBase;
864   auto client =
865       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
866   client->connect(true /* write before connect completes */);
867   EventBaseAborter eba(&eventBase, 3000);
868   eventBase.loop();
869
870   usleep(100000);
871   // This is checking that the connectError callback precedes any queued
872   // writeError callbacks.  This matches AsyncSocket's behavior
873   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
874   EXPECT_EQ(client->getErrors(), 1);
875   EXPECT_EQ(client->getMiss(), 0);
876   EXPECT_EQ(client->getHit(), 0);
877
878   cerr << "SSLClientTimeoutTest test completed" << endl;
879 }
880
881
882 /**
883  * Test SSL server async cache
884  */
885 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
886   // Start listening on a local port
887   WriteCallbackBase writeCallback;
888   ReadCallback readCallback(&writeCallback);
889   HandshakeCallback handshakeCallback(&readCallback);
890   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
891   TestSSLAsyncCacheServer server(&acceptCallback);
892
893   // Set up SSL client
894   EventBase eventBase;
895   auto client =
896       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
897
898   client->connect();
899   EventBaseAborter eba(&eventBase, 3000);
900   eventBase.loop();
901
902   EXPECT_EQ(server.getAsyncCallbacks(), 18);
903   EXPECT_EQ(server.getAsyncLookups(), 9);
904   EXPECT_EQ(client->getMiss(), 10);
905   EXPECT_EQ(client->getHit(), 0);
906
907   cerr << "SSLServerAsyncCacheTest test completed" << endl;
908 }
909
910
911 /**
912  * Test SSL server accept timeout with cache path
913  */
914 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
915   // Start listening on a local port
916   WriteCallbackBase writeCallback;
917   ReadCallback readCallback(&writeCallback);
918   EmptyReadCallback clientReadCallback;
919   HandshakeCallback handshakeCallback(&readCallback);
920   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
921   TestSSLAsyncCacheServer server(&acceptCallback);
922
923   // Set up SSL client
924   EventBase eventBase;
925   // only do a TCP connect
926   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
927   sock->connect(nullptr, server.getAddress());
928   clientReadCallback.tcpSocket_ = sock;
929   sock->setReadCB(&clientReadCallback);
930
931   EventBaseAborter eba(&eventBase, 3000);
932   eventBase.loop();
933
934   EXPECT_EQ(readCallback.state, STATE_WAITING);
935
936   cerr << "SSLServerTimeoutTest test completed" << endl;
937 }
938
939 /**
940  * Test SSL server accept timeout with cache path
941  */
942 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
943   // Start listening on a local port
944   WriteCallbackBase writeCallback;
945   ReadCallback readCallback(&writeCallback);
946   HandshakeCallback handshakeCallback(&readCallback);
947   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
948   TestSSLAsyncCacheServer server(&acceptCallback);
949
950   // Set up SSL client
951   EventBase eventBase;
952   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
953
954   client->connect();
955   EventBaseAborter eba(&eventBase, 3000);
956   eventBase.loop();
957
958   EXPECT_EQ(server.getAsyncCallbacks(), 1);
959   EXPECT_EQ(server.getAsyncLookups(), 1);
960   EXPECT_EQ(client->getErrors(), 1);
961   EXPECT_EQ(client->getMiss(), 1);
962   EXPECT_EQ(client->getHit(), 0);
963
964   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
965 }
966
967 /**
968  * Test SSL server accept timeout with cache path
969  */
970 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
971   // Start listening on a local port
972   WriteCallbackBase writeCallback;
973   ReadCallback readCallback(&writeCallback);
974   HandshakeCallback handshakeCallback(&readCallback,
975                                       HandshakeCallback::EXPECT_ERROR);
976   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
977   TestSSLAsyncCacheServer server(&acceptCallback, 500);
978
979   // Set up SSL client
980   EventBase eventBase;
981   auto client =
982       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
983
984   client->connect();
985   EventBaseAborter eba(&eventBase, 3000);
986   eventBase.loop();
987
988   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
989       handshakeCallback.closeSocket();});
990   // give time for the cache lookup to come back and find it closed
991   handshakeCallback.waitForHandshake();
992
993   EXPECT_EQ(server.getAsyncCallbacks(), 1);
994   EXPECT_EQ(server.getAsyncLookups(), 1);
995   EXPECT_EQ(client->getErrors(), 1);
996   EXPECT_EQ(client->getMiss(), 1);
997   EXPECT_EQ(client->getHit(), 0);
998
999   cerr << "SSLServerCacheCloseTest test completed" << endl;
1000 }
1001
1002 /**
1003  * Verify Client Ciphers obtained using SSL MSG Callback.
1004  */
1005 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
1006   EventBase eventBase;
1007   auto clientCtx = std::make_shared<SSLContext>();
1008   auto serverCtx = std::make_shared<SSLContext>();
1009   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1010   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
1011   serverCtx->loadPrivateKey(testKey);
1012   serverCtx->loadCertificate(testCert);
1013   serverCtx->loadTrustedCertificates(testCA);
1014   serverCtx->loadClientCAList(testCA);
1015
1016   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1017   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
1018   clientCtx->loadPrivateKey(testKey);
1019   clientCtx->loadCertificate(testCert);
1020   clientCtx->loadTrustedCertificates(testCA);
1021
1022   int fds[2];
1023   getfds(fds);
1024
1025   AsyncSSLSocket::UniquePtr clientSock(
1026       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1027   AsyncSSLSocket::UniquePtr serverSock(
1028       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1029
1030   SSLHandshakeClient client(std::move(clientSock), true, true);
1031   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1032
1033   eventBase.loop();
1034
1035   EXPECT_EQ(server.clientCiphers_,
1036             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
1037   EXPECT_TRUE(client.handshakeVerify_);
1038   EXPECT_TRUE(client.handshakeSuccess_);
1039   EXPECT_TRUE(!client.handshakeError_);
1040   EXPECT_TRUE(server.handshakeVerify_);
1041   EXPECT_TRUE(server.handshakeSuccess_);
1042   EXPECT_TRUE(!server.handshakeError_);
1043 }
1044
1045 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1046   EventBase eventBase;
1047   auto ctx = std::make_shared<SSLContext>();
1048
1049   int fds[2];
1050   getfds(fds);
1051
1052   int bufLen = 42;
1053   uint8_t majorVersion = 18;
1054   uint8_t minorVersion = 25;
1055
1056   // Create callback buf
1057   auto buf = IOBuf::create(bufLen);
1058   buf->append(bufLen);
1059   folly::io::RWPrivateCursor cursor(buf.get());
1060   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1061   cursor.write<uint16_t>(0);
1062   cursor.write<uint8_t>(38);
1063   cursor.write<uint8_t>(majorVersion);
1064   cursor.write<uint8_t>(minorVersion);
1065   cursor.skip(32);
1066   cursor.write<uint32_t>(0);
1067
1068   SSL* ssl = ctx->createSSL();
1069   SCOPE_EXIT { SSL_free(ssl); };
1070   AsyncSSLSocket::UniquePtr sock(
1071       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1072   sock->enableClientHelloParsing();
1073
1074   // Test client hello parsing in one packet
1075   AsyncSSLSocket::clientHelloParsingCallback(
1076       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1077   buf.reset();
1078
1079   auto parsedClientHello = sock->getClientHelloInfo();
1080   EXPECT_TRUE(parsedClientHello != nullptr);
1081   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1082   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1083 }
1084
1085 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1086   EventBase eventBase;
1087   auto ctx = std::make_shared<SSLContext>();
1088
1089   int fds[2];
1090   getfds(fds);
1091
1092   int bufLen = 42;
1093   uint8_t majorVersion = 18;
1094   uint8_t minorVersion = 25;
1095
1096   // Create callback buf
1097   auto buf = IOBuf::create(bufLen);
1098   buf->append(bufLen);
1099   folly::io::RWPrivateCursor cursor(buf.get());
1100   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1101   cursor.write<uint16_t>(0);
1102   cursor.write<uint8_t>(38);
1103   cursor.write<uint8_t>(majorVersion);
1104   cursor.write<uint8_t>(minorVersion);
1105   cursor.skip(32);
1106   cursor.write<uint32_t>(0);
1107
1108   SSL* ssl = ctx->createSSL();
1109   SCOPE_EXIT { SSL_free(ssl); };
1110   AsyncSSLSocket::UniquePtr sock(
1111       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1112   sock->enableClientHelloParsing();
1113
1114   // Test parsing with two packets with first packet size < 3
1115   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1116   AsyncSSLSocket::clientHelloParsingCallback(
1117       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1118       ssl, sock.get());
1119   bufCopy.reset();
1120   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1121   AsyncSSLSocket::clientHelloParsingCallback(
1122       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1123       ssl, sock.get());
1124   bufCopy.reset();
1125
1126   auto parsedClientHello = sock->getClientHelloInfo();
1127   EXPECT_TRUE(parsedClientHello != nullptr);
1128   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1129   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1130 }
1131
1132 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1133   EventBase eventBase;
1134   auto ctx = std::make_shared<SSLContext>();
1135
1136   int fds[2];
1137   getfds(fds);
1138
1139   int bufLen = 42;
1140   uint8_t majorVersion = 18;
1141   uint8_t minorVersion = 25;
1142
1143   // Create callback buf
1144   auto buf = IOBuf::create(bufLen);
1145   buf->append(bufLen);
1146   folly::io::RWPrivateCursor cursor(buf.get());
1147   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1148   cursor.write<uint16_t>(0);
1149   cursor.write<uint8_t>(38);
1150   cursor.write<uint8_t>(majorVersion);
1151   cursor.write<uint8_t>(minorVersion);
1152   cursor.skip(32);
1153   cursor.write<uint32_t>(0);
1154
1155   SSL* ssl = ctx->createSSL();
1156   SCOPE_EXIT { SSL_free(ssl); };
1157   AsyncSSLSocket::UniquePtr sock(
1158       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1159   sock->enableClientHelloParsing();
1160
1161   // Test parsing with multiple small packets
1162   for (uint64_t i = 0; i < buf->length(); i += 3) {
1163     auto bufCopy = folly::IOBuf::copyBuffer(
1164         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1165     AsyncSSLSocket::clientHelloParsingCallback(
1166         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1167         ssl, sock.get());
1168     bufCopy.reset();
1169   }
1170
1171   auto parsedClientHello = sock->getClientHelloInfo();
1172   EXPECT_TRUE(parsedClientHello != nullptr);
1173   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1174   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1175 }
1176
1177 /**
1178  * Verify sucessful behavior of SSL certificate validation.
1179  */
1180 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1181   EventBase eventBase;
1182   auto clientCtx = std::make_shared<SSLContext>();
1183   auto dfServerCtx = std::make_shared<SSLContext>();
1184
1185   int fds[2];
1186   getfds(fds);
1187   getctx(clientCtx, dfServerCtx);
1188
1189   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1190   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1191
1192   AsyncSSLSocket::UniquePtr clientSock(
1193     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1194   AsyncSSLSocket::UniquePtr serverSock(
1195     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1196
1197   SSLHandshakeClient client(std::move(clientSock), true, true);
1198   clientCtx->loadTrustedCertificates(testCA);
1199
1200   SSLHandshakeServer server(std::move(serverSock), true, true);
1201
1202   eventBase.loop();
1203
1204   EXPECT_TRUE(client.handshakeVerify_);
1205   EXPECT_TRUE(client.handshakeSuccess_);
1206   EXPECT_TRUE(!client.handshakeError_);
1207   EXPECT_LE(0, client.handshakeTime.count());
1208   EXPECT_TRUE(!server.handshakeVerify_);
1209   EXPECT_TRUE(server.handshakeSuccess_);
1210   EXPECT_TRUE(!server.handshakeError_);
1211   EXPECT_LE(0, server.handshakeTime.count());
1212 }
1213
1214 /**
1215  * Verify that the client's verification callback is able to fail SSL
1216  * connection establishment.
1217  */
1218 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1219   EventBase eventBase;
1220   auto clientCtx = std::make_shared<SSLContext>();
1221   auto dfServerCtx = std::make_shared<SSLContext>();
1222
1223   int fds[2];
1224   getfds(fds);
1225   getctx(clientCtx, dfServerCtx);
1226
1227   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1228   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1229
1230   AsyncSSLSocket::UniquePtr clientSock(
1231     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1232   AsyncSSLSocket::UniquePtr serverSock(
1233     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1234
1235   SSLHandshakeClient client(std::move(clientSock), true, false);
1236   clientCtx->loadTrustedCertificates(testCA);
1237
1238   SSLHandshakeServer server(std::move(serverSock), true, true);
1239
1240   eventBase.loop();
1241
1242   EXPECT_TRUE(client.handshakeVerify_);
1243   EXPECT_TRUE(!client.handshakeSuccess_);
1244   EXPECT_TRUE(client.handshakeError_);
1245   EXPECT_LE(0, client.handshakeTime.count());
1246   EXPECT_TRUE(!server.handshakeVerify_);
1247   EXPECT_TRUE(!server.handshakeSuccess_);
1248   EXPECT_TRUE(server.handshakeError_);
1249   EXPECT_LE(0, server.handshakeTime.count());
1250 }
1251
1252 /**
1253  * Verify that the options in SSLContext can be overridden in
1254  * sslConnect/Accept.i.e specifying that no validation should be performed
1255  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1256  * the validation callback.
1257  */
1258 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1259   EventBase eventBase;
1260   auto clientCtx = std::make_shared<SSLContext>();
1261   auto dfServerCtx = std::make_shared<SSLContext>();
1262
1263   int fds[2];
1264   getfds(fds);
1265   getctx(clientCtx, dfServerCtx);
1266
1267   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1268   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1269
1270   AsyncSSLSocket::UniquePtr clientSock(
1271     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1272   AsyncSSLSocket::UniquePtr serverSock(
1273     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1274
1275   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1276   clientCtx->loadTrustedCertificates(testCA);
1277
1278   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1279
1280   eventBase.loop();
1281
1282   EXPECT_TRUE(!client.handshakeVerify_);
1283   EXPECT_TRUE(client.handshakeSuccess_);
1284   EXPECT_TRUE(!client.handshakeError_);
1285   EXPECT_LE(0, client.handshakeTime.count());
1286   EXPECT_TRUE(!server.handshakeVerify_);
1287   EXPECT_TRUE(server.handshakeSuccess_);
1288   EXPECT_TRUE(!server.handshakeError_);
1289   EXPECT_LE(0, server.handshakeTime.count());
1290 }
1291
1292 /**
1293  * Verify that the options in SSLContext can be overridden in
1294  * sslConnect/Accept. Enable verification even if context says otherwise.
1295  * Test requireClientCert with client cert
1296  */
1297 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1298   EventBase eventBase;
1299   auto clientCtx = std::make_shared<SSLContext>();
1300   auto serverCtx = std::make_shared<SSLContext>();
1301   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1302   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1303   serverCtx->loadPrivateKey(testKey);
1304   serverCtx->loadCertificate(testCert);
1305   serverCtx->loadTrustedCertificates(testCA);
1306   serverCtx->loadClientCAList(testCA);
1307
1308   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1309   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1310   clientCtx->loadPrivateKey(testKey);
1311   clientCtx->loadCertificate(testCert);
1312   clientCtx->loadTrustedCertificates(testCA);
1313
1314   int fds[2];
1315   getfds(fds);
1316
1317   AsyncSSLSocket::UniquePtr clientSock(
1318       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1319   AsyncSSLSocket::UniquePtr serverSock(
1320       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1321
1322   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1323   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1324
1325   eventBase.loop();
1326
1327   EXPECT_TRUE(client.handshakeVerify_);
1328   EXPECT_TRUE(client.handshakeSuccess_);
1329   EXPECT_FALSE(client.handshakeError_);
1330   EXPECT_LE(0, client.handshakeTime.count());
1331   EXPECT_TRUE(server.handshakeVerify_);
1332   EXPECT_TRUE(server.handshakeSuccess_);
1333   EXPECT_FALSE(server.handshakeError_);
1334   EXPECT_LE(0, server.handshakeTime.count());
1335 }
1336
1337 /**
1338  * Verify that the client's verification callback is able to override
1339  * the preverification failure and allow a successful connection.
1340  */
1341 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1342   EventBase eventBase;
1343   auto clientCtx = std::make_shared<SSLContext>();
1344   auto dfServerCtx = std::make_shared<SSLContext>();
1345
1346   int fds[2];
1347   getfds(fds);
1348   getctx(clientCtx, dfServerCtx);
1349
1350   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1351   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1352
1353   AsyncSSLSocket::UniquePtr clientSock(
1354     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1355   AsyncSSLSocket::UniquePtr serverSock(
1356     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1357
1358   SSLHandshakeClient client(std::move(clientSock), false, true);
1359   SSLHandshakeServer server(std::move(serverSock), true, true);
1360
1361   eventBase.loop();
1362
1363   EXPECT_TRUE(client.handshakeVerify_);
1364   EXPECT_TRUE(client.handshakeSuccess_);
1365   EXPECT_TRUE(!client.handshakeError_);
1366   EXPECT_LE(0, client.handshakeTime.count());
1367   EXPECT_TRUE(!server.handshakeVerify_);
1368   EXPECT_TRUE(server.handshakeSuccess_);
1369   EXPECT_TRUE(!server.handshakeError_);
1370   EXPECT_LE(0, server.handshakeTime.count());
1371 }
1372
1373 /**
1374  * Verify that specifying that no validation should be performed allows an
1375  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1376  * callback.
1377  */
1378 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1379   EventBase eventBase;
1380   auto clientCtx = std::make_shared<SSLContext>();
1381   auto dfServerCtx = std::make_shared<SSLContext>();
1382
1383   int fds[2];
1384   getfds(fds);
1385   getctx(clientCtx, dfServerCtx);
1386
1387   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1388   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1389
1390   AsyncSSLSocket::UniquePtr clientSock(
1391     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1392   AsyncSSLSocket::UniquePtr serverSock(
1393     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1394
1395   SSLHandshakeClient client(std::move(clientSock), false, false);
1396   SSLHandshakeServer server(std::move(serverSock), false, false);
1397
1398   eventBase.loop();
1399
1400   EXPECT_TRUE(!client.handshakeVerify_);
1401   EXPECT_TRUE(client.handshakeSuccess_);
1402   EXPECT_TRUE(!client.handshakeError_);
1403   EXPECT_LE(0, client.handshakeTime.count());
1404   EXPECT_TRUE(!server.handshakeVerify_);
1405   EXPECT_TRUE(server.handshakeSuccess_);
1406   EXPECT_TRUE(!server.handshakeError_);
1407   EXPECT_LE(0, server.handshakeTime.count());
1408 }
1409
1410 /**
1411  * Test requireClientCert with client cert
1412  */
1413 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1414   EventBase eventBase;
1415   auto clientCtx = std::make_shared<SSLContext>();
1416   auto serverCtx = std::make_shared<SSLContext>();
1417   serverCtx->setVerificationOption(
1418       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1419   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1420   serverCtx->loadPrivateKey(testKey);
1421   serverCtx->loadCertificate(testCert);
1422   serverCtx->loadTrustedCertificates(testCA);
1423   serverCtx->loadClientCAList(testCA);
1424
1425   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1426   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1427   clientCtx->loadPrivateKey(testKey);
1428   clientCtx->loadCertificate(testCert);
1429   clientCtx->loadTrustedCertificates(testCA);
1430
1431   int fds[2];
1432   getfds(fds);
1433
1434   AsyncSSLSocket::UniquePtr clientSock(
1435       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1436   AsyncSSLSocket::UniquePtr serverSock(
1437       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1438
1439   SSLHandshakeClient client(std::move(clientSock), true, true);
1440   SSLHandshakeServer server(std::move(serverSock), true, true);
1441
1442   eventBase.loop();
1443
1444   EXPECT_TRUE(client.handshakeVerify_);
1445   EXPECT_TRUE(client.handshakeSuccess_);
1446   EXPECT_FALSE(client.handshakeError_);
1447   EXPECT_LE(0, client.handshakeTime.count());
1448   EXPECT_TRUE(server.handshakeVerify_);
1449   EXPECT_TRUE(server.handshakeSuccess_);
1450   EXPECT_FALSE(server.handshakeError_);
1451   EXPECT_LE(0, server.handshakeTime.count());
1452 }
1453
1454
1455 /**
1456  * Test requireClientCert with no client cert
1457  */
1458 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1459   EventBase eventBase;
1460   auto clientCtx = std::make_shared<SSLContext>();
1461   auto serverCtx = std::make_shared<SSLContext>();
1462   serverCtx->setVerificationOption(
1463       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1464   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1465   serverCtx->loadPrivateKey(testKey);
1466   serverCtx->loadCertificate(testCert);
1467   serverCtx->loadTrustedCertificates(testCA);
1468   serverCtx->loadClientCAList(testCA);
1469   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1470   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1471
1472   int fds[2];
1473   getfds(fds);
1474
1475   AsyncSSLSocket::UniquePtr clientSock(
1476       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1477   AsyncSSLSocket::UniquePtr serverSock(
1478       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1479
1480   SSLHandshakeClient client(std::move(clientSock), false, false);
1481   SSLHandshakeServer server(std::move(serverSock), false, false);
1482
1483   eventBase.loop();
1484
1485   EXPECT_FALSE(server.handshakeVerify_);
1486   EXPECT_FALSE(server.handshakeSuccess_);
1487   EXPECT_TRUE(server.handshakeError_);
1488   EXPECT_LE(0, client.handshakeTime.count());
1489   EXPECT_LE(0, server.handshakeTime.count());
1490 }
1491
1492 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1493   auto cert = getFileAsBuf(testCert);
1494   auto key = getFileAsBuf(testKey);
1495
1496   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1497   BIO_write(certBio.get(), cert.data(), cert.size());
1498   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1499   BIO_write(keyBio.get(), key.data(), key.size());
1500
1501   // Create SSL structs from buffers to get properties
1502   ssl::X509UniquePtr certStruct(
1503       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1504   ssl::EvpPkeyUniquePtr keyStruct(
1505       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1506   certBio = nullptr;
1507   keyBio = nullptr;
1508
1509   auto origCommonName = getCommonName(certStruct.get());
1510   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1511   certStruct = nullptr;
1512   keyStruct = nullptr;
1513
1514   auto ctx = std::make_shared<SSLContext>();
1515   ctx->loadPrivateKeyFromBufferPEM(key);
1516   ctx->loadCertificateFromBufferPEM(cert);
1517   ctx->loadTrustedCertificates(testCA);
1518
1519   ssl::SSLUniquePtr ssl(ctx->createSSL());
1520
1521   auto newCert = SSL_get_certificate(ssl.get());
1522   auto newKey = SSL_get_privatekey(ssl.get());
1523
1524   // Get properties from SSL struct
1525   auto newCommonName = getCommonName(newCert);
1526   auto newKeySize = EVP_PKEY_bits(newKey);
1527
1528   // Check that the key and cert have the expected properties
1529   EXPECT_EQ(origCommonName, newCommonName);
1530   EXPECT_EQ(origKeySize, newKeySize);
1531 }
1532
1533 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1534   EventBase eb;
1535
1536   // Set up SSL context.
1537   auto sslContext = std::make_shared<SSLContext>();
1538   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1539
1540   // create SSL socket
1541   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1542
1543   EXPECT_EQ(1500, socket->getMinWriteSize());
1544
1545   socket->setMinWriteSize(0);
1546   EXPECT_EQ(0, socket->getMinWriteSize());
1547   socket->setMinWriteSize(50000);
1548   EXPECT_EQ(50000, socket->getMinWriteSize());
1549 }
1550
1551 class ReadCallbackTerminator : public ReadCallback {
1552  public:
1553   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1554       : ReadCallback(wcb)
1555       , base_(base) {}
1556
1557   // Do not write data back, terminate the loop.
1558   void readDataAvailable(size_t len) noexcept override {
1559     std::cerr << "readDataAvailable, len " << len << std::endl;
1560
1561     currentBuffer.length = len;
1562
1563     buffers.push_back(currentBuffer);
1564     currentBuffer.reset();
1565     state = STATE_SUCCEEDED;
1566
1567     socket_->setReadCB(nullptr);
1568     base_->terminateLoopSoon();
1569   }
1570  private:
1571   EventBase* base_;
1572 };
1573
1574
1575 /**
1576  * Test a full unencrypted codepath
1577  */
1578 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1579   EventBase base;
1580
1581   auto clientCtx = std::make_shared<folly::SSLContext>();
1582   auto serverCtx = std::make_shared<folly::SSLContext>();
1583   int fds[2];
1584   getfds(fds);
1585   getctx(clientCtx, serverCtx);
1586   auto client = AsyncSSLSocket::newSocket(
1587                   clientCtx, &base, fds[0], false, true);
1588   auto server = AsyncSSLSocket::newSocket(
1589                   serverCtx, &base, fds[1], true, true);
1590
1591   ReadCallbackTerminator readCallback(&base, nullptr);
1592   server->setReadCB(&readCallback);
1593   readCallback.setSocket(server);
1594
1595   uint8_t buf[128];
1596   memset(buf, 'a', sizeof(buf));
1597   client->write(nullptr, buf, sizeof(buf));
1598
1599   // Check that bytes are unencrypted
1600   char c;
1601   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1602   EXPECT_EQ('a', c);
1603
1604   EventBaseAborter eba(&base, 3000);
1605   base.loop();
1606
1607   EXPECT_EQ(1, readCallback.buffers.size());
1608   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1609
1610   server->setReadCB(&readCallback);
1611
1612   // Unencrypted
1613   server->sslAccept(nullptr);
1614   client->sslConn(nullptr);
1615
1616   // Do NOT wait for handshake, writing should be queued and happen after
1617
1618   client->write(nullptr, buf, sizeof(buf));
1619
1620   // Check that bytes are *not* unencrypted
1621   char c2;
1622   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1623   EXPECT_NE('a', c2);
1624
1625
1626   base.loop();
1627
1628   EXPECT_EQ(2, readCallback.buffers.size());
1629   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1630 }
1631
1632 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1633   // Start listening on a local port
1634   WriteCallbackBase writeCallback;
1635   WriteErrorCallback readCallback(&writeCallback);
1636   HandshakeCallback handshakeCallback(&readCallback,
1637                                       HandshakeCallback::EXPECT_ERROR);
1638   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1639   TestSSLServer server(&acceptCallback);
1640
1641   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1642   socket->open();
1643   uint8_t buf[3] = {0x16, 0x03, 0x01};
1644   socket->write(buf, sizeof(buf));
1645   socket->closeWithReset();
1646
1647   handshakeCallback.waitForHandshake();
1648   EXPECT_NE(
1649       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1650   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1651 }
1652
1653 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1654   // Start listening on a local port
1655   WriteCallbackBase writeCallback;
1656   WriteErrorCallback readCallback(&writeCallback);
1657   HandshakeCallback handshakeCallback(&readCallback,
1658                                       HandshakeCallback::EXPECT_ERROR);
1659   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1660   TestSSLServer server(&acceptCallback);
1661
1662   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1663   socket->open();
1664   uint8_t buf[3] = {0x16, 0x03, 0x01};
1665   socket->write(buf, sizeof(buf));
1666   socket->close();
1667
1668   handshakeCallback.waitForHandshake();
1669   EXPECT_NE(
1670       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1671   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1672 }
1673
1674 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1675   // Start listening on a local port
1676   WriteCallbackBase writeCallback;
1677   WriteErrorCallback readCallback(&writeCallback);
1678   HandshakeCallback handshakeCallback(&readCallback,
1679                                       HandshakeCallback::EXPECT_ERROR);
1680   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1681   TestSSLServer server(&acceptCallback);
1682
1683   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1684   socket->open();
1685   uint8_t buf[256] = {0x16, 0x03};
1686   memset(buf + 2, 'a', sizeof(buf) - 2);
1687   socket->write(buf, sizeof(buf));
1688   socket->close();
1689
1690   handshakeCallback.waitForHandshake();
1691   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1692             std::string::npos);
1693   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1694             std::string::npos);
1695 }
1696
1697 #if FOLLY_ALLOW_TFO
1698
1699 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1700  public:
1701   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1702
1703   explicit MockAsyncTFOSSLSocket(
1704       std::shared_ptr<folly::SSLContext> sslCtx,
1705       EventBase* evb)
1706       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1707
1708   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1709 };
1710
1711 /**
1712  * Test connecting to, writing to, reading from, and closing the
1713  * connection to the SSL server with TFO.
1714  */
1715 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1716   // Start listening on a local port
1717   WriteCallbackBase writeCallback;
1718   ReadCallback readCallback(&writeCallback);
1719   HandshakeCallback handshakeCallback(&readCallback);
1720   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1721   TestSSLServer server(&acceptCallback, true);
1722
1723   // Set up SSL context.
1724   auto sslContext = std::make_shared<SSLContext>();
1725
1726   // connect
1727   auto socket =
1728       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1729   socket->enableTFO();
1730   socket->open();
1731
1732   // write()
1733   std::array<uint8_t, 128> buf;
1734   memset(buf.data(), 'a', buf.size());
1735   socket->write(buf.data(), buf.size());
1736
1737   // read()
1738   std::array<uint8_t, 128> readbuf;
1739   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1740   EXPECT_EQ(bytesRead, 128);
1741   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1742
1743   // close()
1744   socket->close();
1745 }
1746
1747 /**
1748  * Test connecting to, writing to, reading from, and closing the
1749  * connection to the SSL server with TFO.
1750  */
1751 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1752   // Start listening on a local port
1753   WriteCallbackBase writeCallback;
1754   ReadCallback readCallback(&writeCallback);
1755   HandshakeCallback handshakeCallback(&readCallback);
1756   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1757   TestSSLServer server(&acceptCallback, false);
1758
1759   // Set up SSL context.
1760   auto sslContext = std::make_shared<SSLContext>();
1761
1762   // connect
1763   auto socket =
1764       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1765   socket->enableTFO();
1766   socket->open();
1767
1768   // write()
1769   std::array<uint8_t, 128> buf;
1770   memset(buf.data(), 'a', buf.size());
1771   socket->write(buf.data(), buf.size());
1772
1773   // read()
1774   std::array<uint8_t, 128> readbuf;
1775   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1776   EXPECT_EQ(bytesRead, 128);
1777   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1778
1779   // close()
1780   socket->close();
1781 }
1782
1783 class ConnCallback : public AsyncSocket::ConnectCallback {
1784  public:
1785   virtual void connectSuccess() noexcept override {
1786     state = State::SUCCESS;
1787   }
1788
1789   virtual void connectErr(const AsyncSocketException&) noexcept override {
1790     state = State::ERROR;
1791   }
1792
1793   enum class State { WAITING, SUCCESS, ERROR };
1794
1795   State state{State::WAITING};
1796 };
1797
1798 template <class Cardinality>
1799 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1800     EventBase* evb,
1801     const SocketAddress& address,
1802     Cardinality cardinality) {
1803   // Set up SSL context.
1804   auto sslContext = std::make_shared<SSLContext>();
1805
1806   // connect
1807   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1808       new MockAsyncTFOSSLSocket(sslContext, evb));
1809   socket->enableTFO();
1810
1811   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1812       .Times(cardinality)
1813       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1814         sockaddr_storage addr;
1815         auto len = address.getAddress(&addr);
1816         return connect(fd, (const struct sockaddr*)&addr, len);
1817       }));
1818   return socket;
1819 }
1820
1821 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1822   // Start listening on a local port
1823   WriteCallbackBase writeCallback;
1824   ReadCallback readCallback(&writeCallback);
1825   HandshakeCallback handshakeCallback(&readCallback);
1826   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1827   TestSSLServer server(&acceptCallback, true);
1828
1829   EventBase evb;
1830
1831   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1832   ConnCallback ccb;
1833   socket->connect(&ccb, server.getAddress(), 30);
1834
1835   evb.loop();
1836   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1837
1838   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1839   evb.loop();
1840
1841   BlockingSocket sock(std::move(socket));
1842   // write()
1843   std::array<uint8_t, 128> buf;
1844   memset(buf.data(), 'a', buf.size());
1845   sock.write(buf.data(), buf.size());
1846
1847   // read()
1848   std::array<uint8_t, 128> readbuf;
1849   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1850   EXPECT_EQ(bytesRead, 128);
1851   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1852
1853   // close()
1854   sock.close();
1855 }
1856
1857 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1858   // Start listening on a local port
1859   ConnectTimeoutCallback acceptCallback;
1860   TestSSLServer server(&acceptCallback, true);
1861
1862   // Set up SSL context.
1863   auto sslContext = std::make_shared<SSLContext>();
1864
1865   // connect
1866   auto socket =
1867       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1868   socket->enableTFO();
1869   EXPECT_THROW(
1870       socket->open(std::chrono::milliseconds(1)), AsyncSocketException);
1871 }
1872
1873 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1874   // Start listening on a local port
1875   ConnectTimeoutCallback acceptCallback;
1876   TestSSLServer server(&acceptCallback, true);
1877
1878   EventBase evb;
1879
1880   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1881   ConnCallback ccb;
1882   // Set a short timeout
1883   socket->connect(&ccb, server.getAddress(), 1);
1884
1885   evb.loop();
1886   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1887 }
1888
1889 #endif
1890
1891 } // namespace
1892
1893 ///////////////////////////////////////////////////////////////////////////
1894 // init_unit_test_suite
1895 ///////////////////////////////////////////////////////////////////////////
1896 namespace {
1897 struct Initializer {
1898   Initializer() {
1899     signal(SIGPIPE, SIG_IGN);
1900   }
1901 };
1902 Initializer initializer;
1903 } // anonymous