Implementing a callback interface for folly::AsyncSocket allowing to supply an ancill...
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <signal.h>
19
20 #include <folly/SocketAddress.h>
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/portability/GMock.h>
24 #include <folly/portability/GTest.h>
25 #include <folly/portability/OpenSSL.h>
26 #include <folly/portability/Sockets.h>
27 #include <folly/portability/Unistd.h>
28
29 #include <folly/io/async/test/BlockingSocket.h>
30
31 #include <fcntl.h>
32 #include <folly/io/Cursor.h>
33 #include <openssl/bio.h>
34 #include <sys/types.h>
35 #include <sys/utsname.h>
36 #include <fstream>
37 #include <iostream>
38 #include <list>
39 #include <set>
40 #include <thread>
41
42 using std::string;
43 using std::vector;
44 using std::min;
45 using std::cerr;
46 using std::endl;
47 using std::list;
48
49 using namespace testing;
50
51 namespace folly {
52 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
53 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
54 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
55
56 constexpr size_t SSLClient::kMaxReadBufferSz;
57 constexpr size_t SSLClient::kMaxReadsPerEvent;
58
59 void getfds(int fds[2]) {
60   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
61     FAIL() << "failed to create socketpair: " << strerror(errno);
62   }
63   for (int idx = 0; idx < 2; ++idx) {
64     int flags = fcntl(fds[idx], F_GETFL, 0);
65     if (flags == -1) {
66       FAIL() << "failed to get flags for socket " << idx << ": "
67              << strerror(errno);
68     }
69     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
70       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
71              << strerror(errno);
72     }
73   }
74 }
75
76 void getctx(
77   std::shared_ptr<folly::SSLContext> clientCtx,
78   std::shared_ptr<folly::SSLContext> serverCtx) {
79   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
80
81   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
82   serverCtx->loadCertificate(kTestCert);
83   serverCtx->loadPrivateKey(kTestKey);
84 }
85
86 void sslsocketpair(
87   EventBase* eventBase,
88   AsyncSSLSocket::UniquePtr* clientSock,
89   AsyncSSLSocket::UniquePtr* serverSock) {
90   auto clientCtx = std::make_shared<folly::SSLContext>();
91   auto serverCtx = std::make_shared<folly::SSLContext>();
92   int fds[2];
93   getfds(fds);
94   getctx(clientCtx, serverCtx);
95   clientSock->reset(new AsyncSSLSocket(
96                       clientCtx, eventBase, fds[0], false));
97   serverSock->reset(new AsyncSSLSocket(
98                       serverCtx, eventBase, fds[1], true));
99
100   // (*clientSock)->setSendTimeout(100);
101   // (*serverSock)->setSendTimeout(100);
102 }
103
104 // client protocol filters
105 bool clientProtoFilterPickPony(unsigned char** client,
106   unsigned int* client_len, const unsigned char*, unsigned int ) {
107   //the protocol string in length prefixed byte string. the
108   //length byte is not included in the length
109   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
110   *client = p;
111   *client_len = 7;
112   return true;
113 }
114
115 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
116   const unsigned char*, unsigned int) {
117   return false;
118 }
119
120 std::string getFileAsBuf(const char* fileName) {
121   std::string buffer;
122   folly::readFile(fileName, buffer);
123   return buffer;
124 }
125
126 std::string getCommonName(X509* cert) {
127   X509_NAME* subject = X509_get_subject_name(cert);
128   std::string cn;
129   cn.resize(ub_common_name);
130   X509_NAME_get_text_by_NID(
131       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
132   return cn;
133 }
134
135 /**
136  * Test connecting to, writing to, reading from, and closing the
137  * connection to the SSL server.
138  */
139 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
140   // Start listening on a local port
141   WriteCallbackBase writeCallback;
142   ReadCallback readCallback(&writeCallback);
143   HandshakeCallback handshakeCallback(&readCallback);
144   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
145   TestSSLServer server(&acceptCallback);
146
147   // Set up SSL context.
148   std::shared_ptr<SSLContext> sslContext(new SSLContext());
149   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
150   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
151   //sslContext->authenticate(true, false);
152
153   // connect
154   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
155                                                  sslContext);
156   socket->open(std::chrono::milliseconds(10000));
157
158   // write()
159   uint8_t buf[128];
160   memset(buf, 'a', sizeof(buf));
161   socket->write(buf, sizeof(buf));
162
163   // read()
164   uint8_t readbuf[128];
165   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
166   EXPECT_EQ(bytesRead, 128);
167   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
168
169   // close()
170   socket->close();
171
172   cerr << "ConnectWriteReadClose test completed" << endl;
173   EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
174 }
175
176 /**
177  * Test reading after server close.
178  */
179 TEST(AsyncSSLSocketTest, ReadAfterClose) {
180   // Start listening on a local port
181   WriteCallbackBase writeCallback;
182   ReadEOFCallback readCallback(&writeCallback);
183   HandshakeCallback handshakeCallback(&readCallback);
184   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
185   auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
186
187   // Set up SSL context.
188   auto sslContext = std::make_shared<SSLContext>();
189   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
190
191   auto socket =
192       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
193   socket->open();
194
195   // This should trigger an EOF on the client.
196   auto evb = handshakeCallback.getSocket()->getEventBase();
197   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
198   std::array<uint8_t, 128> readbuf;
199   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
200   EXPECT_EQ(0, bytesRead);
201 }
202
203 /**
204  * Test bad renegotiation
205  */
206 #if !defined(OPENSSL_IS_BORINGSSL)
207 TEST(AsyncSSLSocketTest, Renegotiate) {
208   EventBase eventBase;
209   auto clientCtx = std::make_shared<SSLContext>();
210   auto dfServerCtx = std::make_shared<SSLContext>();
211   std::array<int, 2> fds;
212   getfds(fds.data());
213   getctx(clientCtx, dfServerCtx);
214
215   AsyncSSLSocket::UniquePtr clientSock(
216       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
217   AsyncSSLSocket::UniquePtr serverSock(
218       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
219   SSLHandshakeClient client(std::move(clientSock), true, true);
220   RenegotiatingServer server(std::move(serverSock));
221
222   while (!client.handshakeSuccess_ && !client.handshakeError_) {
223     eventBase.loopOnce();
224   }
225
226   ASSERT_TRUE(client.handshakeSuccess_);
227
228   auto sslSock = std::move(client).moveSocket();
229   sslSock->detachEventBase();
230   // This is nasty, however we don't want to add support for
231   // renegotiation in AsyncSSLSocket.
232   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
233
234   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
235
236   std::thread t([&]() { eventBase.loopForever(); });
237
238   // Trigger the renegotiation.
239   std::array<uint8_t, 128> buf;
240   memset(buf.data(), 'a', buf.size());
241   try {
242     socket->write(buf.data(), buf.size());
243   } catch (AsyncSocketException& e) {
244     LOG(INFO) << "client got error " << e.what();
245   }
246   eventBase.terminateLoopSoon();
247   t.join();
248
249   eventBase.loop();
250   ASSERT_TRUE(server.renegotiationError_);
251 }
252 #endif
253
254 /**
255  * Negative test for handshakeError().
256  */
257 TEST(AsyncSSLSocketTest, HandshakeError) {
258   // Start listening on a local port
259   WriteCallbackBase writeCallback;
260   WriteErrorCallback readCallback(&writeCallback);
261   HandshakeCallback handshakeCallback(&readCallback);
262   HandshakeErrorCallback acceptCallback(&handshakeCallback);
263   TestSSLServer server(&acceptCallback);
264
265   // Set up SSL context.
266   std::shared_ptr<SSLContext> sslContext(new SSLContext());
267   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
268
269   // connect
270   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
271                                                  sslContext);
272   // read()
273   bool ex = false;
274   try {
275     socket->open();
276
277     uint8_t readbuf[128];
278     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
279     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
280   } catch (AsyncSocketException&) {
281     ex = true;
282   }
283   EXPECT_TRUE(ex);
284
285   // close()
286   socket->close();
287   cerr << "HandshakeError test completed" << endl;
288 }
289
290 /**
291  * Negative test for readError().
292  */
293 TEST(AsyncSSLSocketTest, ReadError) {
294   // Start listening on a local port
295   WriteCallbackBase writeCallback;
296   ReadErrorCallback readCallback(&writeCallback);
297   HandshakeCallback handshakeCallback(&readCallback);
298   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
299   TestSSLServer server(&acceptCallback);
300
301   // Set up SSL context.
302   std::shared_ptr<SSLContext> sslContext(new SSLContext());
303   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
304
305   // connect
306   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
307                                                  sslContext);
308   socket->open();
309
310   // write something to trigger ssl handshake
311   uint8_t buf[128];
312   memset(buf, 'a', sizeof(buf));
313   socket->write(buf, sizeof(buf));
314
315   socket->close();
316   cerr << "ReadError test completed" << endl;
317 }
318
319 /**
320  * Negative test for writeError().
321  */
322 TEST(AsyncSSLSocketTest, WriteError) {
323   // Start listening on a local port
324   WriteCallbackBase writeCallback;
325   WriteErrorCallback readCallback(&writeCallback);
326   HandshakeCallback handshakeCallback(&readCallback);
327   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
328   TestSSLServer server(&acceptCallback);
329
330   // Set up SSL context.
331   std::shared_ptr<SSLContext> sslContext(new SSLContext());
332   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
333
334   // connect
335   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
336                                                  sslContext);
337   socket->open();
338
339   // write something to trigger ssl handshake
340   uint8_t buf[128];
341   memset(buf, 'a', sizeof(buf));
342   socket->write(buf, sizeof(buf));
343
344   socket->close();
345   cerr << "WriteError test completed" << endl;
346 }
347
348 /**
349  * Test a socket with TCP_NODELAY unset.
350  */
351 TEST(AsyncSSLSocketTest, SocketWithDelay) {
352   // Start listening on a local port
353   WriteCallbackBase writeCallback;
354   ReadCallback readCallback(&writeCallback);
355   HandshakeCallback handshakeCallback(&readCallback);
356   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
357   TestSSLServer server(&acceptCallback);
358
359   // Set up SSL context.
360   std::shared_ptr<SSLContext> sslContext(new SSLContext());
361   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
362
363   // connect
364   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
365                                                  sslContext);
366   socket->open();
367
368   // write()
369   uint8_t buf[128];
370   memset(buf, 'a', sizeof(buf));
371   socket->write(buf, sizeof(buf));
372
373   // read()
374   uint8_t readbuf[128];
375   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
376   EXPECT_EQ(bytesRead, 128);
377   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
378
379   // close()
380   socket->close();
381
382   cerr << "SocketWithDelay test completed" << endl;
383 }
384
385 using NextProtocolTypePair =
386     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
387
388 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
389   // For matching protos
390  public:
391   void SetUp() override { getctx(clientCtx, serverCtx); }
392
393   void connect(bool unset = false) {
394     getfds(fds);
395
396     if (unset) {
397       // unsetting NPN for any of [client, server] is enough to make NPN not
398       // work
399       clientCtx->unsetNextProtocols();
400     }
401
402     AsyncSSLSocket::UniquePtr clientSock(
403       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
404     AsyncSSLSocket::UniquePtr serverSock(
405       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
406     client = folly::make_unique<NpnClient>(std::move(clientSock));
407     server = folly::make_unique<NpnServer>(std::move(serverSock));
408
409     eventBase.loop();
410   }
411
412   void expectProtocol(const std::string& proto) {
413     EXPECT_NE(client->nextProtoLength, 0);
414     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
415     EXPECT_EQ(
416         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
417         0);
418     string selected((const char*)client->nextProto, client->nextProtoLength);
419     EXPECT_EQ(proto, selected);
420   }
421
422   void expectNoProtocol() {
423     EXPECT_EQ(client->nextProtoLength, 0);
424     EXPECT_EQ(server->nextProtoLength, 0);
425     EXPECT_EQ(client->nextProto, nullptr);
426     EXPECT_EQ(server->nextProto, nullptr);
427   }
428
429   void expectProtocolType() {
430     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
431         GetParam().second == SSLContext::NextProtocolType::ANY) {
432       EXPECT_EQ(client->protocolType, server->protocolType);
433     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
434                GetParam().second == SSLContext::NextProtocolType::ANY) {
435       // Well not much we can say
436     } else {
437       expectProtocolType(GetParam());
438     }
439   }
440
441   void expectProtocolType(NextProtocolTypePair expected) {
442     EXPECT_EQ(client->protocolType, expected.first);
443     EXPECT_EQ(server->protocolType, expected.second);
444   }
445
446   EventBase eventBase;
447   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
448   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
449   int fds[2];
450   std::unique_ptr<NpnClient> client;
451   std::unique_ptr<NpnServer> server;
452 };
453
454 class NextProtocolTLSExtTest : public NextProtocolTest {
455   // For extended TLS protos
456 };
457
458 class NextProtocolNPNOnlyTest : public NextProtocolTest {
459   // For mismatching protos
460 };
461
462 class NextProtocolMismatchTest : public NextProtocolTest {
463   // For mismatching protos
464 };
465
466 TEST_P(NextProtocolTest, NpnTestOverlap) {
467   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
468   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
469                                         GetParam().second);
470
471   connect();
472
473   expectProtocol("baz");
474   expectProtocolType();
475 }
476
477 TEST_P(NextProtocolTest, NpnTestUnset) {
478   // Identical to above test, except that we want unset NPN before
479   // looping.
480   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
481   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
482                                         GetParam().second);
483
484   connect(true /* unset */);
485
486   // if alpn negotiation fails, type will appear as npn
487   expectNoProtocol();
488   EXPECT_EQ(client->protocolType, server->protocolType);
489 }
490
491 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
492   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
493   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
494                                         GetParam().second);
495
496   connect();
497
498   expectNoProtocol();
499   expectProtocolType(
500       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
501 }
502
503 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
504 // will fail on 1.0.2 before that.
505 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
506   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
507   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
508                                         GetParam().second);
509
510   connect();
511
512   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
513       GetParam().second == SSLContext::NextProtocolType::ALPN) {
514     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
515     // mismatch should result in a fatal alert, but this is OpenSSL's current
516     // behavior and we want to know if it changes.
517     expectNoProtocol();
518   }
519 #if defined(OPENSSL_IS_BORINGSSL)
520   // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
521   // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
522   // NPN *and* ALPN
523   else if (
524       GetParam().first == SSLContext::NextProtocolType::ANY &&
525       GetParam().second == SSLContext::NextProtocolType::ANY) {
526     expectNoProtocol();
527   }
528 #endif
529   else {
530     expectProtocol("blub");
531     expectProtocolType(
532         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
533   }
534 }
535
536 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
537   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
538   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
539   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
540                                         GetParam().second);
541
542   connect();
543
544   expectProtocol("ponies");
545   expectProtocolType();
546 }
547
548 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
549   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
550   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
551   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
552                                         GetParam().second);
553
554   connect();
555
556   expectProtocol("blub");
557   expectProtocolType();
558 }
559
560 TEST_P(NextProtocolTest, RandomizedNpnTest) {
561   // Probability that this test will fail is 2^-64, which could be considered
562   // as negligible.
563   const int kTries = 64;
564
565   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
566                                         GetParam().first);
567   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
568                                                   GetParam().second);
569
570   std::set<string> selectedProtocols;
571   for (int i = 0; i < kTries; ++i) {
572     connect();
573
574     EXPECT_NE(client->nextProtoLength, 0);
575     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
576     EXPECT_EQ(
577         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
578         0);
579     string selected((const char*)client->nextProto, client->nextProtoLength);
580     selectedProtocols.insert(selected);
581     expectProtocolType();
582   }
583   EXPECT_EQ(selectedProtocols.size(), 2);
584 }
585
586 INSTANTIATE_TEST_CASE_P(
587     AsyncSSLSocketTest,
588     NextProtocolTest,
589     ::testing::Values(
590         NextProtocolTypePair(
591             SSLContext::NextProtocolType::NPN,
592             SSLContext::NextProtocolType::NPN),
593         NextProtocolTypePair(
594             SSLContext::NextProtocolType::NPN,
595             SSLContext::NextProtocolType::ANY),
596         NextProtocolTypePair(
597             SSLContext::NextProtocolType::ANY,
598             SSLContext::NextProtocolType::ANY)));
599
600 #if FOLLY_OPENSSL_HAS_ALPN
601 INSTANTIATE_TEST_CASE_P(
602     AsyncSSLSocketTest,
603     NextProtocolTLSExtTest,
604     ::testing::Values(
605         NextProtocolTypePair(
606             SSLContext::NextProtocolType::ALPN,
607             SSLContext::NextProtocolType::ALPN),
608         NextProtocolTypePair(
609             SSLContext::NextProtocolType::ALPN,
610             SSLContext::NextProtocolType::ANY),
611         NextProtocolTypePair(
612             SSLContext::NextProtocolType::ANY,
613             SSLContext::NextProtocolType::ALPN)));
614 #endif
615
616 INSTANTIATE_TEST_CASE_P(
617     AsyncSSLSocketTest,
618     NextProtocolNPNOnlyTest,
619     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
620                                            SSLContext::NextProtocolType::NPN)));
621
622 #if FOLLY_OPENSSL_HAS_ALPN
623 INSTANTIATE_TEST_CASE_P(
624     AsyncSSLSocketTest,
625     NextProtocolMismatchTest,
626     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
627                                            SSLContext::NextProtocolType::ALPN),
628                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
629                                            SSLContext::NextProtocolType::NPN)));
630 #endif
631
632 #ifndef OPENSSL_NO_TLSEXT
633 /**
634  * 1. Client sends TLSEXT_HOSTNAME in client hello.
635  * 2. Server found a match SSL_CTX and use this SSL_CTX to
636  *    continue the SSL handshake.
637  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
638  */
639 TEST(AsyncSSLSocketTest, SNITestMatch) {
640   EventBase eventBase;
641   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
642   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
643   // Use the same SSLContext to continue the handshake after
644   // tlsext_hostname match.
645   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
646   const std::string serverName("xyz.newdev.facebook.com");
647   int fds[2];
648   getfds(fds);
649   getctx(clientCtx, dfServerCtx);
650
651   AsyncSSLSocket::UniquePtr clientSock(
652     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
653   AsyncSSLSocket::UniquePtr serverSock(
654     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
655   SNIClient client(std::move(clientSock));
656   SNIServer server(std::move(serverSock),
657                    dfServerCtx,
658                    hskServerCtx,
659                    serverName);
660
661   eventBase.loop();
662
663   EXPECT_TRUE(client.serverNameMatch);
664   EXPECT_TRUE(server.serverNameMatch);
665 }
666
667 /**
668  * 1. Client sends TLSEXT_HOSTNAME in client hello.
669  * 2. Server cannot find a matching SSL_CTX and continue to use
670  *    the current SSL_CTX to do the handshake.
671  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
672  */
673 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
674   EventBase eventBase;
675   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
676   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
677   // Use the same SSLContext to continue the handshake after
678   // tlsext_hostname match.
679   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
680   const std::string clientRequestingServerName("foo.com");
681   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
682
683   int fds[2];
684   getfds(fds);
685   getctx(clientCtx, dfServerCtx);
686
687   AsyncSSLSocket::UniquePtr clientSock(
688     new AsyncSSLSocket(clientCtx,
689                         &eventBase,
690                         fds[0],
691                         clientRequestingServerName));
692   AsyncSSLSocket::UniquePtr serverSock(
693     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
694   SNIClient client(std::move(clientSock));
695   SNIServer server(std::move(serverSock),
696                    dfServerCtx,
697                    hskServerCtx,
698                    serverExpectedServerName);
699
700   eventBase.loop();
701
702   EXPECT_TRUE(!client.serverNameMatch);
703   EXPECT_TRUE(!server.serverNameMatch);
704 }
705 /**
706  * 1. Client sends TLSEXT_HOSTNAME in client hello.
707  * 2. We then change the serverName.
708  * 3. We expect that we get 'false' as the result for serNameMatch.
709  */
710
711 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
712    EventBase eventBase;
713   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
714   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
715   // Use the same SSLContext to continue the handshake after
716   // tlsext_hostname match.
717   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
718   const std::string serverName("xyz.newdev.facebook.com");
719   int fds[2];
720   getfds(fds);
721   getctx(clientCtx, dfServerCtx);
722
723   AsyncSSLSocket::UniquePtr clientSock(
724     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
725   //Change the server name
726   std::string newName("new.com");
727   clientSock->setServerName(newName);
728   AsyncSSLSocket::UniquePtr serverSock(
729     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
730   SNIClient client(std::move(clientSock));
731   SNIServer server(std::move(serverSock),
732                    dfServerCtx,
733                    hskServerCtx,
734                    serverName);
735
736   eventBase.loop();
737
738   EXPECT_TRUE(!client.serverNameMatch);
739 }
740
741 /**
742  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
743  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
744  */
745 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
746   EventBase eventBase;
747   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
748   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
749   // Use the same SSLContext to continue the handshake after
750   // tlsext_hostname match.
751   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
752   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
753
754   int fds[2];
755   getfds(fds);
756   getctx(clientCtx, dfServerCtx);
757
758   AsyncSSLSocket::UniquePtr clientSock(
759     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
760   AsyncSSLSocket::UniquePtr serverSock(
761     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
762   SNIClient client(std::move(clientSock));
763   SNIServer server(std::move(serverSock),
764                    dfServerCtx,
765                    hskServerCtx,
766                    serverExpectedServerName);
767
768   eventBase.loop();
769
770   EXPECT_TRUE(!client.serverNameMatch);
771   EXPECT_TRUE(!server.serverNameMatch);
772 }
773
774 #endif
775 /**
776  * Test SSL client socket
777  */
778 TEST(AsyncSSLSocketTest, SSLClientTest) {
779   // Start listening on a local port
780   WriteCallbackBase writeCallback;
781   ReadCallback readCallback(&writeCallback);
782   HandshakeCallback handshakeCallback(&readCallback);
783   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
784   TestSSLServer server(&acceptCallback);
785
786   // Set up SSL client
787   EventBase eventBase;
788   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
789
790   client->connect();
791   EventBaseAborter eba(&eventBase, 3000);
792   eventBase.loop();
793
794   EXPECT_EQ(client->getMiss(), 1);
795   EXPECT_EQ(client->getHit(), 0);
796
797   cerr << "SSLClientTest test completed" << endl;
798 }
799
800
801 /**
802  * Test SSL client socket session re-use
803  */
804 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
805   // Start listening on a local port
806   WriteCallbackBase writeCallback;
807   ReadCallback readCallback(&writeCallback);
808   HandshakeCallback handshakeCallback(&readCallback);
809   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
810   TestSSLServer server(&acceptCallback);
811
812   // Set up SSL client
813   EventBase eventBase;
814   auto client =
815       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
816
817   client->connect();
818   EventBaseAborter eba(&eventBase, 3000);
819   eventBase.loop();
820
821   EXPECT_EQ(client->getMiss(), 1);
822   EXPECT_EQ(client->getHit(), 9);
823
824   cerr << "SSLClientTestReuse test completed" << endl;
825 }
826
827 /**
828  * Test SSL client socket timeout
829  */
830 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
831   // Start listening on a local port
832   EmptyReadCallback readCallback;
833   HandshakeCallback handshakeCallback(&readCallback,
834                                       HandshakeCallback::EXPECT_ERROR);
835   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
836   TestSSLServer server(&acceptCallback);
837
838   // Set up SSL client
839   EventBase eventBase;
840   auto client =
841       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
842   client->connect(true /* write before connect completes */);
843   EventBaseAborter eba(&eventBase, 3000);
844   eventBase.loop();
845
846   usleep(100000);
847   // This is checking that the connectError callback precedes any queued
848   // writeError callbacks.  This matches AsyncSocket's behavior
849   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
850   EXPECT_EQ(client->getErrors(), 1);
851   EXPECT_EQ(client->getMiss(), 0);
852   EXPECT_EQ(client->getHit(), 0);
853
854   cerr << "SSLClientTimeoutTest test completed" << endl;
855 }
856
857 // The next 3 tests need an FB-only extension, and will fail without it
858 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
859 /**
860  * Test SSL server async cache
861  */
862 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
863   // Start listening on a local port
864   WriteCallbackBase writeCallback;
865   ReadCallback readCallback(&writeCallback);
866   HandshakeCallback handshakeCallback(&readCallback);
867   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
868   TestSSLAsyncCacheServer server(&acceptCallback);
869
870   // Set up SSL client
871   EventBase eventBase;
872   auto client =
873       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
874
875   client->connect();
876   EventBaseAborter eba(&eventBase, 3000);
877   eventBase.loop();
878
879   EXPECT_EQ(server.getAsyncCallbacks(), 18);
880   EXPECT_EQ(server.getAsyncLookups(), 9);
881   EXPECT_EQ(client->getMiss(), 10);
882   EXPECT_EQ(client->getHit(), 0);
883
884   cerr << "SSLServerAsyncCacheTest test completed" << endl;
885 }
886
887 /**
888  * Test SSL server accept timeout with cache path
889  */
890 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
891   // Start listening on a local port
892   WriteCallbackBase writeCallback;
893   ReadCallback readCallback(&writeCallback);
894   HandshakeCallback handshakeCallback(&readCallback);
895   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
896   TestSSLAsyncCacheServer server(&acceptCallback);
897
898   // Set up SSL client
899   EventBase eventBase;
900   // only do a TCP connect
901   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
902   sock->connect(nullptr, server.getAddress());
903
904   EmptyReadCallback clientReadCallback;
905   clientReadCallback.tcpSocket_ = sock;
906   sock->setReadCB(&clientReadCallback);
907
908   EventBaseAborter eba(&eventBase, 3000);
909   eventBase.loop();
910
911   EXPECT_EQ(readCallback.state, STATE_WAITING);
912
913   cerr << "SSLServerTimeoutTest test completed" << endl;
914 }
915
916 /**
917  * Test SSL server accept timeout with cache path
918  */
919 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
920   // Start listening on a local port
921   WriteCallbackBase writeCallback;
922   ReadCallback readCallback(&writeCallback);
923   HandshakeCallback handshakeCallback(&readCallback);
924   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
925   TestSSLAsyncCacheServer server(&acceptCallback);
926
927   // Set up SSL client
928   EventBase eventBase;
929   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
930
931   client->connect();
932   EventBaseAborter eba(&eventBase, 3000);
933   eventBase.loop();
934
935   EXPECT_EQ(server.getAsyncCallbacks(), 1);
936   EXPECT_EQ(server.getAsyncLookups(), 1);
937   EXPECT_EQ(client->getErrors(), 1);
938   EXPECT_EQ(client->getMiss(), 1);
939   EXPECT_EQ(client->getHit(), 0);
940
941   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
942 }
943
944 /**
945  * Test SSL server accept timeout with cache path
946  */
947 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
948   // Start listening on a local port
949   WriteCallbackBase writeCallback;
950   ReadCallback readCallback(&writeCallback);
951   HandshakeCallback handshakeCallback(&readCallback,
952                                       HandshakeCallback::EXPECT_ERROR);
953   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
954   TestSSLAsyncCacheServer server(&acceptCallback, 500);
955
956   // Set up SSL client
957   EventBase eventBase;
958   auto client =
959       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
960
961   client->connect();
962   EventBaseAborter eba(&eventBase, 3000);
963   eventBase.loop();
964
965   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
966       handshakeCallback.closeSocket();});
967   // give time for the cache lookup to come back and find it closed
968   handshakeCallback.waitForHandshake();
969
970   EXPECT_EQ(server.getAsyncCallbacks(), 1);
971   EXPECT_EQ(server.getAsyncLookups(), 1);
972   EXPECT_EQ(client->getErrors(), 1);
973   EXPECT_EQ(client->getMiss(), 1);
974   EXPECT_EQ(client->getHit(), 0);
975
976   cerr << "SSLServerCacheCloseTest test completed" << endl;
977 }
978 #endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
979
980 /**
981  * Verify Client Ciphers obtained using SSL MSG Callback.
982  */
983 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
984   EventBase eventBase;
985   auto clientCtx = std::make_shared<SSLContext>();
986   auto serverCtx = std::make_shared<SSLContext>();
987   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
988   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
989   serverCtx->loadPrivateKey(kTestKey);
990   serverCtx->loadCertificate(kTestCert);
991   serverCtx->loadTrustedCertificates(kTestCA);
992   serverCtx->loadClientCAList(kTestCA);
993
994   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
995   clientCtx->ciphers("AES256-SHA:AES128-SHA");
996   clientCtx->loadPrivateKey(kTestKey);
997   clientCtx->loadCertificate(kTestCert);
998   clientCtx->loadTrustedCertificates(kTestCA);
999
1000   int fds[2];
1001   getfds(fds);
1002
1003   AsyncSSLSocket::UniquePtr clientSock(
1004       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1005   AsyncSSLSocket::UniquePtr serverSock(
1006       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1007
1008   SSLHandshakeClient client(std::move(clientSock), true, true);
1009   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1010
1011   eventBase.loop();
1012
1013 #if defined(OPENSSL_IS_BORINGSSL)
1014   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
1015 #else
1016   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
1017 #endif
1018   EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1019   EXPECT_TRUE(client.handshakeVerify_);
1020   EXPECT_TRUE(client.handshakeSuccess_);
1021   EXPECT_TRUE(!client.handshakeError_);
1022   EXPECT_TRUE(server.handshakeVerify_);
1023   EXPECT_TRUE(server.handshakeSuccess_);
1024   EXPECT_TRUE(!server.handshakeError_);
1025 }
1026
1027 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1028   EventBase eventBase;
1029   auto ctx = std::make_shared<SSLContext>();
1030
1031   int fds[2];
1032   getfds(fds);
1033
1034   int bufLen = 42;
1035   uint8_t majorVersion = 18;
1036   uint8_t minorVersion = 25;
1037
1038   // Create callback buf
1039   auto buf = IOBuf::create(bufLen);
1040   buf->append(bufLen);
1041   folly::io::RWPrivateCursor cursor(buf.get());
1042   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1043   cursor.write<uint16_t>(0);
1044   cursor.write<uint8_t>(38);
1045   cursor.write<uint8_t>(majorVersion);
1046   cursor.write<uint8_t>(minorVersion);
1047   cursor.skip(32);
1048   cursor.write<uint32_t>(0);
1049
1050   SSL* ssl = ctx->createSSL();
1051   SCOPE_EXIT { SSL_free(ssl); };
1052   AsyncSSLSocket::UniquePtr sock(
1053       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1054   sock->enableClientHelloParsing();
1055
1056   // Test client hello parsing in one packet
1057   AsyncSSLSocket::clientHelloParsingCallback(
1058       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1059   buf.reset();
1060
1061   auto parsedClientHello = sock->getClientHelloInfo();
1062   EXPECT_TRUE(parsedClientHello != nullptr);
1063   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1064   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1065 }
1066
1067 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1068   EventBase eventBase;
1069   auto ctx = std::make_shared<SSLContext>();
1070
1071   int fds[2];
1072   getfds(fds);
1073
1074   int bufLen = 42;
1075   uint8_t majorVersion = 18;
1076   uint8_t minorVersion = 25;
1077
1078   // Create callback buf
1079   auto buf = IOBuf::create(bufLen);
1080   buf->append(bufLen);
1081   folly::io::RWPrivateCursor cursor(buf.get());
1082   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1083   cursor.write<uint16_t>(0);
1084   cursor.write<uint8_t>(38);
1085   cursor.write<uint8_t>(majorVersion);
1086   cursor.write<uint8_t>(minorVersion);
1087   cursor.skip(32);
1088   cursor.write<uint32_t>(0);
1089
1090   SSL* ssl = ctx->createSSL();
1091   SCOPE_EXIT { SSL_free(ssl); };
1092   AsyncSSLSocket::UniquePtr sock(
1093       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1094   sock->enableClientHelloParsing();
1095
1096   // Test parsing with two packets with first packet size < 3
1097   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1098   AsyncSSLSocket::clientHelloParsingCallback(
1099       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1100       ssl, sock.get());
1101   bufCopy.reset();
1102   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1103   AsyncSSLSocket::clientHelloParsingCallback(
1104       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1105       ssl, sock.get());
1106   bufCopy.reset();
1107
1108   auto parsedClientHello = sock->getClientHelloInfo();
1109   EXPECT_TRUE(parsedClientHello != nullptr);
1110   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1111   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1112 }
1113
1114 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1115   EventBase eventBase;
1116   auto ctx = std::make_shared<SSLContext>();
1117
1118   int fds[2];
1119   getfds(fds);
1120
1121   int bufLen = 42;
1122   uint8_t majorVersion = 18;
1123   uint8_t minorVersion = 25;
1124
1125   // Create callback buf
1126   auto buf = IOBuf::create(bufLen);
1127   buf->append(bufLen);
1128   folly::io::RWPrivateCursor cursor(buf.get());
1129   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1130   cursor.write<uint16_t>(0);
1131   cursor.write<uint8_t>(38);
1132   cursor.write<uint8_t>(majorVersion);
1133   cursor.write<uint8_t>(minorVersion);
1134   cursor.skip(32);
1135   cursor.write<uint32_t>(0);
1136
1137   SSL* ssl = ctx->createSSL();
1138   SCOPE_EXIT { SSL_free(ssl); };
1139   AsyncSSLSocket::UniquePtr sock(
1140       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1141   sock->enableClientHelloParsing();
1142
1143   // Test parsing with multiple small packets
1144   for (uint64_t i = 0; i < buf->length(); i += 3) {
1145     auto bufCopy = folly::IOBuf::copyBuffer(
1146         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1147     AsyncSSLSocket::clientHelloParsingCallback(
1148         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1149         ssl, sock.get());
1150     bufCopy.reset();
1151   }
1152
1153   auto parsedClientHello = sock->getClientHelloInfo();
1154   EXPECT_TRUE(parsedClientHello != nullptr);
1155   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1156   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1157 }
1158
1159 /**
1160  * Verify sucessful behavior of SSL certificate validation.
1161  */
1162 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1163   EventBase eventBase;
1164   auto clientCtx = std::make_shared<SSLContext>();
1165   auto dfServerCtx = std::make_shared<SSLContext>();
1166
1167   int fds[2];
1168   getfds(fds);
1169   getctx(clientCtx, dfServerCtx);
1170
1171   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1172   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1173
1174   AsyncSSLSocket::UniquePtr clientSock(
1175     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1176   AsyncSSLSocket::UniquePtr serverSock(
1177     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1178
1179   SSLHandshakeClient client(std::move(clientSock), true, true);
1180   clientCtx->loadTrustedCertificates(kTestCA);
1181
1182   SSLHandshakeServer server(std::move(serverSock), true, true);
1183
1184   eventBase.loop();
1185
1186   EXPECT_TRUE(client.handshakeVerify_);
1187   EXPECT_TRUE(client.handshakeSuccess_);
1188   EXPECT_TRUE(!client.handshakeError_);
1189   EXPECT_LE(0, client.handshakeTime.count());
1190   EXPECT_TRUE(!server.handshakeVerify_);
1191   EXPECT_TRUE(server.handshakeSuccess_);
1192   EXPECT_TRUE(!server.handshakeError_);
1193   EXPECT_LE(0, server.handshakeTime.count());
1194 }
1195
1196 /**
1197  * Verify that the client's verification callback is able to fail SSL
1198  * connection establishment.
1199  */
1200 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1201   EventBase eventBase;
1202   auto clientCtx = std::make_shared<SSLContext>();
1203   auto dfServerCtx = std::make_shared<SSLContext>();
1204
1205   int fds[2];
1206   getfds(fds);
1207   getctx(clientCtx, dfServerCtx);
1208
1209   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1210   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1211
1212   AsyncSSLSocket::UniquePtr clientSock(
1213     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1214   AsyncSSLSocket::UniquePtr serverSock(
1215     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1216
1217   SSLHandshakeClient client(std::move(clientSock), true, false);
1218   clientCtx->loadTrustedCertificates(kTestCA);
1219
1220   SSLHandshakeServer server(std::move(serverSock), true, true);
1221
1222   eventBase.loop();
1223
1224   EXPECT_TRUE(client.handshakeVerify_);
1225   EXPECT_TRUE(!client.handshakeSuccess_);
1226   EXPECT_TRUE(client.handshakeError_);
1227   EXPECT_LE(0, client.handshakeTime.count());
1228   EXPECT_TRUE(!server.handshakeVerify_);
1229   EXPECT_TRUE(!server.handshakeSuccess_);
1230   EXPECT_TRUE(server.handshakeError_);
1231   EXPECT_LE(0, server.handshakeTime.count());
1232 }
1233
1234 /**
1235  * Verify that the options in SSLContext can be overridden in
1236  * sslConnect/Accept.i.e specifying that no validation should be performed
1237  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1238  * the validation callback.
1239  */
1240 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1241   EventBase eventBase;
1242   auto clientCtx = std::make_shared<SSLContext>();
1243   auto dfServerCtx = std::make_shared<SSLContext>();
1244
1245   int fds[2];
1246   getfds(fds);
1247   getctx(clientCtx, dfServerCtx);
1248
1249   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1250   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1251
1252   AsyncSSLSocket::UniquePtr clientSock(
1253     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1254   AsyncSSLSocket::UniquePtr serverSock(
1255     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1256
1257   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1258   clientCtx->loadTrustedCertificates(kTestCA);
1259
1260   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1261
1262   eventBase.loop();
1263
1264   EXPECT_TRUE(!client.handshakeVerify_);
1265   EXPECT_TRUE(client.handshakeSuccess_);
1266   EXPECT_TRUE(!client.handshakeError_);
1267   EXPECT_LE(0, client.handshakeTime.count());
1268   EXPECT_TRUE(!server.handshakeVerify_);
1269   EXPECT_TRUE(server.handshakeSuccess_);
1270   EXPECT_TRUE(!server.handshakeError_);
1271   EXPECT_LE(0, server.handshakeTime.count());
1272 }
1273
1274 /**
1275  * Verify that the options in SSLContext can be overridden in
1276  * sslConnect/Accept. Enable verification even if context says otherwise.
1277  * Test requireClientCert with client cert
1278  */
1279 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1280   EventBase eventBase;
1281   auto clientCtx = std::make_shared<SSLContext>();
1282   auto serverCtx = std::make_shared<SSLContext>();
1283   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1284   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1285   serverCtx->loadPrivateKey(kTestKey);
1286   serverCtx->loadCertificate(kTestCert);
1287   serverCtx->loadTrustedCertificates(kTestCA);
1288   serverCtx->loadClientCAList(kTestCA);
1289
1290   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1291   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1292   clientCtx->loadPrivateKey(kTestKey);
1293   clientCtx->loadCertificate(kTestCert);
1294   clientCtx->loadTrustedCertificates(kTestCA);
1295
1296   int fds[2];
1297   getfds(fds);
1298
1299   AsyncSSLSocket::UniquePtr clientSock(
1300       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1301   AsyncSSLSocket::UniquePtr serverSock(
1302       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1303
1304   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1305   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1306
1307   eventBase.loop();
1308
1309   EXPECT_TRUE(client.handshakeVerify_);
1310   EXPECT_TRUE(client.handshakeSuccess_);
1311   EXPECT_FALSE(client.handshakeError_);
1312   EXPECT_LE(0, client.handshakeTime.count());
1313   EXPECT_TRUE(server.handshakeVerify_);
1314   EXPECT_TRUE(server.handshakeSuccess_);
1315   EXPECT_FALSE(server.handshakeError_);
1316   EXPECT_LE(0, server.handshakeTime.count());
1317 }
1318
1319 /**
1320  * Verify that the client's verification callback is able to override
1321  * the preverification failure and allow a successful connection.
1322  */
1323 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1324   EventBase eventBase;
1325   auto clientCtx = std::make_shared<SSLContext>();
1326   auto dfServerCtx = std::make_shared<SSLContext>();
1327
1328   int fds[2];
1329   getfds(fds);
1330   getctx(clientCtx, dfServerCtx);
1331
1332   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1333   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1334
1335   AsyncSSLSocket::UniquePtr clientSock(
1336     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1337   AsyncSSLSocket::UniquePtr serverSock(
1338     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1339
1340   SSLHandshakeClient client(std::move(clientSock), false, true);
1341   SSLHandshakeServer server(std::move(serverSock), true, true);
1342
1343   eventBase.loop();
1344
1345   EXPECT_TRUE(client.handshakeVerify_);
1346   EXPECT_TRUE(client.handshakeSuccess_);
1347   EXPECT_TRUE(!client.handshakeError_);
1348   EXPECT_LE(0, client.handshakeTime.count());
1349   EXPECT_TRUE(!server.handshakeVerify_);
1350   EXPECT_TRUE(server.handshakeSuccess_);
1351   EXPECT_TRUE(!server.handshakeError_);
1352   EXPECT_LE(0, server.handshakeTime.count());
1353 }
1354
1355 /**
1356  * Verify that specifying that no validation should be performed allows an
1357  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1358  * callback.
1359  */
1360 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1361   EventBase eventBase;
1362   auto clientCtx = std::make_shared<SSLContext>();
1363   auto dfServerCtx = std::make_shared<SSLContext>();
1364
1365   int fds[2];
1366   getfds(fds);
1367   getctx(clientCtx, dfServerCtx);
1368
1369   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1370   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1371
1372   AsyncSSLSocket::UniquePtr clientSock(
1373     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1374   AsyncSSLSocket::UniquePtr serverSock(
1375     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1376
1377   SSLHandshakeClient client(std::move(clientSock), false, false);
1378   SSLHandshakeServer server(std::move(serverSock), false, false);
1379
1380   eventBase.loop();
1381
1382   EXPECT_TRUE(!client.handshakeVerify_);
1383   EXPECT_TRUE(client.handshakeSuccess_);
1384   EXPECT_TRUE(!client.handshakeError_);
1385   EXPECT_LE(0, client.handshakeTime.count());
1386   EXPECT_TRUE(!server.handshakeVerify_);
1387   EXPECT_TRUE(server.handshakeSuccess_);
1388   EXPECT_TRUE(!server.handshakeError_);
1389   EXPECT_LE(0, server.handshakeTime.count());
1390 }
1391
1392 /**
1393  * Test requireClientCert with client cert
1394  */
1395 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1396   EventBase eventBase;
1397   auto clientCtx = std::make_shared<SSLContext>();
1398   auto serverCtx = std::make_shared<SSLContext>();
1399   serverCtx->setVerificationOption(
1400       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1401   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1402   serverCtx->loadPrivateKey(kTestKey);
1403   serverCtx->loadCertificate(kTestCert);
1404   serverCtx->loadTrustedCertificates(kTestCA);
1405   serverCtx->loadClientCAList(kTestCA);
1406
1407   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1408   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1409   clientCtx->loadPrivateKey(kTestKey);
1410   clientCtx->loadCertificate(kTestCert);
1411   clientCtx->loadTrustedCertificates(kTestCA);
1412
1413   int fds[2];
1414   getfds(fds);
1415
1416   AsyncSSLSocket::UniquePtr clientSock(
1417       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1418   AsyncSSLSocket::UniquePtr serverSock(
1419       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1420
1421   SSLHandshakeClient client(std::move(clientSock), true, true);
1422   SSLHandshakeServer server(std::move(serverSock), true, true);
1423
1424   eventBase.loop();
1425
1426   EXPECT_TRUE(client.handshakeVerify_);
1427   EXPECT_TRUE(client.handshakeSuccess_);
1428   EXPECT_FALSE(client.handshakeError_);
1429   EXPECT_LE(0, client.handshakeTime.count());
1430   EXPECT_TRUE(server.handshakeVerify_);
1431   EXPECT_TRUE(server.handshakeSuccess_);
1432   EXPECT_FALSE(server.handshakeError_);
1433   EXPECT_LE(0, server.handshakeTime.count());
1434 }
1435
1436
1437 /**
1438  * Test requireClientCert with no client cert
1439  */
1440 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1441   EventBase eventBase;
1442   auto clientCtx = std::make_shared<SSLContext>();
1443   auto serverCtx = std::make_shared<SSLContext>();
1444   serverCtx->setVerificationOption(
1445       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1446   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1447   serverCtx->loadPrivateKey(kTestKey);
1448   serverCtx->loadCertificate(kTestCert);
1449   serverCtx->loadTrustedCertificates(kTestCA);
1450   serverCtx->loadClientCAList(kTestCA);
1451   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1452   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1453
1454   int fds[2];
1455   getfds(fds);
1456
1457   AsyncSSLSocket::UniquePtr clientSock(
1458       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1459   AsyncSSLSocket::UniquePtr serverSock(
1460       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1461
1462   SSLHandshakeClient client(std::move(clientSock), false, false);
1463   SSLHandshakeServer server(std::move(serverSock), false, false);
1464
1465   eventBase.loop();
1466
1467   EXPECT_FALSE(server.handshakeVerify_);
1468   EXPECT_FALSE(server.handshakeSuccess_);
1469   EXPECT_TRUE(server.handshakeError_);
1470   EXPECT_LE(0, client.handshakeTime.count());
1471   EXPECT_LE(0, server.handshakeTime.count());
1472 }
1473
1474 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1475   auto cert = getFileAsBuf(kTestCert);
1476   auto key = getFileAsBuf(kTestKey);
1477
1478   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1479   BIO_write(certBio.get(), cert.data(), cert.size());
1480   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1481   BIO_write(keyBio.get(), key.data(), key.size());
1482
1483   // Create SSL structs from buffers to get properties
1484   ssl::X509UniquePtr certStruct(
1485       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1486   ssl::EvpPkeyUniquePtr keyStruct(
1487       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1488   certBio = nullptr;
1489   keyBio = nullptr;
1490
1491   auto origCommonName = getCommonName(certStruct.get());
1492   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1493   certStruct = nullptr;
1494   keyStruct = nullptr;
1495
1496   auto ctx = std::make_shared<SSLContext>();
1497   ctx->loadPrivateKeyFromBufferPEM(key);
1498   ctx->loadCertificateFromBufferPEM(cert);
1499   ctx->loadTrustedCertificates(kTestCA);
1500
1501   ssl::SSLUniquePtr ssl(ctx->createSSL());
1502
1503   auto newCert = SSL_get_certificate(ssl.get());
1504   auto newKey = SSL_get_privatekey(ssl.get());
1505
1506   // Get properties from SSL struct
1507   auto newCommonName = getCommonName(newCert);
1508   auto newKeySize = EVP_PKEY_bits(newKey);
1509
1510   // Check that the key and cert have the expected properties
1511   EXPECT_EQ(origCommonName, newCommonName);
1512   EXPECT_EQ(origKeySize, newKeySize);
1513 }
1514
1515 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1516   EventBase eb;
1517
1518   // Set up SSL context.
1519   auto sslContext = std::make_shared<SSLContext>();
1520   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1521
1522   // create SSL socket
1523   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1524
1525   EXPECT_EQ(1500, socket->getMinWriteSize());
1526
1527   socket->setMinWriteSize(0);
1528   EXPECT_EQ(0, socket->getMinWriteSize());
1529   socket->setMinWriteSize(50000);
1530   EXPECT_EQ(50000, socket->getMinWriteSize());
1531 }
1532
1533 class ReadCallbackTerminator : public ReadCallback {
1534  public:
1535   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1536       : ReadCallback(wcb)
1537       , base_(base) {}
1538
1539   // Do not write data back, terminate the loop.
1540   void readDataAvailable(size_t len) noexcept override {
1541     std::cerr << "readDataAvailable, len " << len << std::endl;
1542
1543     currentBuffer.length = len;
1544
1545     buffers.push_back(currentBuffer);
1546     currentBuffer.reset();
1547     state = STATE_SUCCEEDED;
1548
1549     socket_->setReadCB(nullptr);
1550     base_->terminateLoopSoon();
1551   }
1552  private:
1553   EventBase* base_;
1554 };
1555
1556
1557 /**
1558  * Test a full unencrypted codepath
1559  */
1560 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1561   EventBase base;
1562
1563   auto clientCtx = std::make_shared<folly::SSLContext>();
1564   auto serverCtx = std::make_shared<folly::SSLContext>();
1565   int fds[2];
1566   getfds(fds);
1567   getctx(clientCtx, serverCtx);
1568   auto client = AsyncSSLSocket::newSocket(
1569                   clientCtx, &base, fds[0], false, true);
1570   auto server = AsyncSSLSocket::newSocket(
1571                   serverCtx, &base, fds[1], true, true);
1572
1573   ReadCallbackTerminator readCallback(&base, nullptr);
1574   server->setReadCB(&readCallback);
1575   readCallback.setSocket(server);
1576
1577   uint8_t buf[128];
1578   memset(buf, 'a', sizeof(buf));
1579   client->write(nullptr, buf, sizeof(buf));
1580
1581   // Check that bytes are unencrypted
1582   char c;
1583   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1584   EXPECT_EQ('a', c);
1585
1586   EventBaseAborter eba(&base, 3000);
1587   base.loop();
1588
1589   EXPECT_EQ(1, readCallback.buffers.size());
1590   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1591
1592   server->setReadCB(&readCallback);
1593
1594   // Unencrypted
1595   server->sslAccept(nullptr);
1596   client->sslConn(nullptr);
1597
1598   // Do NOT wait for handshake, writing should be queued and happen after
1599
1600   client->write(nullptr, buf, sizeof(buf));
1601
1602   // Check that bytes are *not* unencrypted
1603   char c2;
1604   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1605   EXPECT_NE('a', c2);
1606
1607
1608   base.loop();
1609
1610   EXPECT_EQ(2, readCallback.buffers.size());
1611   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1612 }
1613
1614 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1615   // Start listening on a local port
1616   WriteCallbackBase writeCallback;
1617   WriteErrorCallback readCallback(&writeCallback);
1618   HandshakeCallback handshakeCallback(&readCallback,
1619                                       HandshakeCallback::EXPECT_ERROR);
1620   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1621   TestSSLServer server(&acceptCallback);
1622
1623   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1624   socket->open();
1625   uint8_t buf[3] = {0x16, 0x03, 0x01};
1626   socket->write(buf, sizeof(buf));
1627   socket->closeWithReset();
1628
1629   handshakeCallback.waitForHandshake();
1630   EXPECT_NE(
1631       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1632   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1633 }
1634
1635 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1636   // Start listening on a local port
1637   WriteCallbackBase writeCallback;
1638   WriteErrorCallback readCallback(&writeCallback);
1639   HandshakeCallback handshakeCallback(&readCallback,
1640                                       HandshakeCallback::EXPECT_ERROR);
1641   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1642   TestSSLServer server(&acceptCallback);
1643
1644   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1645   socket->open();
1646   uint8_t buf[3] = {0x16, 0x03, 0x01};
1647   socket->write(buf, sizeof(buf));
1648   socket->close();
1649
1650   handshakeCallback.waitForHandshake();
1651   EXPECT_NE(
1652       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1653   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1654 }
1655
1656 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1657   // Start listening on a local port
1658   WriteCallbackBase writeCallback;
1659   WriteErrorCallback readCallback(&writeCallback);
1660   HandshakeCallback handshakeCallback(&readCallback,
1661                                       HandshakeCallback::EXPECT_ERROR);
1662   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1663   TestSSLServer server(&acceptCallback);
1664
1665   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1666   socket->open();
1667   uint8_t buf[256] = {0x16, 0x03};
1668   memset(buf + 2, 'a', sizeof(buf) - 2);
1669   socket->write(buf, sizeof(buf));
1670   socket->close();
1671
1672   handshakeCallback.waitForHandshake();
1673   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1674             std::string::npos);
1675 #if defined(OPENSSL_IS_BORINGSSL)
1676   EXPECT_NE(
1677       handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
1678       std::string::npos);
1679 #else
1680   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1681             std::string::npos);
1682 #endif
1683 }
1684
1685 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
1686   using folly::ssl::OpenSSLUtils;
1687   EXPECT_EQ(
1688       OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
1689   // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
1690   EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
1691   // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
1692   EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
1693 }
1694
1695 #if FOLLY_ALLOW_TFO
1696
1697 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1698  public:
1699   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1700
1701   explicit MockAsyncTFOSSLSocket(
1702       std::shared_ptr<folly::SSLContext> sslCtx,
1703       EventBase* evb)
1704       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1705
1706   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1707 };
1708
1709 /**
1710  * Test connecting to, writing to, reading from, and closing the
1711  * connection to the SSL server with TFO.
1712  */
1713 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1714   // Start listening on a local port
1715   WriteCallbackBase writeCallback;
1716   ReadCallback readCallback(&writeCallback);
1717   HandshakeCallback handshakeCallback(&readCallback);
1718   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1719   TestSSLServer server(&acceptCallback, true);
1720
1721   // Set up SSL context.
1722   auto sslContext = std::make_shared<SSLContext>();
1723
1724   // connect
1725   auto socket =
1726       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1727   socket->enableTFO();
1728   socket->open();
1729
1730   // write()
1731   std::array<uint8_t, 128> buf;
1732   memset(buf.data(), 'a', buf.size());
1733   socket->write(buf.data(), buf.size());
1734
1735   // read()
1736   std::array<uint8_t, 128> readbuf;
1737   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1738   EXPECT_EQ(bytesRead, 128);
1739   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1740
1741   // close()
1742   socket->close();
1743 }
1744
1745 /**
1746  * Test connecting to, writing to, reading from, and closing the
1747  * connection to the SSL server with TFO.
1748  */
1749 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1750   // Start listening on a local port
1751   WriteCallbackBase writeCallback;
1752   ReadCallback readCallback(&writeCallback);
1753   HandshakeCallback handshakeCallback(&readCallback);
1754   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1755   TestSSLServer server(&acceptCallback, false);
1756
1757   // Set up SSL context.
1758   auto sslContext = std::make_shared<SSLContext>();
1759
1760   // connect
1761   auto socket =
1762       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1763   socket->enableTFO();
1764   socket->open();
1765
1766   // write()
1767   std::array<uint8_t, 128> buf;
1768   memset(buf.data(), 'a', buf.size());
1769   socket->write(buf.data(), buf.size());
1770
1771   // read()
1772   std::array<uint8_t, 128> readbuf;
1773   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1774   EXPECT_EQ(bytesRead, 128);
1775   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1776
1777   // close()
1778   socket->close();
1779 }
1780
1781 class ConnCallback : public AsyncSocket::ConnectCallback {
1782  public:
1783   virtual void connectSuccess() noexcept override {
1784     state = State::SUCCESS;
1785   }
1786
1787   virtual void connectErr(const AsyncSocketException& ex) noexcept override {
1788     state = State::ERROR;
1789     error = ex.what();
1790   }
1791
1792   enum class State { WAITING, SUCCESS, ERROR };
1793
1794   State state{State::WAITING};
1795   std::string error;
1796 };
1797
1798 template <class Cardinality>
1799 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1800     EventBase* evb,
1801     const SocketAddress& address,
1802     Cardinality cardinality) {
1803   // Set up SSL context.
1804   auto sslContext = std::make_shared<SSLContext>();
1805
1806   // connect
1807   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1808       new MockAsyncTFOSSLSocket(sslContext, evb));
1809   socket->enableTFO();
1810
1811   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1812       .Times(cardinality)
1813       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1814         sockaddr_storage addr;
1815         auto len = address.getAddress(&addr);
1816         return connect(fd, (const struct sockaddr*)&addr, len);
1817       }));
1818   return socket;
1819 }
1820
1821 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1822   // Start listening on a local port
1823   WriteCallbackBase writeCallback;
1824   ReadCallback readCallback(&writeCallback);
1825   HandshakeCallback handshakeCallback(&readCallback);
1826   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1827   TestSSLServer server(&acceptCallback, true);
1828
1829   EventBase evb;
1830
1831   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1832   ConnCallback ccb;
1833   socket->connect(&ccb, server.getAddress(), 30);
1834
1835   evb.loop();
1836   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1837
1838   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1839   evb.loop();
1840
1841   BlockingSocket sock(std::move(socket));
1842   // write()
1843   std::array<uint8_t, 128> buf;
1844   memset(buf.data(), 'a', buf.size());
1845   sock.write(buf.data(), buf.size());
1846
1847   // read()
1848   std::array<uint8_t, 128> readbuf;
1849   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1850   EXPECT_EQ(bytesRead, 128);
1851   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1852
1853   // close()
1854   sock.close();
1855 }
1856
1857 #if !defined(OPENSSL_IS_BORINGSSL)
1858 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1859   // Start listening on a local port
1860   ConnectTimeoutCallback acceptCallback;
1861   TestSSLServer server(&acceptCallback, true);
1862
1863   // Set up SSL context.
1864   auto sslContext = std::make_shared<SSLContext>();
1865
1866   // connect
1867   auto socket =
1868       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1869   socket->enableTFO();
1870   EXPECT_THROW(
1871       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
1872 }
1873 #endif
1874
1875 #if !defined(OPENSSL_IS_BORINGSSL)
1876 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1877   // Start listening on a local port
1878   ConnectTimeoutCallback acceptCallback;
1879   TestSSLServer server(&acceptCallback, true);
1880
1881   EventBase evb;
1882
1883   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1884   ConnCallback ccb;
1885   // Set a short timeout
1886   socket->connect(&ccb, server.getAddress(), 1);
1887
1888   evb.loop();
1889   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1890 }
1891 #endif
1892
1893 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
1894   // Start listening on a local port
1895   EmptyReadCallback readCallback;
1896   HandshakeCallback handshakeCallback(
1897       &readCallback, HandshakeCallback::EXPECT_ERROR);
1898   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
1899   TestSSLServer server(&acceptCallback, true);
1900
1901   EventBase evb;
1902
1903   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1904   ConnCallback ccb;
1905   socket->connect(&ccb, server.getAddress(), 100);
1906
1907   evb.loop();
1908   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1909   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
1910 }
1911
1912 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
1913   // Start listening on a local port
1914   EventBase evb;
1915
1916   // Hopefully nothing is listening on this address
1917   SocketAddress addr("127.0.0.1", 65535);
1918   auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
1919   ConnCallback ccb;
1920   socket->connect(&ccb, addr, 100);
1921
1922   evb.loop();
1923   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1924   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
1925 }
1926
1927 TEST(AsyncSSLSocketTest, TestPreReceivedData) {
1928   EventBase clientEventBase;
1929   EventBase serverEventBase;
1930   auto clientCtx = std::make_shared<SSLContext>();
1931   auto dfServerCtx = std::make_shared<SSLContext>();
1932   std::array<int, 2> fds;
1933   getfds(fds.data());
1934   getctx(clientCtx, dfServerCtx);
1935
1936   AsyncSSLSocket::UniquePtr clientSockPtr(
1937       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
1938   AsyncSSLSocket::UniquePtr serverSockPtr(
1939       new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
1940   auto clientSock = clientSockPtr.get();
1941   auto serverSock = serverSockPtr.get();
1942   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
1943
1944   // Steal some data from the server.
1945   clientEventBase.loopOnce();
1946   std::array<uint8_t, 10> buf;
1947   recv(fds[1], buf.data(), buf.size(), 0);
1948
1949   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
1950   SSLHandshakeServer server(std::move(serverSockPtr), true, true);
1951   while (!client.handshakeSuccess_ && !client.handshakeError_) {
1952     serverEventBase.loopOnce();
1953     clientEventBase.loopOnce();
1954   }
1955
1956   EXPECT_TRUE(client.handshakeSuccess_);
1957   EXPECT_TRUE(server.handshakeSuccess_);
1958   EXPECT_EQ(
1959       serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
1960 }
1961
1962 /**
1963  * Test overriding the flags passed to "sendmsg()" system call,
1964  * and verifying that write requests fail properly.
1965  */
1966 TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
1967   // Start listening on a local port
1968   SendMsgFlagsCallback msgCallback;
1969   ExpectWriteErrorCallback writeCallback(&msgCallback);
1970   ReadCallback readCallback(&writeCallback);
1971   HandshakeCallback handshakeCallback(&readCallback);
1972   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1973   TestSSLServer server(&acceptCallback);
1974
1975   // Set up SSL context.
1976   auto sslContext = std::make_shared<SSLContext>();
1977   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1978
1979   // connect
1980   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
1981                                                  sslContext);
1982   socket->open();
1983
1984   // Setting flags to "-1" to trigger "Invalid argument" error
1985   // on attempt to use this flags in sendmsg() system call.
1986   msgCallback.resetFlags(-1);
1987
1988   // write()
1989   std::vector<uint8_t> buf(128, 'a');
1990   ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
1991
1992   // close()
1993   socket->close();
1994
1995   cerr << "SendMsgParamsCallback test completed" << endl;
1996 }
1997
1998 #ifdef MSG_ERRQUEUE
1999 /**
2000  * Test connecting to, writing to, reading from, and closing the
2001  * connection to the SSL server.
2002  */
2003 TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
2004   // This test requires Linux kernel v4.6 or later
2005   struct utsname s_uname;
2006   memset(&s_uname, 0, sizeof(s_uname));
2007   ASSERT_EQ(uname(&s_uname), 0);
2008   int major, minor;
2009   folly::StringPiece extra;
2010   if (folly::split<false>(
2011         '.', std::string(s_uname.release) + ".", major, minor, extra)) {
2012     if (major < 4 || (major == 4 && minor < 6)) {
2013       LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
2014                 << "kernel ver. " << s_uname.release << " detected).";
2015       return;
2016     }
2017   }
2018
2019   // Start listening on a local port
2020   SendMsgDataCallback msgCallback;
2021   WriteCheckTimestampCallback writeCallback(&msgCallback);
2022   ReadCallback readCallback(&writeCallback);
2023   HandshakeCallback handshakeCallback(&readCallback);
2024   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2025   TestSSLServer server(&acceptCallback);
2026
2027   // Set up SSL context.
2028   auto sslContext = std::make_shared<SSLContext>();
2029   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2030
2031   // connect
2032   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
2033                                                  sslContext);
2034   socket->open();
2035
2036   // Adding MSG_EOR flag to the message flags - it'll trigger
2037   // timestamp generation for the last byte of the message.
2038   msgCallback.resetFlags(MSG_DONTWAIT|MSG_NOSIGNAL|MSG_EOR);
2039
2040   // Init ancillary data buffer to trigger timestamp notification
2041   union {
2042     uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
2043     struct cmsghdr cmsg;
2044   } u;
2045   u.cmsg.cmsg_level = SOL_SOCKET;
2046   u.cmsg.cmsg_type = SO_TIMESTAMPING;
2047   u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
2048   uint32_t flags =
2049       SOF_TIMESTAMPING_TX_SCHED |
2050       SOF_TIMESTAMPING_TX_SOFTWARE |
2051       SOF_TIMESTAMPING_TX_ACK;
2052   memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
2053   std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
2054   memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
2055   msgCallback.resetData(std::move(ctrl));
2056
2057   // write()
2058   std::vector<uint8_t> buf(128, 'a');
2059   socket->write(buf.data(), buf.size());
2060
2061   // read()
2062   std::vector<uint8_t> readbuf(buf.size());
2063   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2064   EXPECT_EQ(bytesRead, buf.size());
2065   EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
2066
2067   writeCallback.checkForTimestampNotifications();
2068
2069   // close()
2070   socket->close();
2071
2072   cerr << "SendMsgDataCallback test completed" << endl;
2073 }
2074 #endif // MSG_ERRQUEUE
2075
2076 #endif
2077
2078 } // namespace
2079
2080 #ifdef SIGPIPE
2081 ///////////////////////////////////////////////////////////////////////////
2082 // init_unit_test_suite
2083 ///////////////////////////////////////////////////////////////////////////
2084 namespace {
2085 struct Initializer {
2086   Initializer() {
2087     signal(SIGPIPE, SIG_IGN);
2088   }
2089 };
2090 Initializer initializer;
2091 } // anonymous
2092 #endif