Add TFO support to AsyncSSLSocket
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <pthread.h>
19 #include <signal.h>
20
21 #include <folly/SocketAddress.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/EventBase.h>
24 #include <folly/portability/Sockets.h>
25 #include <folly/portability/Unistd.h>
26
27 #include <folly/io/async/test/BlockingSocket.h>
28
29 #include <fcntl.h>
30 #include <folly/io/Cursor.h>
31 #include <gtest/gtest.h>
32 #include <openssl/bio.h>
33 #include <sys/types.h>
34 #include <fstream>
35 #include <iostream>
36 #include <list>
37 #include <set>
38
39 #include <gmock/gmock.h>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 using namespace testing;
49
50 namespace folly {
51 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
52 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
53 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
54
55 const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
56 const char* testKey = "folly/io/async/test/certs/tests-key.pem";
57 const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
58
59 constexpr size_t SSLClient::kMaxReadBufferSz;
60 constexpr size_t SSLClient::kMaxReadsPerEvent;
61
62 TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase* acb, bool enableTFO)
63     : ctx_(new folly::SSLContext),
64       acb_(acb),
65       socket_(folly::AsyncServerSocket::newSocket(&evb_)) {
66   // Set up the SSL context
67   ctx_->loadCertificate(testCert);
68   ctx_->loadPrivateKey(testKey);
69   ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
70
71   acb_->ctx_ = ctx_;
72   acb_->base_ = &evb_;
73
74   // Enable TFO
75   if (enableTFO) {
76     LOG(INFO) << "server TFO enabled";
77     socket_->setTFOEnabled(true, 1000);
78   }
79
80   // set up the listening socket
81   socket_->bind(0);
82   socket_->getAddress(&address_);
83   socket_->listen(100);
84   socket_->addAcceptCallback(acb_, &evb_);
85   socket_->startAccepting();
86
87   int ret = pthread_create(&thread_, nullptr, Main, this);
88   assert(ret == 0);
89   (void)ret;
90
91   std::cerr << "Accepting connections on " << address_ << std::endl;
92 }
93
94 void getfds(int fds[2]) {
95   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
96     FAIL() << "failed to create socketpair: " << strerror(errno);
97   }
98   for (int idx = 0; idx < 2; ++idx) {
99     int flags = fcntl(fds[idx], F_GETFL, 0);
100     if (flags == -1) {
101       FAIL() << "failed to get flags for socket " << idx << ": "
102              << strerror(errno);
103     }
104     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
105       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
106              << strerror(errno);
107     }
108   }
109 }
110
111 void getctx(
112   std::shared_ptr<folly::SSLContext> clientCtx,
113   std::shared_ptr<folly::SSLContext> serverCtx) {
114   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
115
116   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
117   serverCtx->loadCertificate(
118       testCert);
119   serverCtx->loadPrivateKey(
120       testKey);
121 }
122
123 void sslsocketpair(
124   EventBase* eventBase,
125   AsyncSSLSocket::UniquePtr* clientSock,
126   AsyncSSLSocket::UniquePtr* serverSock) {
127   auto clientCtx = std::make_shared<folly::SSLContext>();
128   auto serverCtx = std::make_shared<folly::SSLContext>();
129   int fds[2];
130   getfds(fds);
131   getctx(clientCtx, serverCtx);
132   clientSock->reset(new AsyncSSLSocket(
133                       clientCtx, eventBase, fds[0], false));
134   serverSock->reset(new AsyncSSLSocket(
135                       serverCtx, eventBase, fds[1], true));
136
137   // (*clientSock)->setSendTimeout(100);
138   // (*serverSock)->setSendTimeout(100);
139 }
140
141 // client protocol filters
142 bool clientProtoFilterPickPony(unsigned char** client,
143   unsigned int* client_len, const unsigned char*, unsigned int ) {
144   //the protocol string in length prefixed byte string. the
145   //length byte is not included in the length
146   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
147   *client = p;
148   *client_len = 7;
149   return true;
150 }
151
152 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
153   const unsigned char*, unsigned int) {
154   return false;
155 }
156
157 std::string getFileAsBuf(const char* fileName) {
158   std::string buffer;
159   folly::readFile(fileName, buffer);
160   return buffer;
161 }
162
163 std::string getCommonName(X509* cert) {
164   X509_NAME* subject = X509_get_subject_name(cert);
165   std::string cn;
166   cn.resize(ub_common_name);
167   X509_NAME_get_text_by_NID(
168       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
169   return cn;
170 }
171
172 /**
173  * Test connecting to, writing to, reading from, and closing the
174  * connection to the SSL server.
175  */
176 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
177   // Start listening on a local port
178   WriteCallbackBase writeCallback;
179   ReadCallback readCallback(&writeCallback);
180   HandshakeCallback handshakeCallback(&readCallback);
181   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
182   TestSSLServer server(&acceptCallback);
183
184   // Set up SSL context.
185   std::shared_ptr<SSLContext> sslContext(new SSLContext());
186   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
187   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
188   //sslContext->authenticate(true, false);
189
190   // connect
191   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
192                                                  sslContext);
193   socket->open();
194
195   // write()
196   uint8_t buf[128];
197   memset(buf, 'a', sizeof(buf));
198   socket->write(buf, sizeof(buf));
199
200   // read()
201   uint8_t readbuf[128];
202   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
203   EXPECT_EQ(bytesRead, 128);
204   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
205
206   // close()
207   socket->close();
208
209   cerr << "ConnectWriteReadClose test completed" << endl;
210 }
211
212 /**
213  * Test reading after server close.
214  */
215 TEST(AsyncSSLSocketTest, ReadAfterClose) {
216   // Start listening on a local port
217   WriteCallbackBase writeCallback;
218   ReadEOFCallback readCallback(&writeCallback);
219   HandshakeCallback handshakeCallback(&readCallback);
220   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
221   auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
222
223   // Set up SSL context.
224   auto sslContext = std::make_shared<SSLContext>();
225   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
226
227   auto socket =
228       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
229   socket->open();
230
231   // This should trigger an EOF on the client.
232   auto evb = handshakeCallback.getSocket()->getEventBase();
233   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
234   std::array<uint8_t, 128> readbuf;
235   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
236   EXPECT_EQ(0, bytesRead);
237 }
238
239 /**
240  * Test bad renegotiation
241  */
242 TEST(AsyncSSLSocketTest, Renegotiate) {
243   EventBase eventBase;
244   auto clientCtx = std::make_shared<SSLContext>();
245   auto dfServerCtx = std::make_shared<SSLContext>();
246   std::array<int, 2> fds;
247   getfds(fds.data());
248   getctx(clientCtx, dfServerCtx);
249
250   AsyncSSLSocket::UniquePtr clientSock(
251       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
252   AsyncSSLSocket::UniquePtr serverSock(
253       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
254   SSLHandshakeClient client(std::move(clientSock), true, true);
255   RenegotiatingServer server(std::move(serverSock));
256
257   while (!client.handshakeSuccess_ && !client.handshakeError_) {
258     eventBase.loopOnce();
259   }
260
261   ASSERT_TRUE(client.handshakeSuccess_);
262
263   auto sslSock = std::move(client).moveSocket();
264   sslSock->detachEventBase();
265   // This is nasty, however we don't want to add support for
266   // renegotiation in AsyncSSLSocket.
267   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
268
269   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
270
271   std::thread t([&]() { eventBase.loopForever(); });
272
273   // Trigger the renegotiation.
274   std::array<uint8_t, 128> buf;
275   memset(buf.data(), 'a', buf.size());
276   try {
277     socket->write(buf.data(), buf.size());
278   } catch (AsyncSocketException& e) {
279     LOG(INFO) << "client got error " << e.what();
280   }
281   eventBase.terminateLoopSoon();
282   t.join();
283
284   eventBase.loop();
285   ASSERT_TRUE(server.renegotiationError_);
286 }
287
288 /**
289  * Negative test for handshakeError().
290  */
291 TEST(AsyncSSLSocketTest, HandshakeError) {
292   // Start listening on a local port
293   WriteCallbackBase writeCallback;
294   WriteErrorCallback readCallback(&writeCallback);
295   HandshakeCallback handshakeCallback(&readCallback);
296   HandshakeErrorCallback acceptCallback(&handshakeCallback);
297   TestSSLServer server(&acceptCallback);
298
299   // Set up SSL context.
300   std::shared_ptr<SSLContext> sslContext(new SSLContext());
301   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
302
303   // connect
304   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
305                                                  sslContext);
306   // read()
307   bool ex = false;
308   try {
309     socket->open();
310
311     uint8_t readbuf[128];
312     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
313     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
314   } catch (AsyncSocketException &e) {
315     ex = true;
316   }
317   EXPECT_TRUE(ex);
318
319   // close()
320   socket->close();
321   cerr << "HandshakeError test completed" << endl;
322 }
323
324 /**
325  * Negative test for readError().
326  */
327 TEST(AsyncSSLSocketTest, ReadError) {
328   // Start listening on a local port
329   WriteCallbackBase writeCallback;
330   ReadErrorCallback readCallback(&writeCallback);
331   HandshakeCallback handshakeCallback(&readCallback);
332   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
333   TestSSLServer server(&acceptCallback);
334
335   // Set up SSL context.
336   std::shared_ptr<SSLContext> sslContext(new SSLContext());
337   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
338
339   // connect
340   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
341                                                  sslContext);
342   socket->open();
343
344   // write something to trigger ssl handshake
345   uint8_t buf[128];
346   memset(buf, 'a', sizeof(buf));
347   socket->write(buf, sizeof(buf));
348
349   socket->close();
350   cerr << "ReadError test completed" << endl;
351 }
352
353 /**
354  * Negative test for writeError().
355  */
356 TEST(AsyncSSLSocketTest, WriteError) {
357   // Start listening on a local port
358   WriteCallbackBase writeCallback;
359   WriteErrorCallback readCallback(&writeCallback);
360   HandshakeCallback handshakeCallback(&readCallback);
361   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
362   TestSSLServer server(&acceptCallback);
363
364   // Set up SSL context.
365   std::shared_ptr<SSLContext> sslContext(new SSLContext());
366   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
367
368   // connect
369   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
370                                                  sslContext);
371   socket->open();
372
373   // write something to trigger ssl handshake
374   uint8_t buf[128];
375   memset(buf, 'a', sizeof(buf));
376   socket->write(buf, sizeof(buf));
377
378   socket->close();
379   cerr << "WriteError test completed" << endl;
380 }
381
382 /**
383  * Test a socket with TCP_NODELAY unset.
384  */
385 TEST(AsyncSSLSocketTest, SocketWithDelay) {
386   // Start listening on a local port
387   WriteCallbackBase writeCallback;
388   ReadCallback readCallback(&writeCallback);
389   HandshakeCallback handshakeCallback(&readCallback);
390   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
391   TestSSLServer server(&acceptCallback);
392
393   // Set up SSL context.
394   std::shared_ptr<SSLContext> sslContext(new SSLContext());
395   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
396
397   // connect
398   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
399                                                  sslContext);
400   socket->open();
401
402   // write()
403   uint8_t buf[128];
404   memset(buf, 'a', sizeof(buf));
405   socket->write(buf, sizeof(buf));
406
407   // read()
408   uint8_t readbuf[128];
409   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
410   EXPECT_EQ(bytesRead, 128);
411   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
412
413   // close()
414   socket->close();
415
416   cerr << "SocketWithDelay test completed" << endl;
417 }
418
419 using NextProtocolTypePair =
420     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
421
422 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
423   // For matching protos
424  public:
425   void SetUp() override { getctx(clientCtx, serverCtx); }
426
427   void connect(bool unset = false) {
428     getfds(fds);
429
430     if (unset) {
431       // unsetting NPN for any of [client, server] is enough to make NPN not
432       // work
433       clientCtx->unsetNextProtocols();
434     }
435
436     AsyncSSLSocket::UniquePtr clientSock(
437       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
438     AsyncSSLSocket::UniquePtr serverSock(
439       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
440     client = folly::make_unique<NpnClient>(std::move(clientSock));
441     server = folly::make_unique<NpnServer>(std::move(serverSock));
442
443     eventBase.loop();
444   }
445
446   void expectProtocol(const std::string& proto) {
447     EXPECT_NE(client->nextProtoLength, 0);
448     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
449     EXPECT_EQ(
450         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
451         0);
452     string selected((const char*)client->nextProto, client->nextProtoLength);
453     EXPECT_EQ(proto, selected);
454   }
455
456   void expectNoProtocol() {
457     EXPECT_EQ(client->nextProtoLength, 0);
458     EXPECT_EQ(server->nextProtoLength, 0);
459     EXPECT_EQ(client->nextProto, nullptr);
460     EXPECT_EQ(server->nextProto, nullptr);
461   }
462
463   void expectProtocolType() {
464     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
465         GetParam().second == SSLContext::NextProtocolType::ANY) {
466       EXPECT_EQ(client->protocolType, server->protocolType);
467     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
468                GetParam().second == SSLContext::NextProtocolType::ANY) {
469       // Well not much we can say
470     } else {
471       expectProtocolType(GetParam());
472     }
473   }
474
475   void expectProtocolType(NextProtocolTypePair expected) {
476     EXPECT_EQ(client->protocolType, expected.first);
477     EXPECT_EQ(server->protocolType, expected.second);
478   }
479
480   EventBase eventBase;
481   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
482   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
483   int fds[2];
484   std::unique_ptr<NpnClient> client;
485   std::unique_ptr<NpnServer> server;
486 };
487
488 class NextProtocolNPNOnlyTest : public NextProtocolTest {
489   // For mismatching protos
490 };
491
492 class NextProtocolMismatchTest : public NextProtocolTest {
493   // For mismatching protos
494 };
495
496 TEST_P(NextProtocolTest, NpnTestOverlap) {
497   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
498   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
499                                         GetParam().second);
500
501   connect();
502
503   expectProtocol("baz");
504   expectProtocolType();
505 }
506
507 TEST_P(NextProtocolTest, NpnTestUnset) {
508   // Identical to above test, except that we want unset NPN before
509   // looping.
510   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
511   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
512                                         GetParam().second);
513
514   connect(true /* unset */);
515
516   // if alpn negotiation fails, type will appear as npn
517   expectNoProtocol();
518   EXPECT_EQ(client->protocolType, server->protocolType);
519 }
520
521 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
522   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
523   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
524                                         GetParam().second);
525
526   connect();
527
528   expectNoProtocol();
529   expectProtocolType(
530       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
531 }
532
533 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
534 // will fail on 1.0.2 before that.
535 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
536   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
537   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
538                                         GetParam().second);
539
540   connect();
541
542   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
543       GetParam().second == SSLContext::NextProtocolType::ALPN) {
544     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
545     // mismatch should result in a fatal alert, but this is OpenSSL's current
546     // behavior and we want to know if it changes.
547     expectNoProtocol();
548   } else {
549     expectProtocol("blub");
550     expectProtocolType(
551         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
552   }
553 }
554
555 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
556   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
557   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
558   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
559                                         GetParam().second);
560
561   connect();
562
563   expectProtocol("ponies");
564   expectProtocolType();
565 }
566
567 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
568   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
569   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
570   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
571                                         GetParam().second);
572
573   connect();
574
575   expectProtocol("blub");
576   expectProtocolType();
577 }
578
579 TEST_P(NextProtocolTest, RandomizedNpnTest) {
580   // Probability that this test will fail is 2^-64, which could be considered
581   // as negligible.
582   const int kTries = 64;
583
584   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
585                                         GetParam().first);
586   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
587                                                   GetParam().second);
588
589   std::set<string> selectedProtocols;
590   for (int i = 0; i < kTries; ++i) {
591     connect();
592
593     EXPECT_NE(client->nextProtoLength, 0);
594     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
595     EXPECT_EQ(
596         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
597         0);
598     string selected((const char*)client->nextProto, client->nextProtoLength);
599     selectedProtocols.insert(selected);
600     expectProtocolType();
601   }
602   EXPECT_EQ(selectedProtocols.size(), 2);
603 }
604
605 INSTANTIATE_TEST_CASE_P(
606     AsyncSSLSocketTest,
607     NextProtocolTest,
608     ::testing::Values(
609         NextProtocolTypePair(
610             SSLContext::NextProtocolType::NPN,
611             SSLContext::NextProtocolType::NPN),
612 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
613         NextProtocolTypePair(
614             SSLContext::NextProtocolType::ALPN,
615             SSLContext::NextProtocolType::ALPN),
616         NextProtocolTypePair(
617             SSLContext::NextProtocolType::ALPN,
618             SSLContext::NextProtocolType::ANY),
619         NextProtocolTypePair(
620             SSLContext::NextProtocolType::ANY,
621             SSLContext::NextProtocolType::ALPN),
622 #endif
623         NextProtocolTypePair(
624             SSLContext::NextProtocolType::NPN,
625             SSLContext::NextProtocolType::ANY),
626         NextProtocolTypePair(
627             SSLContext::NextProtocolType::ANY,
628             SSLContext::NextProtocolType::ANY)));
629
630 INSTANTIATE_TEST_CASE_P(
631     AsyncSSLSocketTest,
632     NextProtocolNPNOnlyTest,
633     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
634                                            SSLContext::NextProtocolType::NPN)));
635
636 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
637 INSTANTIATE_TEST_CASE_P(
638     AsyncSSLSocketTest,
639     NextProtocolMismatchTest,
640     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
641                                            SSLContext::NextProtocolType::ALPN),
642                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
643                                            SSLContext::NextProtocolType::NPN)));
644 #endif
645
646 #ifndef OPENSSL_NO_TLSEXT
647 /**
648  * 1. Client sends TLSEXT_HOSTNAME in client hello.
649  * 2. Server found a match SSL_CTX and use this SSL_CTX to
650  *    continue the SSL handshake.
651  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
652  */
653 TEST(AsyncSSLSocketTest, SNITestMatch) {
654   EventBase eventBase;
655   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
656   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
657   // Use the same SSLContext to continue the handshake after
658   // tlsext_hostname match.
659   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
660   const std::string serverName("xyz.newdev.facebook.com");
661   int fds[2];
662   getfds(fds);
663   getctx(clientCtx, dfServerCtx);
664
665   AsyncSSLSocket::UniquePtr clientSock(
666     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
667   AsyncSSLSocket::UniquePtr serverSock(
668     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
669   SNIClient client(std::move(clientSock));
670   SNIServer server(std::move(serverSock),
671                    dfServerCtx,
672                    hskServerCtx,
673                    serverName);
674
675   eventBase.loop();
676
677   EXPECT_TRUE(client.serverNameMatch);
678   EXPECT_TRUE(server.serverNameMatch);
679 }
680
681 /**
682  * 1. Client sends TLSEXT_HOSTNAME in client hello.
683  * 2. Server cannot find a matching SSL_CTX and continue to use
684  *    the current SSL_CTX to do the handshake.
685  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
686  */
687 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
688   EventBase eventBase;
689   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
690   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
691   // Use the same SSLContext to continue the handshake after
692   // tlsext_hostname match.
693   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
694   const std::string clientRequestingServerName("foo.com");
695   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
696
697   int fds[2];
698   getfds(fds);
699   getctx(clientCtx, dfServerCtx);
700
701   AsyncSSLSocket::UniquePtr clientSock(
702     new AsyncSSLSocket(clientCtx,
703                         &eventBase,
704                         fds[0],
705                         clientRequestingServerName));
706   AsyncSSLSocket::UniquePtr serverSock(
707     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
708   SNIClient client(std::move(clientSock));
709   SNIServer server(std::move(serverSock),
710                    dfServerCtx,
711                    hskServerCtx,
712                    serverExpectedServerName);
713
714   eventBase.loop();
715
716   EXPECT_TRUE(!client.serverNameMatch);
717   EXPECT_TRUE(!server.serverNameMatch);
718 }
719 /**
720  * 1. Client sends TLSEXT_HOSTNAME in client hello.
721  * 2. We then change the serverName.
722  * 3. We expect that we get 'false' as the result for serNameMatch.
723  */
724
725 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
726    EventBase eventBase;
727   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
728   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
729   // Use the same SSLContext to continue the handshake after
730   // tlsext_hostname match.
731   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
732   const std::string serverName("xyz.newdev.facebook.com");
733   int fds[2];
734   getfds(fds);
735   getctx(clientCtx, dfServerCtx);
736
737   AsyncSSLSocket::UniquePtr clientSock(
738     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
739   //Change the server name
740   std::string newName("new.com");
741   clientSock->setServerName(newName);
742   AsyncSSLSocket::UniquePtr serverSock(
743     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
744   SNIClient client(std::move(clientSock));
745   SNIServer server(std::move(serverSock),
746                    dfServerCtx,
747                    hskServerCtx,
748                    serverName);
749
750   eventBase.loop();
751
752   EXPECT_TRUE(!client.serverNameMatch);
753 }
754
755 /**
756  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
757  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
758  */
759 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
760   EventBase eventBase;
761   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
762   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
763   // Use the same SSLContext to continue the handshake after
764   // tlsext_hostname match.
765   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
766   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
767
768   int fds[2];
769   getfds(fds);
770   getctx(clientCtx, dfServerCtx);
771
772   AsyncSSLSocket::UniquePtr clientSock(
773     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
774   AsyncSSLSocket::UniquePtr serverSock(
775     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
776   SNIClient client(std::move(clientSock));
777   SNIServer server(std::move(serverSock),
778                    dfServerCtx,
779                    hskServerCtx,
780                    serverExpectedServerName);
781
782   eventBase.loop();
783
784   EXPECT_TRUE(!client.serverNameMatch);
785   EXPECT_TRUE(!server.serverNameMatch);
786 }
787
788 #endif
789 /**
790  * Test SSL client socket
791  */
792 TEST(AsyncSSLSocketTest, SSLClientTest) {
793   // Start listening on a local port
794   WriteCallbackBase writeCallback;
795   ReadCallback readCallback(&writeCallback);
796   HandshakeCallback handshakeCallback(&readCallback);
797   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
798   TestSSLServer server(&acceptCallback);
799
800   // Set up SSL client
801   EventBase eventBase;
802   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
803
804   client->connect();
805   EventBaseAborter eba(&eventBase, 3000);
806   eventBase.loop();
807
808   EXPECT_EQ(client->getMiss(), 1);
809   EXPECT_EQ(client->getHit(), 0);
810
811   cerr << "SSLClientTest test completed" << endl;
812 }
813
814
815 /**
816  * Test SSL client socket session re-use
817  */
818 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
819   // Start listening on a local port
820   WriteCallbackBase writeCallback;
821   ReadCallback readCallback(&writeCallback);
822   HandshakeCallback handshakeCallback(&readCallback);
823   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
824   TestSSLServer server(&acceptCallback);
825
826   // Set up SSL client
827   EventBase eventBase;
828   auto client =
829       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
830
831   client->connect();
832   EventBaseAborter eba(&eventBase, 3000);
833   eventBase.loop();
834
835   EXPECT_EQ(client->getMiss(), 1);
836   EXPECT_EQ(client->getHit(), 9);
837
838   cerr << "SSLClientTestReuse test completed" << endl;
839 }
840
841 /**
842  * Test SSL client socket timeout
843  */
844 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
845   // Start listening on a local port
846   EmptyReadCallback readCallback;
847   HandshakeCallback handshakeCallback(&readCallback,
848                                       HandshakeCallback::EXPECT_ERROR);
849   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
850   TestSSLServer server(&acceptCallback);
851
852   // Set up SSL client
853   EventBase eventBase;
854   auto client =
855       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
856   client->connect(true /* write before connect completes */);
857   EventBaseAborter eba(&eventBase, 3000);
858   eventBase.loop();
859
860   usleep(100000);
861   // This is checking that the connectError callback precedes any queued
862   // writeError callbacks.  This matches AsyncSocket's behavior
863   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
864   EXPECT_EQ(client->getErrors(), 1);
865   EXPECT_EQ(client->getMiss(), 0);
866   EXPECT_EQ(client->getHit(), 0);
867
868   cerr << "SSLClientTimeoutTest test completed" << endl;
869 }
870
871
872 /**
873  * Test SSL server async cache
874  */
875 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
876   // Start listening on a local port
877   WriteCallbackBase writeCallback;
878   ReadCallback readCallback(&writeCallback);
879   HandshakeCallback handshakeCallback(&readCallback);
880   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
881   TestSSLAsyncCacheServer server(&acceptCallback);
882
883   // Set up SSL client
884   EventBase eventBase;
885   auto client =
886       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
887
888   client->connect();
889   EventBaseAborter eba(&eventBase, 3000);
890   eventBase.loop();
891
892   EXPECT_EQ(server.getAsyncCallbacks(), 18);
893   EXPECT_EQ(server.getAsyncLookups(), 9);
894   EXPECT_EQ(client->getMiss(), 10);
895   EXPECT_EQ(client->getHit(), 0);
896
897   cerr << "SSLServerAsyncCacheTest test completed" << endl;
898 }
899
900
901 /**
902  * Test SSL server accept timeout with cache path
903  */
904 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
905   // Start listening on a local port
906   WriteCallbackBase writeCallback;
907   ReadCallback readCallback(&writeCallback);
908   EmptyReadCallback clientReadCallback;
909   HandshakeCallback handshakeCallback(&readCallback);
910   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
911   TestSSLAsyncCacheServer server(&acceptCallback);
912
913   // Set up SSL client
914   EventBase eventBase;
915   // only do a TCP connect
916   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
917   sock->connect(nullptr, server.getAddress());
918   clientReadCallback.tcpSocket_ = sock;
919   sock->setReadCB(&clientReadCallback);
920
921   EventBaseAborter eba(&eventBase, 3000);
922   eventBase.loop();
923
924   EXPECT_EQ(readCallback.state, STATE_WAITING);
925
926   cerr << "SSLServerTimeoutTest test completed" << endl;
927 }
928
929 /**
930  * Test SSL server accept timeout with cache path
931  */
932 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
933   // Start listening on a local port
934   WriteCallbackBase writeCallback;
935   ReadCallback readCallback(&writeCallback);
936   HandshakeCallback handshakeCallback(&readCallback);
937   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
938   TestSSLAsyncCacheServer server(&acceptCallback);
939
940   // Set up SSL client
941   EventBase eventBase;
942   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
943
944   client->connect();
945   EventBaseAborter eba(&eventBase, 3000);
946   eventBase.loop();
947
948   EXPECT_EQ(server.getAsyncCallbacks(), 1);
949   EXPECT_EQ(server.getAsyncLookups(), 1);
950   EXPECT_EQ(client->getErrors(), 1);
951   EXPECT_EQ(client->getMiss(), 1);
952   EXPECT_EQ(client->getHit(), 0);
953
954   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
955 }
956
957 /**
958  * Test SSL server accept timeout with cache path
959  */
960 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
961   // Start listening on a local port
962   WriteCallbackBase writeCallback;
963   ReadCallback readCallback(&writeCallback);
964   HandshakeCallback handshakeCallback(&readCallback,
965                                       HandshakeCallback::EXPECT_ERROR);
966   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
967   TestSSLAsyncCacheServer server(&acceptCallback, 500);
968
969   // Set up SSL client
970   EventBase eventBase;
971   auto client =
972       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
973
974   client->connect();
975   EventBaseAborter eba(&eventBase, 3000);
976   eventBase.loop();
977
978   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
979       handshakeCallback.closeSocket();});
980   // give time for the cache lookup to come back and find it closed
981   handshakeCallback.waitForHandshake();
982
983   EXPECT_EQ(server.getAsyncCallbacks(), 1);
984   EXPECT_EQ(server.getAsyncLookups(), 1);
985   EXPECT_EQ(client->getErrors(), 1);
986   EXPECT_EQ(client->getMiss(), 1);
987   EXPECT_EQ(client->getHit(), 0);
988
989   cerr << "SSLServerCacheCloseTest test completed" << endl;
990 }
991
992 /**
993  * Verify Client Ciphers obtained using SSL MSG Callback.
994  */
995 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
996   EventBase eventBase;
997   auto clientCtx = std::make_shared<SSLContext>();
998   auto serverCtx = std::make_shared<SSLContext>();
999   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1000   serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
1001   serverCtx->loadPrivateKey(testKey);
1002   serverCtx->loadCertificate(testCert);
1003   serverCtx->loadTrustedCertificates(testCA);
1004   serverCtx->loadClientCAList(testCA);
1005
1006   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1007   clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
1008   clientCtx->loadPrivateKey(testKey);
1009   clientCtx->loadCertificate(testCert);
1010   clientCtx->loadTrustedCertificates(testCA);
1011
1012   int fds[2];
1013   getfds(fds);
1014
1015   AsyncSSLSocket::UniquePtr clientSock(
1016       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1017   AsyncSSLSocket::UniquePtr serverSock(
1018       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1019
1020   SSLHandshakeClient client(std::move(clientSock), true, true);
1021   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1022
1023   eventBase.loop();
1024
1025   EXPECT_EQ(server.clientCiphers_,
1026             "RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
1027   EXPECT_TRUE(client.handshakeVerify_);
1028   EXPECT_TRUE(client.handshakeSuccess_);
1029   EXPECT_TRUE(!client.handshakeError_);
1030   EXPECT_TRUE(server.handshakeVerify_);
1031   EXPECT_TRUE(server.handshakeSuccess_);
1032   EXPECT_TRUE(!server.handshakeError_);
1033 }
1034
1035 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1036   EventBase eventBase;
1037   auto ctx = std::make_shared<SSLContext>();
1038
1039   int fds[2];
1040   getfds(fds);
1041
1042   int bufLen = 42;
1043   uint8_t majorVersion = 18;
1044   uint8_t minorVersion = 25;
1045
1046   // Create callback buf
1047   auto buf = IOBuf::create(bufLen);
1048   buf->append(bufLen);
1049   folly::io::RWPrivateCursor cursor(buf.get());
1050   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1051   cursor.write<uint16_t>(0);
1052   cursor.write<uint8_t>(38);
1053   cursor.write<uint8_t>(majorVersion);
1054   cursor.write<uint8_t>(minorVersion);
1055   cursor.skip(32);
1056   cursor.write<uint32_t>(0);
1057
1058   SSL* ssl = ctx->createSSL();
1059   SCOPE_EXIT { SSL_free(ssl); };
1060   AsyncSSLSocket::UniquePtr sock(
1061       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1062   sock->enableClientHelloParsing();
1063
1064   // Test client hello parsing in one packet
1065   AsyncSSLSocket::clientHelloParsingCallback(
1066       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1067   buf.reset();
1068
1069   auto parsedClientHello = sock->getClientHelloInfo();
1070   EXPECT_TRUE(parsedClientHello != nullptr);
1071   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1072   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1073 }
1074
1075 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1076   EventBase eventBase;
1077   auto ctx = std::make_shared<SSLContext>();
1078
1079   int fds[2];
1080   getfds(fds);
1081
1082   int bufLen = 42;
1083   uint8_t majorVersion = 18;
1084   uint8_t minorVersion = 25;
1085
1086   // Create callback buf
1087   auto buf = IOBuf::create(bufLen);
1088   buf->append(bufLen);
1089   folly::io::RWPrivateCursor cursor(buf.get());
1090   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1091   cursor.write<uint16_t>(0);
1092   cursor.write<uint8_t>(38);
1093   cursor.write<uint8_t>(majorVersion);
1094   cursor.write<uint8_t>(minorVersion);
1095   cursor.skip(32);
1096   cursor.write<uint32_t>(0);
1097
1098   SSL* ssl = ctx->createSSL();
1099   SCOPE_EXIT { SSL_free(ssl); };
1100   AsyncSSLSocket::UniquePtr sock(
1101       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1102   sock->enableClientHelloParsing();
1103
1104   // Test parsing with two packets with first packet size < 3
1105   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1106   AsyncSSLSocket::clientHelloParsingCallback(
1107       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1108       ssl, sock.get());
1109   bufCopy.reset();
1110   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1111   AsyncSSLSocket::clientHelloParsingCallback(
1112       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1113       ssl, sock.get());
1114   bufCopy.reset();
1115
1116   auto parsedClientHello = sock->getClientHelloInfo();
1117   EXPECT_TRUE(parsedClientHello != nullptr);
1118   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1119   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1120 }
1121
1122 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1123   EventBase eventBase;
1124   auto ctx = std::make_shared<SSLContext>();
1125
1126   int fds[2];
1127   getfds(fds);
1128
1129   int bufLen = 42;
1130   uint8_t majorVersion = 18;
1131   uint8_t minorVersion = 25;
1132
1133   // Create callback buf
1134   auto buf = IOBuf::create(bufLen);
1135   buf->append(bufLen);
1136   folly::io::RWPrivateCursor cursor(buf.get());
1137   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1138   cursor.write<uint16_t>(0);
1139   cursor.write<uint8_t>(38);
1140   cursor.write<uint8_t>(majorVersion);
1141   cursor.write<uint8_t>(minorVersion);
1142   cursor.skip(32);
1143   cursor.write<uint32_t>(0);
1144
1145   SSL* ssl = ctx->createSSL();
1146   SCOPE_EXIT { SSL_free(ssl); };
1147   AsyncSSLSocket::UniquePtr sock(
1148       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1149   sock->enableClientHelloParsing();
1150
1151   // Test parsing with multiple small packets
1152   for (uint64_t i = 0; i < buf->length(); i += 3) {
1153     auto bufCopy = folly::IOBuf::copyBuffer(
1154         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1155     AsyncSSLSocket::clientHelloParsingCallback(
1156         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1157         ssl, sock.get());
1158     bufCopy.reset();
1159   }
1160
1161   auto parsedClientHello = sock->getClientHelloInfo();
1162   EXPECT_TRUE(parsedClientHello != nullptr);
1163   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1164   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1165 }
1166
1167 /**
1168  * Verify sucessful behavior of SSL certificate validation.
1169  */
1170 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1171   EventBase eventBase;
1172   auto clientCtx = std::make_shared<SSLContext>();
1173   auto dfServerCtx = std::make_shared<SSLContext>();
1174
1175   int fds[2];
1176   getfds(fds);
1177   getctx(clientCtx, dfServerCtx);
1178
1179   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1180   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1181
1182   AsyncSSLSocket::UniquePtr clientSock(
1183     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1184   AsyncSSLSocket::UniquePtr serverSock(
1185     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1186
1187   SSLHandshakeClient client(std::move(clientSock), true, true);
1188   clientCtx->loadTrustedCertificates(testCA);
1189
1190   SSLHandshakeServer server(std::move(serverSock), true, true);
1191
1192   eventBase.loop();
1193
1194   EXPECT_TRUE(client.handshakeVerify_);
1195   EXPECT_TRUE(client.handshakeSuccess_);
1196   EXPECT_TRUE(!client.handshakeError_);
1197   EXPECT_LE(0, client.handshakeTime.count());
1198   EXPECT_TRUE(!server.handshakeVerify_);
1199   EXPECT_TRUE(server.handshakeSuccess_);
1200   EXPECT_TRUE(!server.handshakeError_);
1201   EXPECT_LE(0, server.handshakeTime.count());
1202 }
1203
1204 /**
1205  * Verify that the client's verification callback is able to fail SSL
1206  * connection establishment.
1207  */
1208 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1209   EventBase eventBase;
1210   auto clientCtx = std::make_shared<SSLContext>();
1211   auto dfServerCtx = std::make_shared<SSLContext>();
1212
1213   int fds[2];
1214   getfds(fds);
1215   getctx(clientCtx, dfServerCtx);
1216
1217   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1218   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1219
1220   AsyncSSLSocket::UniquePtr clientSock(
1221     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1222   AsyncSSLSocket::UniquePtr serverSock(
1223     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1224
1225   SSLHandshakeClient client(std::move(clientSock), true, false);
1226   clientCtx->loadTrustedCertificates(testCA);
1227
1228   SSLHandshakeServer server(std::move(serverSock), true, true);
1229
1230   eventBase.loop();
1231
1232   EXPECT_TRUE(client.handshakeVerify_);
1233   EXPECT_TRUE(!client.handshakeSuccess_);
1234   EXPECT_TRUE(client.handshakeError_);
1235   EXPECT_LE(0, client.handshakeTime.count());
1236   EXPECT_TRUE(!server.handshakeVerify_);
1237   EXPECT_TRUE(!server.handshakeSuccess_);
1238   EXPECT_TRUE(server.handshakeError_);
1239   EXPECT_LE(0, server.handshakeTime.count());
1240 }
1241
1242 /**
1243  * Verify that the options in SSLContext can be overridden in
1244  * sslConnect/Accept.i.e specifying that no validation should be performed
1245  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1246  * the validation callback.
1247  */
1248 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1249   EventBase eventBase;
1250   auto clientCtx = std::make_shared<SSLContext>();
1251   auto dfServerCtx = std::make_shared<SSLContext>();
1252
1253   int fds[2];
1254   getfds(fds);
1255   getctx(clientCtx, dfServerCtx);
1256
1257   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1258   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1259
1260   AsyncSSLSocket::UniquePtr clientSock(
1261     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1262   AsyncSSLSocket::UniquePtr serverSock(
1263     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1264
1265   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1266   clientCtx->loadTrustedCertificates(testCA);
1267
1268   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1269
1270   eventBase.loop();
1271
1272   EXPECT_TRUE(!client.handshakeVerify_);
1273   EXPECT_TRUE(client.handshakeSuccess_);
1274   EXPECT_TRUE(!client.handshakeError_);
1275   EXPECT_LE(0, client.handshakeTime.count());
1276   EXPECT_TRUE(!server.handshakeVerify_);
1277   EXPECT_TRUE(server.handshakeSuccess_);
1278   EXPECT_TRUE(!server.handshakeError_);
1279   EXPECT_LE(0, server.handshakeTime.count());
1280 }
1281
1282 /**
1283  * Verify that the options in SSLContext can be overridden in
1284  * sslConnect/Accept. Enable verification even if context says otherwise.
1285  * Test requireClientCert with client cert
1286  */
1287 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1288   EventBase eventBase;
1289   auto clientCtx = std::make_shared<SSLContext>();
1290   auto serverCtx = std::make_shared<SSLContext>();
1291   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1292   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1293   serverCtx->loadPrivateKey(testKey);
1294   serverCtx->loadCertificate(testCert);
1295   serverCtx->loadTrustedCertificates(testCA);
1296   serverCtx->loadClientCAList(testCA);
1297
1298   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1299   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1300   clientCtx->loadPrivateKey(testKey);
1301   clientCtx->loadCertificate(testCert);
1302   clientCtx->loadTrustedCertificates(testCA);
1303
1304   int fds[2];
1305   getfds(fds);
1306
1307   AsyncSSLSocket::UniquePtr clientSock(
1308       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1309   AsyncSSLSocket::UniquePtr serverSock(
1310       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1311
1312   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1313   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1314
1315   eventBase.loop();
1316
1317   EXPECT_TRUE(client.handshakeVerify_);
1318   EXPECT_TRUE(client.handshakeSuccess_);
1319   EXPECT_FALSE(client.handshakeError_);
1320   EXPECT_LE(0, client.handshakeTime.count());
1321   EXPECT_TRUE(server.handshakeVerify_);
1322   EXPECT_TRUE(server.handshakeSuccess_);
1323   EXPECT_FALSE(server.handshakeError_);
1324   EXPECT_LE(0, server.handshakeTime.count());
1325 }
1326
1327 /**
1328  * Verify that the client's verification callback is able to override
1329  * the preverification failure and allow a successful connection.
1330  */
1331 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1332   EventBase eventBase;
1333   auto clientCtx = std::make_shared<SSLContext>();
1334   auto dfServerCtx = std::make_shared<SSLContext>();
1335
1336   int fds[2];
1337   getfds(fds);
1338   getctx(clientCtx, dfServerCtx);
1339
1340   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1341   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1342
1343   AsyncSSLSocket::UniquePtr clientSock(
1344     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1345   AsyncSSLSocket::UniquePtr serverSock(
1346     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1347
1348   SSLHandshakeClient client(std::move(clientSock), false, true);
1349   SSLHandshakeServer server(std::move(serverSock), true, true);
1350
1351   eventBase.loop();
1352
1353   EXPECT_TRUE(client.handshakeVerify_);
1354   EXPECT_TRUE(client.handshakeSuccess_);
1355   EXPECT_TRUE(!client.handshakeError_);
1356   EXPECT_LE(0, client.handshakeTime.count());
1357   EXPECT_TRUE(!server.handshakeVerify_);
1358   EXPECT_TRUE(server.handshakeSuccess_);
1359   EXPECT_TRUE(!server.handshakeError_);
1360   EXPECT_LE(0, server.handshakeTime.count());
1361 }
1362
1363 /**
1364  * Verify that specifying that no validation should be performed allows an
1365  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1366  * callback.
1367  */
1368 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1369   EventBase eventBase;
1370   auto clientCtx = std::make_shared<SSLContext>();
1371   auto dfServerCtx = std::make_shared<SSLContext>();
1372
1373   int fds[2];
1374   getfds(fds);
1375   getctx(clientCtx, dfServerCtx);
1376
1377   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1378   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1379
1380   AsyncSSLSocket::UniquePtr clientSock(
1381     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1382   AsyncSSLSocket::UniquePtr serverSock(
1383     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1384
1385   SSLHandshakeClient client(std::move(clientSock), false, false);
1386   SSLHandshakeServer server(std::move(serverSock), false, false);
1387
1388   eventBase.loop();
1389
1390   EXPECT_TRUE(!client.handshakeVerify_);
1391   EXPECT_TRUE(client.handshakeSuccess_);
1392   EXPECT_TRUE(!client.handshakeError_);
1393   EXPECT_LE(0, client.handshakeTime.count());
1394   EXPECT_TRUE(!server.handshakeVerify_);
1395   EXPECT_TRUE(server.handshakeSuccess_);
1396   EXPECT_TRUE(!server.handshakeError_);
1397   EXPECT_LE(0, server.handshakeTime.count());
1398 }
1399
1400 /**
1401  * Test requireClientCert with client cert
1402  */
1403 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1404   EventBase eventBase;
1405   auto clientCtx = std::make_shared<SSLContext>();
1406   auto serverCtx = std::make_shared<SSLContext>();
1407   serverCtx->setVerificationOption(
1408       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1409   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1410   serverCtx->loadPrivateKey(testKey);
1411   serverCtx->loadCertificate(testCert);
1412   serverCtx->loadTrustedCertificates(testCA);
1413   serverCtx->loadClientCAList(testCA);
1414
1415   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1416   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1417   clientCtx->loadPrivateKey(testKey);
1418   clientCtx->loadCertificate(testCert);
1419   clientCtx->loadTrustedCertificates(testCA);
1420
1421   int fds[2];
1422   getfds(fds);
1423
1424   AsyncSSLSocket::UniquePtr clientSock(
1425       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1426   AsyncSSLSocket::UniquePtr serverSock(
1427       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1428
1429   SSLHandshakeClient client(std::move(clientSock), true, true);
1430   SSLHandshakeServer server(std::move(serverSock), true, true);
1431
1432   eventBase.loop();
1433
1434   EXPECT_TRUE(client.handshakeVerify_);
1435   EXPECT_TRUE(client.handshakeSuccess_);
1436   EXPECT_FALSE(client.handshakeError_);
1437   EXPECT_LE(0, client.handshakeTime.count());
1438   EXPECT_TRUE(server.handshakeVerify_);
1439   EXPECT_TRUE(server.handshakeSuccess_);
1440   EXPECT_FALSE(server.handshakeError_);
1441   EXPECT_LE(0, server.handshakeTime.count());
1442 }
1443
1444
1445 /**
1446  * Test requireClientCert with no client cert
1447  */
1448 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1449   EventBase eventBase;
1450   auto clientCtx = std::make_shared<SSLContext>();
1451   auto serverCtx = std::make_shared<SSLContext>();
1452   serverCtx->setVerificationOption(
1453       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1454   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1455   serverCtx->loadPrivateKey(testKey);
1456   serverCtx->loadCertificate(testCert);
1457   serverCtx->loadTrustedCertificates(testCA);
1458   serverCtx->loadClientCAList(testCA);
1459   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1460   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1461
1462   int fds[2];
1463   getfds(fds);
1464
1465   AsyncSSLSocket::UniquePtr clientSock(
1466       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1467   AsyncSSLSocket::UniquePtr serverSock(
1468       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1469
1470   SSLHandshakeClient client(std::move(clientSock), false, false);
1471   SSLHandshakeServer server(std::move(serverSock), false, false);
1472
1473   eventBase.loop();
1474
1475   EXPECT_FALSE(server.handshakeVerify_);
1476   EXPECT_FALSE(server.handshakeSuccess_);
1477   EXPECT_TRUE(server.handshakeError_);
1478   EXPECT_LE(0, client.handshakeTime.count());
1479   EXPECT_LE(0, server.handshakeTime.count());
1480 }
1481
1482 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1483   auto cert = getFileAsBuf(testCert);
1484   auto key = getFileAsBuf(testKey);
1485
1486   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1487   BIO_write(certBio.get(), cert.data(), cert.size());
1488   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1489   BIO_write(keyBio.get(), key.data(), key.size());
1490
1491   // Create SSL structs from buffers to get properties
1492   ssl::X509UniquePtr certStruct(
1493       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1494   ssl::EvpPkeyUniquePtr keyStruct(
1495       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1496   certBio = nullptr;
1497   keyBio = nullptr;
1498
1499   auto origCommonName = getCommonName(certStruct.get());
1500   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1501   certStruct = nullptr;
1502   keyStruct = nullptr;
1503
1504   auto ctx = std::make_shared<SSLContext>();
1505   ctx->loadPrivateKeyFromBufferPEM(key);
1506   ctx->loadCertificateFromBufferPEM(cert);
1507   ctx->loadTrustedCertificates(testCA);
1508
1509   ssl::SSLUniquePtr ssl(ctx->createSSL());
1510
1511   auto newCert = SSL_get_certificate(ssl.get());
1512   auto newKey = SSL_get_privatekey(ssl.get());
1513
1514   // Get properties from SSL struct
1515   auto newCommonName = getCommonName(newCert);
1516   auto newKeySize = EVP_PKEY_bits(newKey);
1517
1518   // Check that the key and cert have the expected properties
1519   EXPECT_EQ(origCommonName, newCommonName);
1520   EXPECT_EQ(origKeySize, newKeySize);
1521 }
1522
1523 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1524   EventBase eb;
1525
1526   // Set up SSL context.
1527   auto sslContext = std::make_shared<SSLContext>();
1528   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1529
1530   // create SSL socket
1531   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1532
1533   EXPECT_EQ(1500, socket->getMinWriteSize());
1534
1535   socket->setMinWriteSize(0);
1536   EXPECT_EQ(0, socket->getMinWriteSize());
1537   socket->setMinWriteSize(50000);
1538   EXPECT_EQ(50000, socket->getMinWriteSize());
1539 }
1540
1541 class ReadCallbackTerminator : public ReadCallback {
1542  public:
1543   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1544       : ReadCallback(wcb)
1545       , base_(base) {}
1546
1547   // Do not write data back, terminate the loop.
1548   void readDataAvailable(size_t len) noexcept override {
1549     std::cerr << "readDataAvailable, len " << len << std::endl;
1550
1551     currentBuffer.length = len;
1552
1553     buffers.push_back(currentBuffer);
1554     currentBuffer.reset();
1555     state = STATE_SUCCEEDED;
1556
1557     socket_->setReadCB(nullptr);
1558     base_->terminateLoopSoon();
1559   }
1560  private:
1561   EventBase* base_;
1562 };
1563
1564
1565 /**
1566  * Test a full unencrypted codepath
1567  */
1568 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1569   EventBase base;
1570
1571   auto clientCtx = std::make_shared<folly::SSLContext>();
1572   auto serverCtx = std::make_shared<folly::SSLContext>();
1573   int fds[2];
1574   getfds(fds);
1575   getctx(clientCtx, serverCtx);
1576   auto client = AsyncSSLSocket::newSocket(
1577                   clientCtx, &base, fds[0], false, true);
1578   auto server = AsyncSSLSocket::newSocket(
1579                   serverCtx, &base, fds[1], true, true);
1580
1581   ReadCallbackTerminator readCallback(&base, nullptr);
1582   server->setReadCB(&readCallback);
1583   readCallback.setSocket(server);
1584
1585   uint8_t buf[128];
1586   memset(buf, 'a', sizeof(buf));
1587   client->write(nullptr, buf, sizeof(buf));
1588
1589   // Check that bytes are unencrypted
1590   char c;
1591   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1592   EXPECT_EQ('a', c);
1593
1594   EventBaseAborter eba(&base, 3000);
1595   base.loop();
1596
1597   EXPECT_EQ(1, readCallback.buffers.size());
1598   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1599
1600   server->setReadCB(&readCallback);
1601
1602   // Unencrypted
1603   server->sslAccept(nullptr);
1604   client->sslConn(nullptr);
1605
1606   // Do NOT wait for handshake, writing should be queued and happen after
1607
1608   client->write(nullptr, buf, sizeof(buf));
1609
1610   // Check that bytes are *not* unencrypted
1611   char c2;
1612   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1613   EXPECT_NE('a', c2);
1614
1615
1616   base.loop();
1617
1618   EXPECT_EQ(2, readCallback.buffers.size());
1619   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1620 }
1621
1622 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1623   // Start listening on a local port
1624   WriteCallbackBase writeCallback;
1625   WriteErrorCallback readCallback(&writeCallback);
1626   HandshakeCallback handshakeCallback(&readCallback,
1627                                       HandshakeCallback::EXPECT_ERROR);
1628   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1629   TestSSLServer server(&acceptCallback);
1630
1631   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1632   socket->open();
1633   uint8_t buf[3] = {0x16, 0x03, 0x01};
1634   socket->write(buf, sizeof(buf));
1635   socket->closeWithReset();
1636
1637   handshakeCallback.waitForHandshake();
1638   EXPECT_NE(
1639       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1640   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1641 }
1642
1643 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1644   // Start listening on a local port
1645   WriteCallbackBase writeCallback;
1646   WriteErrorCallback readCallback(&writeCallback);
1647   HandshakeCallback handshakeCallback(&readCallback,
1648                                       HandshakeCallback::EXPECT_ERROR);
1649   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1650   TestSSLServer server(&acceptCallback);
1651
1652   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1653   socket->open();
1654   uint8_t buf[3] = {0x16, 0x03, 0x01};
1655   socket->write(buf, sizeof(buf));
1656   socket->close();
1657
1658   handshakeCallback.waitForHandshake();
1659   EXPECT_NE(
1660       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1661   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1662 }
1663
1664 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1665   // Start listening on a local port
1666   WriteCallbackBase writeCallback;
1667   WriteErrorCallback readCallback(&writeCallback);
1668   HandshakeCallback handshakeCallback(&readCallback,
1669                                       HandshakeCallback::EXPECT_ERROR);
1670   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1671   TestSSLServer server(&acceptCallback);
1672
1673   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1674   socket->open();
1675   uint8_t buf[256] = {0x16, 0x03};
1676   memset(buf + 2, 'a', sizeof(buf) - 2);
1677   socket->write(buf, sizeof(buf));
1678   socket->close();
1679
1680   handshakeCallback.waitForHandshake();
1681   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1682             std::string::npos);
1683   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1684             std::string::npos);
1685 }
1686
1687 #if FOLLY_ALLOW_TFO
1688
1689 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1690  public:
1691   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1692
1693   explicit MockAsyncTFOSSLSocket(
1694       std::shared_ptr<folly::SSLContext> sslCtx,
1695       EventBase* evb)
1696       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1697
1698   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1699 };
1700
1701 /**
1702  * Test connecting to, writing to, reading from, and closing the
1703  * connection to the SSL server with TFO.
1704  */
1705 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1706   // Start listening on a local port
1707   WriteCallbackBase writeCallback;
1708   ReadCallback readCallback(&writeCallback);
1709   HandshakeCallback handshakeCallback(&readCallback);
1710   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1711   TestSSLServer server(&acceptCallback, true);
1712
1713   // Set up SSL context.
1714   auto sslContext = std::make_shared<SSLContext>();
1715
1716   // connect
1717   auto socket =
1718       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1719   socket->enableTFO();
1720   socket->open();
1721
1722   // write()
1723   std::array<uint8_t, 128> buf;
1724   memset(buf.data(), 'a', buf.size());
1725   socket->write(buf.data(), buf.size());
1726
1727   // read()
1728   std::array<uint8_t, 128> readbuf;
1729   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1730   EXPECT_EQ(bytesRead, 128);
1731   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1732
1733   // close()
1734   socket->close();
1735 }
1736
1737 /**
1738  * Test connecting to, writing to, reading from, and closing the
1739  * connection to the SSL server with TFO.
1740  */
1741 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1742   // Start listening on a local port
1743   WriteCallbackBase writeCallback;
1744   ReadCallback readCallback(&writeCallback);
1745   HandshakeCallback handshakeCallback(&readCallback);
1746   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1747   TestSSLServer server(&acceptCallback, false);
1748
1749   // Set up SSL context.
1750   auto sslContext = std::make_shared<SSLContext>();
1751
1752   // connect
1753   auto socket =
1754       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1755   socket->enableTFO();
1756   socket->open();
1757
1758   // write()
1759   std::array<uint8_t, 128> buf;
1760   memset(buf.data(), 'a', buf.size());
1761   socket->write(buf.data(), buf.size());
1762
1763   // read()
1764   std::array<uint8_t, 128> readbuf;
1765   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1766   EXPECT_EQ(bytesRead, 128);
1767   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1768
1769   // close()
1770   socket->close();
1771 }
1772
1773 class ConnCallback : public AsyncSocket::ConnectCallback {
1774  public:
1775   virtual void connectSuccess() noexcept override {
1776     state = State::SUCCESS;
1777   }
1778
1779   virtual void connectErr(const AsyncSocketException&) noexcept override {
1780     state = State::ERROR;
1781   }
1782
1783   enum class State { WAITING, SUCCESS, ERROR };
1784
1785   State state{State::WAITING};
1786 };
1787
1788 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1789     EventBase* evb,
1790     const SocketAddress& address) {
1791   // Set up SSL context.
1792   auto sslContext = std::make_shared<SSLContext>();
1793
1794   // connect
1795   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1796       new MockAsyncTFOSSLSocket(sslContext, evb));
1797   socket->enableTFO();
1798
1799   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1800       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1801         sockaddr_storage addr;
1802         auto len = address.getAddress(&addr);
1803         return connect(fd, (const struct sockaddr*)&addr, len);
1804       }));
1805   return socket;
1806 }
1807
1808 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1809   // Start listening on a local port
1810   WriteCallbackBase writeCallback;
1811   ReadCallback readCallback(&writeCallback);
1812   HandshakeCallback handshakeCallback(&readCallback);
1813   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1814   TestSSLServer server(&acceptCallback, true);
1815
1816   EventBase evb;
1817
1818   auto socket = setupSocketWithFallback(&evb, server.getAddress());
1819   ConnCallback ccb;
1820   socket->connect(&ccb, server.getAddress(), 30);
1821
1822   evb.loop();
1823   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1824
1825   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1826   evb.loop();
1827
1828   BlockingSocket sock(std::move(socket));
1829   // write()
1830   std::array<uint8_t, 128> buf;
1831   memset(buf.data(), 'a', buf.size());
1832   sock.write(buf.data(), buf.size());
1833
1834   // read()
1835   std::array<uint8_t, 128> readbuf;
1836   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1837   EXPECT_EQ(bytesRead, 128);
1838   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1839
1840   // close()
1841   sock.close();
1842 }
1843
1844 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1845   // Start listening on a local port
1846   WriteCallbackBase writeCallback;
1847   ReadErrorCallback readCallback(&writeCallback);
1848   HandshakeCallback handshakeCallback(&readCallback);
1849   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1850   TestSSLServer server(&acceptCallback, true);
1851
1852   // Set up SSL context.
1853   auto sslContext = std::make_shared<SSLContext>();
1854
1855   // connect
1856   auto socket =
1857       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1858   socket->enableTFO();
1859   EXPECT_THROW(
1860       socket->open(std::chrono::milliseconds(1)), AsyncSocketException);
1861 }
1862
1863 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1864   // Start listening on a local port
1865   WriteCallbackBase writeCallback;
1866   ReadErrorCallback readCallback(&writeCallback);
1867   HandshakeCallback handshakeCallback(&readCallback);
1868   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1869   TestSSLServer server(&acceptCallback, true);
1870
1871   EventBase evb;
1872
1873   auto socket = setupSocketWithFallback(&evb, server.getAddress());
1874   ConnCallback ccb;
1875   // Set a short timeout
1876   socket->connect(&ccb, server.getAddress(), 1);
1877
1878   evb.loop();
1879   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1880 }
1881
1882 #endif
1883
1884 } // namespace
1885
1886 ///////////////////////////////////////////////////////////////////////////
1887 // init_unit_test_suite
1888 ///////////////////////////////////////////////////////////////////////////
1889 namespace {
1890 struct Initializer {
1891   Initializer() {
1892     signal(SIGPIPE, SIG_IGN);
1893   }
1894 };
1895 Initializer initializer;
1896 } // anonymous