apply clang-tidy modernize-use-override
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <folly/SocketAddress.h>
19 #include <folly/io/Cursor.h>
20 #include <folly/io/async/AsyncSSLSocket.h>
21 #include <folly/io/async/EventBase.h>
22 #include <folly/portability/GMock.h>
23 #include <folly/portability/GTest.h>
24 #include <folly/portability/OpenSSL.h>
25 #include <folly/portability/Sockets.h>
26 #include <folly/portability/Unistd.h>
27
28 #include <folly/io/async/test/BlockingSocket.h>
29
30 #include <fcntl.h>
31 #include <signal.h>
32 #include <sys/types.h>
33 #include <sys/utsname.h>
34
35 #include <fstream>
36 #include <iostream>
37 #include <list>
38 #include <set>
39 #include <thread>
40
41 using std::string;
42 using std::vector;
43 using std::min;
44 using std::cerr;
45 using std::endl;
46 using std::list;
47
48 using namespace testing;
49
50 namespace folly {
51 uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
52 uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
53 uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
54
55 constexpr size_t SSLClient::kMaxReadBufferSz;
56 constexpr size_t SSLClient::kMaxReadsPerEvent;
57
58 void getfds(int fds[2]) {
59   if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
60     FAIL() << "failed to create socketpair: " << strerror(errno);
61   }
62   for (int idx = 0; idx < 2; ++idx) {
63     int flags = fcntl(fds[idx], F_GETFL, 0);
64     if (flags == -1) {
65       FAIL() << "failed to get flags for socket " << idx << ": "
66              << strerror(errno);
67     }
68     if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
69       FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
70              << strerror(errno);
71     }
72   }
73 }
74
75 void getctx(
76   std::shared_ptr<folly::SSLContext> clientCtx,
77   std::shared_ptr<folly::SSLContext> serverCtx) {
78   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
79
80   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
81   serverCtx->loadCertificate(kTestCert);
82   serverCtx->loadPrivateKey(kTestKey);
83 }
84
85 void sslsocketpair(
86   EventBase* eventBase,
87   AsyncSSLSocket::UniquePtr* clientSock,
88   AsyncSSLSocket::UniquePtr* serverSock) {
89   auto clientCtx = std::make_shared<folly::SSLContext>();
90   auto serverCtx = std::make_shared<folly::SSLContext>();
91   int fds[2];
92   getfds(fds);
93   getctx(clientCtx, serverCtx);
94   clientSock->reset(new AsyncSSLSocket(
95                       clientCtx, eventBase, fds[0], false));
96   serverSock->reset(new AsyncSSLSocket(
97                       serverCtx, eventBase, fds[1], true));
98
99   // (*clientSock)->setSendTimeout(100);
100   // (*serverSock)->setSendTimeout(100);
101 }
102
103 // client protocol filters
104 bool clientProtoFilterPickPony(unsigned char** client,
105   unsigned int* client_len, const unsigned char*, unsigned int ) {
106   //the protocol string in length prefixed byte string. the
107   //length byte is not included in the length
108   static unsigned char p[7] = {6,'p','o','n','i','e','s'};
109   *client = p;
110   *client_len = 7;
111   return true;
112 }
113
114 bool clientProtoFilterPickNone(unsigned char**, unsigned int*,
115   const unsigned char*, unsigned int) {
116   return false;
117 }
118
119 std::string getFileAsBuf(const char* fileName) {
120   std::string buffer;
121   folly::readFile(fileName, buffer);
122   return buffer;
123 }
124
125 std::string getCommonName(X509* cert) {
126   X509_NAME* subject = X509_get_subject_name(cert);
127   std::string cn;
128   cn.resize(ub_common_name);
129   X509_NAME_get_text_by_NID(
130       subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
131   return cn;
132 }
133
134 /**
135  * Test connecting to, writing to, reading from, and closing the
136  * connection to the SSL server.
137  */
138 TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
139   // Start listening on a local port
140   WriteCallbackBase writeCallback;
141   ReadCallback readCallback(&writeCallback);
142   HandshakeCallback handshakeCallback(&readCallback);
143   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
144   TestSSLServer server(&acceptCallback);
145
146   // Set up SSL context.
147   std::shared_ptr<SSLContext> sslContext(new SSLContext());
148   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
149   //sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
150   //sslContext->authenticate(true, false);
151
152   // connect
153   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
154                                                  sslContext);
155   socket->open(std::chrono::milliseconds(10000));
156
157   // write()
158   uint8_t buf[128];
159   memset(buf, 'a', sizeof(buf));
160   socket->write(buf, sizeof(buf));
161
162   // read()
163   uint8_t readbuf[128];
164   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
165   EXPECT_EQ(bytesRead, 128);
166   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
167
168   // close()
169   socket->close();
170
171   cerr << "ConnectWriteReadClose test completed" << endl;
172   EXPECT_EQ(socket->getSSLSocket()->getTotalConnectTimeout().count(), 10000);
173 }
174
175 /**
176  * Test reading after server close.
177  */
178 TEST(AsyncSSLSocketTest, ReadAfterClose) {
179   // Start listening on a local port
180   WriteCallbackBase writeCallback;
181   ReadEOFCallback readCallback(&writeCallback);
182   HandshakeCallback handshakeCallback(&readCallback);
183   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
184   auto server = std::make_unique<TestSSLServer>(&acceptCallback);
185
186   // Set up SSL context.
187   auto sslContext = std::make_shared<SSLContext>();
188   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
189
190   auto socket =
191       std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
192   socket->open();
193
194   // This should trigger an EOF on the client.
195   auto evb = handshakeCallback.getSocket()->getEventBase();
196   evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
197   std::array<uint8_t, 128> readbuf;
198   auto bytesRead = socket->read(readbuf.data(), readbuf.size());
199   EXPECT_EQ(0, bytesRead);
200 }
201
202 /**
203  * Test bad renegotiation
204  */
205 #if !defined(OPENSSL_IS_BORINGSSL)
206 TEST(AsyncSSLSocketTest, Renegotiate) {
207   EventBase eventBase;
208   auto clientCtx = std::make_shared<SSLContext>();
209   auto dfServerCtx = std::make_shared<SSLContext>();
210   std::array<int, 2> fds;
211   getfds(fds.data());
212   getctx(clientCtx, dfServerCtx);
213
214   AsyncSSLSocket::UniquePtr clientSock(
215       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
216   AsyncSSLSocket::UniquePtr serverSock(
217       new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
218   SSLHandshakeClient client(std::move(clientSock), true, true);
219   RenegotiatingServer server(std::move(serverSock));
220
221   while (!client.handshakeSuccess_ && !client.handshakeError_) {
222     eventBase.loopOnce();
223   }
224
225   ASSERT_TRUE(client.handshakeSuccess_);
226
227   auto sslSock = std::move(client).moveSocket();
228   sslSock->detachEventBase();
229   // This is nasty, however we don't want to add support for
230   // renegotiation in AsyncSSLSocket.
231   SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
232
233   auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
234
235   std::thread t([&]() { eventBase.loopForever(); });
236
237   // Trigger the renegotiation.
238   std::array<uint8_t, 128> buf;
239   memset(buf.data(), 'a', buf.size());
240   try {
241     socket->write(buf.data(), buf.size());
242   } catch (AsyncSocketException& e) {
243     LOG(INFO) << "client got error " << e.what();
244   }
245   eventBase.terminateLoopSoon();
246   t.join();
247
248   eventBase.loop();
249   ASSERT_TRUE(server.renegotiationError_);
250 }
251 #endif
252
253 /**
254  * Negative test for handshakeError().
255  */
256 TEST(AsyncSSLSocketTest, HandshakeError) {
257   // Start listening on a local port
258   WriteCallbackBase writeCallback;
259   WriteErrorCallback readCallback(&writeCallback);
260   HandshakeCallback handshakeCallback(&readCallback);
261   HandshakeErrorCallback acceptCallback(&handshakeCallback);
262   TestSSLServer server(&acceptCallback);
263
264   // Set up SSL context.
265   std::shared_ptr<SSLContext> sslContext(new SSLContext());
266   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
267
268   // connect
269   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
270                                                  sslContext);
271   // read()
272   bool ex = false;
273   try {
274     socket->open();
275
276     uint8_t readbuf[128];
277     uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
278     LOG(ERROR) << "readAll returned " << bytesRead << " instead of throwing";
279   } catch (AsyncSocketException&) {
280     ex = true;
281   }
282   EXPECT_TRUE(ex);
283
284   // close()
285   socket->close();
286   cerr << "HandshakeError test completed" << endl;
287 }
288
289 /**
290  * Negative test for readError().
291  */
292 TEST(AsyncSSLSocketTest, ReadError) {
293   // Start listening on a local port
294   WriteCallbackBase writeCallback;
295   ReadErrorCallback readCallback(&writeCallback);
296   HandshakeCallback handshakeCallback(&readCallback);
297   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
298   TestSSLServer server(&acceptCallback);
299
300   // Set up SSL context.
301   std::shared_ptr<SSLContext> sslContext(new SSLContext());
302   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
303
304   // connect
305   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
306                                                  sslContext);
307   socket->open();
308
309   // write something to trigger ssl handshake
310   uint8_t buf[128];
311   memset(buf, 'a', sizeof(buf));
312   socket->write(buf, sizeof(buf));
313
314   socket->close();
315   cerr << "ReadError test completed" << endl;
316 }
317
318 /**
319  * Negative test for writeError().
320  */
321 TEST(AsyncSSLSocketTest, WriteError) {
322   // Start listening on a local port
323   WriteCallbackBase writeCallback;
324   WriteErrorCallback readCallback(&writeCallback);
325   HandshakeCallback handshakeCallback(&readCallback);
326   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
327   TestSSLServer server(&acceptCallback);
328
329   // Set up SSL context.
330   std::shared_ptr<SSLContext> sslContext(new SSLContext());
331   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
332
333   // connect
334   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
335                                                  sslContext);
336   socket->open();
337
338   // write something to trigger ssl handshake
339   uint8_t buf[128];
340   memset(buf, 'a', sizeof(buf));
341   socket->write(buf, sizeof(buf));
342
343   socket->close();
344   cerr << "WriteError test completed" << endl;
345 }
346
347 /**
348  * Test a socket with TCP_NODELAY unset.
349  */
350 TEST(AsyncSSLSocketTest, SocketWithDelay) {
351   // Start listening on a local port
352   WriteCallbackBase writeCallback;
353   ReadCallback readCallback(&writeCallback);
354   HandshakeCallback handshakeCallback(&readCallback);
355   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
356   TestSSLServer server(&acceptCallback);
357
358   // Set up SSL context.
359   std::shared_ptr<SSLContext> sslContext(new SSLContext());
360   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
361
362   // connect
363   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
364                                                  sslContext);
365   socket->open();
366
367   // write()
368   uint8_t buf[128];
369   memset(buf, 'a', sizeof(buf));
370   socket->write(buf, sizeof(buf));
371
372   // read()
373   uint8_t readbuf[128];
374   uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
375   EXPECT_EQ(bytesRead, 128);
376   EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
377
378   // close()
379   socket->close();
380
381   cerr << "SocketWithDelay test completed" << endl;
382 }
383
384 using NextProtocolTypePair =
385     std::pair<SSLContext::NextProtocolType, SSLContext::NextProtocolType>;
386
387 class NextProtocolTest : public testing::TestWithParam<NextProtocolTypePair> {
388   // For matching protos
389  public:
390   void SetUp() override { getctx(clientCtx, serverCtx); }
391
392   void connect(bool unset = false) {
393     getfds(fds);
394
395     if (unset) {
396       // unsetting NPN for any of [client, server] is enough to make NPN not
397       // work
398       clientCtx->unsetNextProtocols();
399     }
400
401     AsyncSSLSocket::UniquePtr clientSock(
402       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
403     AsyncSSLSocket::UniquePtr serverSock(
404       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
405     client = std::make_unique<NpnClient>(std::move(clientSock));
406     server = std::make_unique<NpnServer>(std::move(serverSock));
407
408     eventBase.loop();
409   }
410
411   void expectProtocol(const std::string& proto) {
412     EXPECT_NE(client->nextProtoLength, 0);
413     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
414     EXPECT_EQ(
415         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
416         0);
417     string selected((const char*)client->nextProto, client->nextProtoLength);
418     EXPECT_EQ(proto, selected);
419   }
420
421   void expectNoProtocol() {
422     EXPECT_EQ(client->nextProtoLength, 0);
423     EXPECT_EQ(server->nextProtoLength, 0);
424     EXPECT_EQ(client->nextProto, nullptr);
425     EXPECT_EQ(server->nextProto, nullptr);
426   }
427
428   void expectProtocolType() {
429     if (GetParam().first == SSLContext::NextProtocolType::ANY &&
430         GetParam().second == SSLContext::NextProtocolType::ANY) {
431       EXPECT_EQ(client->protocolType, server->protocolType);
432     } else if (GetParam().first == SSLContext::NextProtocolType::ANY ||
433                GetParam().second == SSLContext::NextProtocolType::ANY) {
434       // Well not much we can say
435     } else {
436       expectProtocolType(GetParam());
437     }
438   }
439
440   void expectProtocolType(NextProtocolTypePair expected) {
441     EXPECT_EQ(client->protocolType, expected.first);
442     EXPECT_EQ(server->protocolType, expected.second);
443   }
444
445   EventBase eventBase;
446   std::shared_ptr<SSLContext> clientCtx{std::make_shared<SSLContext>()};
447   std::shared_ptr<SSLContext> serverCtx{std::make_shared<SSLContext>()};
448   int fds[2];
449   std::unique_ptr<NpnClient> client;
450   std::unique_ptr<NpnServer> server;
451 };
452
453 class NextProtocolTLSExtTest : public NextProtocolTest {
454   // For extended TLS protos
455 };
456
457 class NextProtocolNPNOnlyTest : public NextProtocolTest {
458   // For mismatching protos
459 };
460
461 class NextProtocolMismatchTest : public NextProtocolTest {
462   // For mismatching protos
463 };
464
465 TEST_P(NextProtocolTest, NpnTestOverlap) {
466   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
467   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
468                                         GetParam().second);
469
470   connect();
471
472   expectProtocol("baz");
473   expectProtocolType();
474 }
475
476 TEST_P(NextProtocolTest, NpnTestUnset) {
477   // Identical to above test, except that we want unset NPN before
478   // looping.
479   clientCtx->setAdvertisedNextProtocols({"blub", "baz"}, GetParam().first);
480   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
481                                         GetParam().second);
482
483   connect(true /* unset */);
484
485   // if alpn negotiation fails, type will appear as npn
486   expectNoProtocol();
487   EXPECT_EQ(client->protocolType, server->protocolType);
488 }
489
490 TEST_P(NextProtocolMismatchTest, NpnAlpnTestNoOverlap) {
491   clientCtx->setAdvertisedNextProtocols({"foo"}, GetParam().first);
492   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
493                                         GetParam().second);
494
495   connect();
496
497   expectNoProtocol();
498   expectProtocolType(
499       {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
500 }
501
502 // Note: the behavior changed in the ANY/ANY case in OpenSSL 1.0.2h, this test
503 // will fail on 1.0.2 before that.
504 TEST_P(NextProtocolTest, NpnTestNoOverlap) {
505   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
506   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
507                                         GetParam().second);
508
509   connect();
510
511   if (GetParam().first == SSLContext::NextProtocolType::ALPN ||
512       GetParam().second == SSLContext::NextProtocolType::ALPN) {
513     // This is arguably incorrect behavior since RFC7301 states an ALPN protocol
514     // mismatch should result in a fatal alert, but this is OpenSSL's current
515     // behavior and we want to know if it changes.
516     expectNoProtocol();
517   }
518 #if defined(OPENSSL_IS_BORINGSSL)
519   // BoringSSL also doesn't fatal on mismatch but behaves slightly differently
520   // from OpenSSL 1.0.2h+ - it doesn't select a protocol if both ends support
521   // NPN *and* ALPN
522   else if (
523       GetParam().first == SSLContext::NextProtocolType::ANY &&
524       GetParam().second == SSLContext::NextProtocolType::ANY) {
525     expectNoProtocol();
526   }
527 #endif
528   else {
529     expectProtocol("blub");
530     expectProtocolType(
531         {SSLContext::NextProtocolType::NPN, SSLContext::NextProtocolType::NPN});
532   }
533 }
534
535 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterHit) {
536   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
537   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickPony);
538   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
539                                         GetParam().second);
540
541   connect();
542
543   expectProtocol("ponies");
544   expectProtocolType();
545 }
546
547 TEST_P(NextProtocolNPNOnlyTest, NpnTestClientProtoFilterMiss) {
548   clientCtx->setAdvertisedNextProtocols({"blub"}, GetParam().first);
549   clientCtx->setClientProtocolFilterCallback(clientProtoFilterPickNone);
550   serverCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
551                                         GetParam().second);
552
553   connect();
554
555   expectProtocol("blub");
556   expectProtocolType();
557 }
558
559 TEST_P(NextProtocolTest, RandomizedNpnTest) {
560   // Probability that this test will fail is 2^-64, which could be considered
561   // as negligible.
562   const int kTries = 64;
563
564   clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"},
565                                         GetParam().first);
566   serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}}, {1, {"bar"}}},
567                                                   GetParam().second);
568
569   std::set<string> selectedProtocols;
570   for (int i = 0; i < kTries; ++i) {
571     connect();
572
573     EXPECT_NE(client->nextProtoLength, 0);
574     EXPECT_EQ(client->nextProtoLength, server->nextProtoLength);
575     EXPECT_EQ(
576         memcmp(client->nextProto, server->nextProto, server->nextProtoLength),
577         0);
578     string selected((const char*)client->nextProto, client->nextProtoLength);
579     selectedProtocols.insert(selected);
580     expectProtocolType();
581   }
582   EXPECT_EQ(selectedProtocols.size(), 2);
583 }
584
585 INSTANTIATE_TEST_CASE_P(
586     AsyncSSLSocketTest,
587     NextProtocolTest,
588     ::testing::Values(
589         NextProtocolTypePair(
590             SSLContext::NextProtocolType::NPN,
591             SSLContext::NextProtocolType::NPN),
592         NextProtocolTypePair(
593             SSLContext::NextProtocolType::NPN,
594             SSLContext::NextProtocolType::ANY),
595         NextProtocolTypePair(
596             SSLContext::NextProtocolType::ANY,
597             SSLContext::NextProtocolType::ANY)));
598
599 #if FOLLY_OPENSSL_HAS_ALPN
600 INSTANTIATE_TEST_CASE_P(
601     AsyncSSLSocketTest,
602     NextProtocolTLSExtTest,
603     ::testing::Values(
604         NextProtocolTypePair(
605             SSLContext::NextProtocolType::ALPN,
606             SSLContext::NextProtocolType::ALPN),
607         NextProtocolTypePair(
608             SSLContext::NextProtocolType::ALPN,
609             SSLContext::NextProtocolType::ANY),
610         NextProtocolTypePair(
611             SSLContext::NextProtocolType::ANY,
612             SSLContext::NextProtocolType::ALPN)));
613 #endif
614
615 INSTANTIATE_TEST_CASE_P(
616     AsyncSSLSocketTest,
617     NextProtocolNPNOnlyTest,
618     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
619                                            SSLContext::NextProtocolType::NPN)));
620
621 #if FOLLY_OPENSSL_HAS_ALPN
622 INSTANTIATE_TEST_CASE_P(
623     AsyncSSLSocketTest,
624     NextProtocolMismatchTest,
625     ::testing::Values(NextProtocolTypePair(SSLContext::NextProtocolType::NPN,
626                                            SSLContext::NextProtocolType::ALPN),
627                       NextProtocolTypePair(SSLContext::NextProtocolType::ALPN,
628                                            SSLContext::NextProtocolType::NPN)));
629 #endif
630
631 #ifndef OPENSSL_NO_TLSEXT
632 /**
633  * 1. Client sends TLSEXT_HOSTNAME in client hello.
634  * 2. Server found a match SSL_CTX and use this SSL_CTX to
635  *    continue the SSL handshake.
636  * 3. Server sends back TLSEXT_HOSTNAME in server hello.
637  */
638 TEST(AsyncSSLSocketTest, SNITestMatch) {
639   EventBase eventBase;
640   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
641   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
642   // Use the same SSLContext to continue the handshake after
643   // tlsext_hostname match.
644   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
645   const std::string serverName("xyz.newdev.facebook.com");
646   int fds[2];
647   getfds(fds);
648   getctx(clientCtx, dfServerCtx);
649
650   AsyncSSLSocket::UniquePtr clientSock(
651     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
652   AsyncSSLSocket::UniquePtr serverSock(
653     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
654   SNIClient client(std::move(clientSock));
655   SNIServer server(std::move(serverSock),
656                    dfServerCtx,
657                    hskServerCtx,
658                    serverName);
659
660   eventBase.loop();
661
662   EXPECT_TRUE(client.serverNameMatch);
663   EXPECT_TRUE(server.serverNameMatch);
664 }
665
666 /**
667  * 1. Client sends TLSEXT_HOSTNAME in client hello.
668  * 2. Server cannot find a matching SSL_CTX and continue to use
669  *    the current SSL_CTX to do the handshake.
670  * 3. Server does not send back TLSEXT_HOSTNAME in server hello.
671  */
672 TEST(AsyncSSLSocketTest, SNITestNotMatch) {
673   EventBase eventBase;
674   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
675   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
676   // Use the same SSLContext to continue the handshake after
677   // tlsext_hostname match.
678   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
679   const std::string clientRequestingServerName("foo.com");
680   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
681
682   int fds[2];
683   getfds(fds);
684   getctx(clientCtx, dfServerCtx);
685
686   AsyncSSLSocket::UniquePtr clientSock(
687     new AsyncSSLSocket(clientCtx,
688                         &eventBase,
689                         fds[0],
690                         clientRequestingServerName));
691   AsyncSSLSocket::UniquePtr serverSock(
692     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
693   SNIClient client(std::move(clientSock));
694   SNIServer server(std::move(serverSock),
695                    dfServerCtx,
696                    hskServerCtx,
697                    serverExpectedServerName);
698
699   eventBase.loop();
700
701   EXPECT_TRUE(!client.serverNameMatch);
702   EXPECT_TRUE(!server.serverNameMatch);
703 }
704 /**
705  * 1. Client sends TLSEXT_HOSTNAME in client hello.
706  * 2. We then change the serverName.
707  * 3. We expect that we get 'false' as the result for serNameMatch.
708  */
709
710 TEST(AsyncSSLSocketTest, SNITestChangeServerName) {
711    EventBase eventBase;
712   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
713   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
714   // Use the same SSLContext to continue the handshake after
715   // tlsext_hostname match.
716   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
717   const std::string serverName("xyz.newdev.facebook.com");
718   int fds[2];
719   getfds(fds);
720   getctx(clientCtx, dfServerCtx);
721
722   AsyncSSLSocket::UniquePtr clientSock(
723     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
724   //Change the server name
725   std::string newName("new.com");
726   clientSock->setServerName(newName);
727   AsyncSSLSocket::UniquePtr serverSock(
728     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
729   SNIClient client(std::move(clientSock));
730   SNIServer server(std::move(serverSock),
731                    dfServerCtx,
732                    hskServerCtx,
733                    serverName);
734
735   eventBase.loop();
736
737   EXPECT_TRUE(!client.serverNameMatch);
738 }
739
740 /**
741  * 1. Client does not send TLSEXT_HOSTNAME in client hello.
742  * 2. Server does not send back TLSEXT_HOSTNAME in server hello.
743  */
744 TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
745   EventBase eventBase;
746   std::shared_ptr<SSLContext> clientCtx(new SSLContext);
747   std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
748   // Use the same SSLContext to continue the handshake after
749   // tlsext_hostname match.
750   std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
751   const std::string serverExpectedServerName("xyz.newdev.facebook.com");
752
753   int fds[2];
754   getfds(fds);
755   getctx(clientCtx, dfServerCtx);
756
757   AsyncSSLSocket::UniquePtr clientSock(
758     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
759   AsyncSSLSocket::UniquePtr serverSock(
760     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
761   SNIClient client(std::move(clientSock));
762   SNIServer server(std::move(serverSock),
763                    dfServerCtx,
764                    hskServerCtx,
765                    serverExpectedServerName);
766
767   eventBase.loop();
768
769   EXPECT_TRUE(!client.serverNameMatch);
770   EXPECT_TRUE(!server.serverNameMatch);
771 }
772
773 #endif
774 /**
775  * Test SSL client socket
776  */
777 TEST(AsyncSSLSocketTest, SSLClientTest) {
778   // Start listening on a local port
779   WriteCallbackBase writeCallback;
780   ReadCallback readCallback(&writeCallback);
781   HandshakeCallback handshakeCallback(&readCallback);
782   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
783   TestSSLServer server(&acceptCallback);
784
785   // Set up SSL client
786   EventBase eventBase;
787   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1);
788
789   client->connect();
790   EventBaseAborter eba(&eventBase, 3000);
791   eventBase.loop();
792
793   EXPECT_EQ(client->getMiss(), 1);
794   EXPECT_EQ(client->getHit(), 0);
795
796   cerr << "SSLClientTest test completed" << endl;
797 }
798
799
800 /**
801  * Test SSL client socket session re-use
802  */
803 TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
804   // Start listening on a local port
805   WriteCallbackBase writeCallback;
806   ReadCallback readCallback(&writeCallback);
807   HandshakeCallback handshakeCallback(&readCallback);
808   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
809   TestSSLServer server(&acceptCallback);
810
811   // Set up SSL client
812   EventBase eventBase;
813   auto client =
814       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10);
815
816   client->connect();
817   EventBaseAborter eba(&eventBase, 3000);
818   eventBase.loop();
819
820   EXPECT_EQ(client->getMiss(), 1);
821   EXPECT_EQ(client->getHit(), 9);
822
823   cerr << "SSLClientTestReuse test completed" << endl;
824 }
825
826 /**
827  * Test SSL client socket timeout
828  */
829 TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
830   // Start listening on a local port
831   EmptyReadCallback readCallback;
832   HandshakeCallback handshakeCallback(&readCallback,
833                                       HandshakeCallback::EXPECT_ERROR);
834   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
835   TestSSLServer server(&acceptCallback);
836
837   // Set up SSL client
838   EventBase eventBase;
839   auto client =
840       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 1, 10);
841   client->connect(true /* write before connect completes */);
842   EventBaseAborter eba(&eventBase, 3000);
843   eventBase.loop();
844
845   usleep(100000);
846   // This is checking that the connectError callback precedes any queued
847   // writeError callbacks.  This matches AsyncSocket's behavior
848   EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
849   EXPECT_EQ(client->getErrors(), 1);
850   EXPECT_EQ(client->getMiss(), 0);
851   EXPECT_EQ(client->getHit(), 0);
852
853   cerr << "SSLClientTimeoutTest test completed" << endl;
854 }
855
856 // The next 3 tests need an FB-only extension, and will fail without it
857 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
858 /**
859  * Test SSL server async cache
860  */
861 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
862   // Start listening on a local port
863   WriteCallbackBase writeCallback;
864   ReadCallback readCallback(&writeCallback);
865   HandshakeCallback handshakeCallback(&readCallback);
866   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
867   TestSSLAsyncCacheServer server(&acceptCallback);
868
869   // Set up SSL client
870   EventBase eventBase;
871   auto client =
872       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 10, 500);
873
874   client->connect();
875   EventBaseAborter eba(&eventBase, 3000);
876   eventBase.loop();
877
878   EXPECT_EQ(server.getAsyncCallbacks(), 18);
879   EXPECT_EQ(server.getAsyncLookups(), 9);
880   EXPECT_EQ(client->getMiss(), 10);
881   EXPECT_EQ(client->getHit(), 0);
882
883   cerr << "SSLServerAsyncCacheTest test completed" << endl;
884 }
885
886 /**
887  * Test SSL server accept timeout with cache path
888  */
889 TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
890   // Start listening on a local port
891   WriteCallbackBase writeCallback;
892   ReadCallback readCallback(&writeCallback);
893   HandshakeCallback handshakeCallback(&readCallback);
894   SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
895   TestSSLAsyncCacheServer server(&acceptCallback);
896
897   // Set up SSL client
898   EventBase eventBase;
899   // only do a TCP connect
900   std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
901   sock->connect(nullptr, server.getAddress());
902
903   EmptyReadCallback clientReadCallback;
904   clientReadCallback.tcpSocket_ = sock;
905   sock->setReadCB(&clientReadCallback);
906
907   EventBaseAborter eba(&eventBase, 3000);
908   eventBase.loop();
909
910   EXPECT_EQ(readCallback.state, STATE_WAITING);
911
912   cerr << "SSLServerTimeoutTest test completed" << endl;
913 }
914
915 /**
916  * Test SSL server accept timeout with cache path
917  */
918 TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
919   // Start listening on a local port
920   WriteCallbackBase writeCallback;
921   ReadCallback readCallback(&writeCallback);
922   HandshakeCallback handshakeCallback(&readCallback);
923   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
924   TestSSLAsyncCacheServer server(&acceptCallback);
925
926   // Set up SSL client
927   EventBase eventBase;
928   auto client = std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2);
929
930   client->connect();
931   EventBaseAborter eba(&eventBase, 3000);
932   eventBase.loop();
933
934   EXPECT_EQ(server.getAsyncCallbacks(), 1);
935   EXPECT_EQ(server.getAsyncLookups(), 1);
936   EXPECT_EQ(client->getErrors(), 1);
937   EXPECT_EQ(client->getMiss(), 1);
938   EXPECT_EQ(client->getHit(), 0);
939
940   cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
941 }
942
943 /**
944  * Test SSL server accept timeout with cache path
945  */
946 TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
947   // Start listening on a local port
948   WriteCallbackBase writeCallback;
949   ReadCallback readCallback(&writeCallback);
950   HandshakeCallback handshakeCallback(&readCallback,
951                                       HandshakeCallback::EXPECT_ERROR);
952   SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
953   TestSSLAsyncCacheServer server(&acceptCallback, 500);
954
955   // Set up SSL client
956   EventBase eventBase;
957   auto client =
958       std::make_shared<SSLClient>(&eventBase, server.getAddress(), 2, 100);
959
960   client->connect();
961   EventBaseAborter eba(&eventBase, 3000);
962   eventBase.loop();
963
964   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
965       handshakeCallback.closeSocket();});
966   // give time for the cache lookup to come back and find it closed
967   handshakeCallback.waitForHandshake();
968
969   EXPECT_EQ(server.getAsyncCallbacks(), 1);
970   EXPECT_EQ(server.getAsyncLookups(), 1);
971   EXPECT_EQ(client->getErrors(), 1);
972   EXPECT_EQ(client->getMiss(), 1);
973   EXPECT_EQ(client->getHit(), 0);
974
975   cerr << "SSLServerCacheCloseTest test completed" << endl;
976 }
977 #endif // !SSL_ERROR_WANT_SESS_CACHE_LOOKUP
978
979 /**
980  * Verify Client Ciphers obtained using SSL MSG Callback.
981  */
982 TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
983   EventBase eventBase;
984   auto clientCtx = std::make_shared<SSLContext>();
985   auto serverCtx = std::make_shared<SSLContext>();
986   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
987   serverCtx->ciphers("ECDHE-RSA-AES128-SHA:AES128-SHA:AES256-SHA");
988   serverCtx->loadPrivateKey(kTestKey);
989   serverCtx->loadCertificate(kTestCert);
990   serverCtx->loadTrustedCertificates(kTestCA);
991   serverCtx->loadClientCAList(kTestCA);
992
993   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
994   clientCtx->ciphers("AES256-SHA:AES128-SHA");
995   clientCtx->loadPrivateKey(kTestKey);
996   clientCtx->loadCertificate(kTestCert);
997   clientCtx->loadTrustedCertificates(kTestCA);
998
999   int fds[2];
1000   getfds(fds);
1001
1002   AsyncSSLSocket::UniquePtr clientSock(
1003       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1004   AsyncSSLSocket::UniquePtr serverSock(
1005       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1006
1007   SSLHandshakeClient client(std::move(clientSock), true, true);
1008   SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
1009
1010   eventBase.loop();
1011
1012 #if defined(OPENSSL_IS_BORINGSSL)
1013   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA");
1014 #else
1015   EXPECT_EQ(server.clientCiphers_, "AES256-SHA:AES128-SHA:00ff");
1016 #endif
1017   EXPECT_EQ(server.chosenCipher_, "AES256-SHA");
1018   EXPECT_TRUE(client.handshakeVerify_);
1019   EXPECT_TRUE(client.handshakeSuccess_);
1020   EXPECT_TRUE(!client.handshakeError_);
1021   EXPECT_TRUE(server.handshakeVerify_);
1022   EXPECT_TRUE(server.handshakeSuccess_);
1023   EXPECT_TRUE(!server.handshakeError_);
1024 }
1025
1026 TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
1027   EventBase eventBase;
1028   auto ctx = std::make_shared<SSLContext>();
1029
1030   int fds[2];
1031   getfds(fds);
1032
1033   int bufLen = 42;
1034   uint8_t majorVersion = 18;
1035   uint8_t minorVersion = 25;
1036
1037   // Create callback buf
1038   auto buf = IOBuf::create(bufLen);
1039   buf->append(bufLen);
1040   folly::io::RWPrivateCursor cursor(buf.get());
1041   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1042   cursor.write<uint16_t>(0);
1043   cursor.write<uint8_t>(38);
1044   cursor.write<uint8_t>(majorVersion);
1045   cursor.write<uint8_t>(minorVersion);
1046   cursor.skip(32);
1047   cursor.write<uint32_t>(0);
1048
1049   SSL* ssl = ctx->createSSL();
1050   SCOPE_EXIT { SSL_free(ssl); };
1051   AsyncSSLSocket::UniquePtr sock(
1052       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1053   sock->enableClientHelloParsing();
1054
1055   // Test client hello parsing in one packet
1056   AsyncSSLSocket::clientHelloParsingCallback(
1057       0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
1058   buf.reset();
1059
1060   auto parsedClientHello = sock->getClientHelloInfo();
1061   EXPECT_TRUE(parsedClientHello != nullptr);
1062   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1063   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1064 }
1065
1066 TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
1067   EventBase eventBase;
1068   auto ctx = std::make_shared<SSLContext>();
1069
1070   int fds[2];
1071   getfds(fds);
1072
1073   int bufLen = 42;
1074   uint8_t majorVersion = 18;
1075   uint8_t minorVersion = 25;
1076
1077   // Create callback buf
1078   auto buf = IOBuf::create(bufLen);
1079   buf->append(bufLen);
1080   folly::io::RWPrivateCursor cursor(buf.get());
1081   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1082   cursor.write<uint16_t>(0);
1083   cursor.write<uint8_t>(38);
1084   cursor.write<uint8_t>(majorVersion);
1085   cursor.write<uint8_t>(minorVersion);
1086   cursor.skip(32);
1087   cursor.write<uint32_t>(0);
1088
1089   SSL* ssl = ctx->createSSL();
1090   SCOPE_EXIT { SSL_free(ssl); };
1091   AsyncSSLSocket::UniquePtr sock(
1092       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1093   sock->enableClientHelloParsing();
1094
1095   // Test parsing with two packets with first packet size < 3
1096   auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
1097   AsyncSSLSocket::clientHelloParsingCallback(
1098       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1099       ssl, sock.get());
1100   bufCopy.reset();
1101   bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
1102   AsyncSSLSocket::clientHelloParsingCallback(
1103       0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1104       ssl, sock.get());
1105   bufCopy.reset();
1106
1107   auto parsedClientHello = sock->getClientHelloInfo();
1108   EXPECT_TRUE(parsedClientHello != nullptr);
1109   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1110   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1111 }
1112
1113 TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
1114   EventBase eventBase;
1115   auto ctx = std::make_shared<SSLContext>();
1116
1117   int fds[2];
1118   getfds(fds);
1119
1120   int bufLen = 42;
1121   uint8_t majorVersion = 18;
1122   uint8_t minorVersion = 25;
1123
1124   // Create callback buf
1125   auto buf = IOBuf::create(bufLen);
1126   buf->append(bufLen);
1127   folly::io::RWPrivateCursor cursor(buf.get());
1128   cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
1129   cursor.write<uint16_t>(0);
1130   cursor.write<uint8_t>(38);
1131   cursor.write<uint8_t>(majorVersion);
1132   cursor.write<uint8_t>(minorVersion);
1133   cursor.skip(32);
1134   cursor.write<uint32_t>(0);
1135
1136   SSL* ssl = ctx->createSSL();
1137   SCOPE_EXIT { SSL_free(ssl); };
1138   AsyncSSLSocket::UniquePtr sock(
1139       new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
1140   sock->enableClientHelloParsing();
1141
1142   // Test parsing with multiple small packets
1143   for (uint64_t i = 0; i < buf->length(); i += 3) {
1144     auto bufCopy = folly::IOBuf::copyBuffer(
1145         buf->data() + i, std::min((uint64_t)3, buf->length() - i));
1146     AsyncSSLSocket::clientHelloParsingCallback(
1147         0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
1148         ssl, sock.get());
1149     bufCopy.reset();
1150   }
1151
1152   auto parsedClientHello = sock->getClientHelloInfo();
1153   EXPECT_TRUE(parsedClientHello != nullptr);
1154   EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
1155   EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
1156 }
1157
1158 /**
1159  * Verify sucessful behavior of SSL certificate validation.
1160  */
1161 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
1162   EventBase eventBase;
1163   auto clientCtx = std::make_shared<SSLContext>();
1164   auto dfServerCtx = std::make_shared<SSLContext>();
1165
1166   int fds[2];
1167   getfds(fds);
1168   getctx(clientCtx, dfServerCtx);
1169
1170   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1171   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1172
1173   AsyncSSLSocket::UniquePtr clientSock(
1174     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1175   AsyncSSLSocket::UniquePtr serverSock(
1176     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1177
1178   SSLHandshakeClient client(std::move(clientSock), true, true);
1179   clientCtx->loadTrustedCertificates(kTestCA);
1180
1181   SSLHandshakeServer server(std::move(serverSock), true, true);
1182
1183   eventBase.loop();
1184
1185   EXPECT_TRUE(client.handshakeVerify_);
1186   EXPECT_TRUE(client.handshakeSuccess_);
1187   EXPECT_TRUE(!client.handshakeError_);
1188   EXPECT_LE(0, client.handshakeTime.count());
1189   EXPECT_TRUE(!server.handshakeVerify_);
1190   EXPECT_TRUE(server.handshakeSuccess_);
1191   EXPECT_TRUE(!server.handshakeError_);
1192   EXPECT_LE(0, server.handshakeTime.count());
1193 }
1194
1195 /**
1196  * Verify that the client's verification callback is able to fail SSL
1197  * connection establishment.
1198  */
1199 TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
1200   EventBase eventBase;
1201   auto clientCtx = std::make_shared<SSLContext>();
1202   auto dfServerCtx = std::make_shared<SSLContext>();
1203
1204   int fds[2];
1205   getfds(fds);
1206   getctx(clientCtx, dfServerCtx);
1207
1208   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1209   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1210
1211   AsyncSSLSocket::UniquePtr clientSock(
1212     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1213   AsyncSSLSocket::UniquePtr serverSock(
1214     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1215
1216   SSLHandshakeClient client(std::move(clientSock), true, false);
1217   clientCtx->loadTrustedCertificates(kTestCA);
1218
1219   SSLHandshakeServer server(std::move(serverSock), true, true);
1220
1221   eventBase.loop();
1222
1223   EXPECT_TRUE(client.handshakeVerify_);
1224   EXPECT_TRUE(!client.handshakeSuccess_);
1225   EXPECT_TRUE(client.handshakeError_);
1226   EXPECT_LE(0, client.handshakeTime.count());
1227   EXPECT_TRUE(!server.handshakeVerify_);
1228   EXPECT_TRUE(!server.handshakeSuccess_);
1229   EXPECT_TRUE(server.handshakeError_);
1230   EXPECT_LE(0, server.handshakeTime.count());
1231 }
1232
1233 /**
1234  * Verify that the options in SSLContext can be overridden in
1235  * sslConnect/Accept.i.e specifying that no validation should be performed
1236  * allows an otherwise-invalid certificate to be accepted and doesn't fire
1237  * the validation callback.
1238  */
1239 TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
1240   EventBase eventBase;
1241   auto clientCtx = std::make_shared<SSLContext>();
1242   auto dfServerCtx = std::make_shared<SSLContext>();
1243
1244   int fds[2];
1245   getfds(fds);
1246   getctx(clientCtx, dfServerCtx);
1247
1248   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1249   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1250
1251   AsyncSSLSocket::UniquePtr clientSock(
1252     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1253   AsyncSSLSocket::UniquePtr serverSock(
1254     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1255
1256   SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
1257   clientCtx->loadTrustedCertificates(kTestCA);
1258
1259   SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
1260
1261   eventBase.loop();
1262
1263   EXPECT_TRUE(!client.handshakeVerify_);
1264   EXPECT_TRUE(client.handshakeSuccess_);
1265   EXPECT_TRUE(!client.handshakeError_);
1266   EXPECT_LE(0, client.handshakeTime.count());
1267   EXPECT_TRUE(!server.handshakeVerify_);
1268   EXPECT_TRUE(server.handshakeSuccess_);
1269   EXPECT_TRUE(!server.handshakeError_);
1270   EXPECT_LE(0, server.handshakeTime.count());
1271 }
1272
1273 /**
1274  * Verify that the options in SSLContext can be overridden in
1275  * sslConnect/Accept. Enable verification even if context says otherwise.
1276  * Test requireClientCert with client cert
1277  */
1278 TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
1279   EventBase eventBase;
1280   auto clientCtx = std::make_shared<SSLContext>();
1281   auto serverCtx = std::make_shared<SSLContext>();
1282   serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1283   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1284   serverCtx->loadPrivateKey(kTestKey);
1285   serverCtx->loadCertificate(kTestCert);
1286   serverCtx->loadTrustedCertificates(kTestCA);
1287   serverCtx->loadClientCAList(kTestCA);
1288
1289   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1290   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1291   clientCtx->loadPrivateKey(kTestKey);
1292   clientCtx->loadCertificate(kTestCert);
1293   clientCtx->loadTrustedCertificates(kTestCA);
1294
1295   int fds[2];
1296   getfds(fds);
1297
1298   AsyncSSLSocket::UniquePtr clientSock(
1299       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1300   AsyncSSLSocket::UniquePtr serverSock(
1301       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1302
1303   SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
1304   SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
1305
1306   eventBase.loop();
1307
1308   EXPECT_TRUE(client.handshakeVerify_);
1309   EXPECT_TRUE(client.handshakeSuccess_);
1310   EXPECT_FALSE(client.handshakeError_);
1311   EXPECT_LE(0, client.handshakeTime.count());
1312   EXPECT_TRUE(server.handshakeVerify_);
1313   EXPECT_TRUE(server.handshakeSuccess_);
1314   EXPECT_FALSE(server.handshakeError_);
1315   EXPECT_LE(0, server.handshakeTime.count());
1316 }
1317
1318 /**
1319  * Verify that the client's verification callback is able to override
1320  * the preverification failure and allow a successful connection.
1321  */
1322 TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
1323   EventBase eventBase;
1324   auto clientCtx = std::make_shared<SSLContext>();
1325   auto dfServerCtx = std::make_shared<SSLContext>();
1326
1327   int fds[2];
1328   getfds(fds);
1329   getctx(clientCtx, dfServerCtx);
1330
1331   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1332   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1333
1334   AsyncSSLSocket::UniquePtr clientSock(
1335     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1336   AsyncSSLSocket::UniquePtr serverSock(
1337     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1338
1339   SSLHandshakeClient client(std::move(clientSock), false, true);
1340   SSLHandshakeServer server(std::move(serverSock), true, true);
1341
1342   eventBase.loop();
1343
1344   EXPECT_TRUE(client.handshakeVerify_);
1345   EXPECT_TRUE(client.handshakeSuccess_);
1346   EXPECT_TRUE(!client.handshakeError_);
1347   EXPECT_LE(0, client.handshakeTime.count());
1348   EXPECT_TRUE(!server.handshakeVerify_);
1349   EXPECT_TRUE(server.handshakeSuccess_);
1350   EXPECT_TRUE(!server.handshakeError_);
1351   EXPECT_LE(0, server.handshakeTime.count());
1352 }
1353
1354 /**
1355  * Verify that specifying that no validation should be performed allows an
1356  * otherwise-invalid certificate to be accepted and doesn't fire the validation
1357  * callback.
1358  */
1359 TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
1360   EventBase eventBase;
1361   auto clientCtx = std::make_shared<SSLContext>();
1362   auto dfServerCtx = std::make_shared<SSLContext>();
1363
1364   int fds[2];
1365   getfds(fds);
1366   getctx(clientCtx, dfServerCtx);
1367
1368   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1369   dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1370
1371   AsyncSSLSocket::UniquePtr clientSock(
1372     new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1373   AsyncSSLSocket::UniquePtr serverSock(
1374     new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
1375
1376   SSLHandshakeClient client(std::move(clientSock), false, false);
1377   SSLHandshakeServer server(std::move(serverSock), false, false);
1378
1379   eventBase.loop();
1380
1381   EXPECT_TRUE(!client.handshakeVerify_);
1382   EXPECT_TRUE(client.handshakeSuccess_);
1383   EXPECT_TRUE(!client.handshakeError_);
1384   EXPECT_LE(0, client.handshakeTime.count());
1385   EXPECT_TRUE(!server.handshakeVerify_);
1386   EXPECT_TRUE(server.handshakeSuccess_);
1387   EXPECT_TRUE(!server.handshakeError_);
1388   EXPECT_LE(0, server.handshakeTime.count());
1389 }
1390
1391 /**
1392  * Test requireClientCert with client cert
1393  */
1394 TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
1395   EventBase eventBase;
1396   auto clientCtx = std::make_shared<SSLContext>();
1397   auto serverCtx = std::make_shared<SSLContext>();
1398   serverCtx->setVerificationOption(
1399       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1400   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1401   serverCtx->loadPrivateKey(kTestKey);
1402   serverCtx->loadCertificate(kTestCert);
1403   serverCtx->loadTrustedCertificates(kTestCA);
1404   serverCtx->loadClientCAList(kTestCA);
1405
1406   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
1407   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1408   clientCtx->loadPrivateKey(kTestKey);
1409   clientCtx->loadCertificate(kTestCert);
1410   clientCtx->loadTrustedCertificates(kTestCA);
1411
1412   int fds[2];
1413   getfds(fds);
1414
1415   AsyncSSLSocket::UniquePtr clientSock(
1416       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1417   AsyncSSLSocket::UniquePtr serverSock(
1418       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1419
1420   SSLHandshakeClient client(std::move(clientSock), true, true);
1421   SSLHandshakeServer server(std::move(serverSock), true, true);
1422
1423   eventBase.loop();
1424
1425   EXPECT_TRUE(client.handshakeVerify_);
1426   EXPECT_TRUE(client.handshakeSuccess_);
1427   EXPECT_FALSE(client.handshakeError_);
1428   EXPECT_LE(0, client.handshakeTime.count());
1429   EXPECT_TRUE(server.handshakeVerify_);
1430   EXPECT_TRUE(server.handshakeSuccess_);
1431   EXPECT_FALSE(server.handshakeError_);
1432   EXPECT_LE(0, server.handshakeTime.count());
1433 }
1434
1435
1436 /**
1437  * Test requireClientCert with no client cert
1438  */
1439 TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
1440   EventBase eventBase;
1441   auto clientCtx = std::make_shared<SSLContext>();
1442   auto serverCtx = std::make_shared<SSLContext>();
1443   serverCtx->setVerificationOption(
1444       SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
1445   serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1446   serverCtx->loadPrivateKey(kTestKey);
1447   serverCtx->loadCertificate(kTestCert);
1448   serverCtx->loadTrustedCertificates(kTestCA);
1449   serverCtx->loadClientCAList(kTestCA);
1450   clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
1451   clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1452
1453   int fds[2];
1454   getfds(fds);
1455
1456   AsyncSSLSocket::UniquePtr clientSock(
1457       new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
1458   AsyncSSLSocket::UniquePtr serverSock(
1459       new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
1460
1461   SSLHandshakeClient client(std::move(clientSock), false, false);
1462   SSLHandshakeServer server(std::move(serverSock), false, false);
1463
1464   eventBase.loop();
1465
1466   EXPECT_FALSE(server.handshakeVerify_);
1467   EXPECT_FALSE(server.handshakeSuccess_);
1468   EXPECT_TRUE(server.handshakeError_);
1469   EXPECT_LE(0, client.handshakeTime.count());
1470   EXPECT_LE(0, server.handshakeTime.count());
1471 }
1472
1473 TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
1474   auto cert = getFileAsBuf(kTestCert);
1475   auto key = getFileAsBuf(kTestKey);
1476
1477   ssl::BioUniquePtr certBio(BIO_new(BIO_s_mem()));
1478   BIO_write(certBio.get(), cert.data(), cert.size());
1479   ssl::BioUniquePtr keyBio(BIO_new(BIO_s_mem()));
1480   BIO_write(keyBio.get(), key.data(), key.size());
1481
1482   // Create SSL structs from buffers to get properties
1483   ssl::X509UniquePtr certStruct(
1484       PEM_read_bio_X509(certBio.get(), nullptr, nullptr, nullptr));
1485   ssl::EvpPkeyUniquePtr keyStruct(
1486       PEM_read_bio_PrivateKey(keyBio.get(), nullptr, nullptr, nullptr));
1487   certBio = nullptr;
1488   keyBio = nullptr;
1489
1490   auto origCommonName = getCommonName(certStruct.get());
1491   auto origKeySize = EVP_PKEY_bits(keyStruct.get());
1492   certStruct = nullptr;
1493   keyStruct = nullptr;
1494
1495   auto ctx = std::make_shared<SSLContext>();
1496   ctx->loadPrivateKeyFromBufferPEM(key);
1497   ctx->loadCertificateFromBufferPEM(cert);
1498   ctx->loadTrustedCertificates(kTestCA);
1499
1500   ssl::SSLUniquePtr ssl(ctx->createSSL());
1501
1502   auto newCert = SSL_get_certificate(ssl.get());
1503   auto newKey = SSL_get_privatekey(ssl.get());
1504
1505   // Get properties from SSL struct
1506   auto newCommonName = getCommonName(newCert);
1507   auto newKeySize = EVP_PKEY_bits(newKey);
1508
1509   // Check that the key and cert have the expected properties
1510   EXPECT_EQ(origCommonName, newCommonName);
1511   EXPECT_EQ(origKeySize, newKeySize);
1512 }
1513
1514 TEST(AsyncSSLSocketTest, MinWriteSizeTest) {
1515   EventBase eb;
1516
1517   // Set up SSL context.
1518   auto sslContext = std::make_shared<SSLContext>();
1519   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
1520
1521   // create SSL socket
1522   AsyncSSLSocket::UniquePtr socket(new AsyncSSLSocket(sslContext, &eb));
1523
1524   EXPECT_EQ(1500, socket->getMinWriteSize());
1525
1526   socket->setMinWriteSize(0);
1527   EXPECT_EQ(0, socket->getMinWriteSize());
1528   socket->setMinWriteSize(50000);
1529   EXPECT_EQ(50000, socket->getMinWriteSize());
1530 }
1531
1532 class ReadCallbackTerminator : public ReadCallback {
1533  public:
1534   ReadCallbackTerminator(EventBase* base, WriteCallbackBase *wcb)
1535       : ReadCallback(wcb)
1536       , base_(base) {}
1537
1538   // Do not write data back, terminate the loop.
1539   void readDataAvailable(size_t len) noexcept override {
1540     std::cerr << "readDataAvailable, len " << len << std::endl;
1541
1542     currentBuffer.length = len;
1543
1544     buffers.push_back(currentBuffer);
1545     currentBuffer.reset();
1546     state = STATE_SUCCEEDED;
1547
1548     socket_->setReadCB(nullptr);
1549     base_->terminateLoopSoon();
1550   }
1551  private:
1552   EventBase* base_;
1553 };
1554
1555
1556 /**
1557  * Test a full unencrypted codepath
1558  */
1559 TEST(AsyncSSLSocketTest, UnencryptedTest) {
1560   EventBase base;
1561
1562   auto clientCtx = std::make_shared<folly::SSLContext>();
1563   auto serverCtx = std::make_shared<folly::SSLContext>();
1564   int fds[2];
1565   getfds(fds);
1566   getctx(clientCtx, serverCtx);
1567   auto client = AsyncSSLSocket::newSocket(
1568                   clientCtx, &base, fds[0], false, true);
1569   auto server = AsyncSSLSocket::newSocket(
1570                   serverCtx, &base, fds[1], true, true);
1571
1572   ReadCallbackTerminator readCallback(&base, nullptr);
1573   server->setReadCB(&readCallback);
1574   readCallback.setSocket(server);
1575
1576   uint8_t buf[128];
1577   memset(buf, 'a', sizeof(buf));
1578   client->write(nullptr, buf, sizeof(buf));
1579
1580   // Check that bytes are unencrypted
1581   char c;
1582   EXPECT_EQ(1, recv(fds[1], &c, 1, MSG_PEEK));
1583   EXPECT_EQ('a', c);
1584
1585   EventBaseAborter eba(&base, 3000);
1586   base.loop();
1587
1588   EXPECT_EQ(1, readCallback.buffers.size());
1589   EXPECT_EQ(AsyncSSLSocket::STATE_UNENCRYPTED, client->getSSLState());
1590
1591   server->setReadCB(&readCallback);
1592
1593   // Unencrypted
1594   server->sslAccept(nullptr);
1595   client->sslConn(nullptr);
1596
1597   // Do NOT wait for handshake, writing should be queued and happen after
1598
1599   client->write(nullptr, buf, sizeof(buf));
1600
1601   // Check that bytes are *not* unencrypted
1602   char c2;
1603   EXPECT_EQ(1, recv(fds[1], &c2, 1, MSG_PEEK));
1604   EXPECT_NE('a', c2);
1605
1606
1607   base.loop();
1608
1609   EXPECT_EQ(2, readCallback.buffers.size());
1610   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
1611 }
1612
1613 TEST(AsyncSSLSocketTest, ConnResetErrorString) {
1614   // Start listening on a local port
1615   WriteCallbackBase writeCallback;
1616   WriteErrorCallback readCallback(&writeCallback);
1617   HandshakeCallback handshakeCallback(&readCallback,
1618                                       HandshakeCallback::EXPECT_ERROR);
1619   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1620   TestSSLServer server(&acceptCallback);
1621
1622   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1623   socket->open();
1624   uint8_t buf[3] = {0x16, 0x03, 0x01};
1625   socket->write(buf, sizeof(buf));
1626   socket->closeWithReset();
1627
1628   handshakeCallback.waitForHandshake();
1629   EXPECT_NE(
1630       handshakeCallback.errorString_.find("Network error"), std::string::npos);
1631   EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
1632 }
1633
1634 TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
1635   // Start listening on a local port
1636   WriteCallbackBase writeCallback;
1637   WriteErrorCallback readCallback(&writeCallback);
1638   HandshakeCallback handshakeCallback(&readCallback,
1639                                       HandshakeCallback::EXPECT_ERROR);
1640   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1641   TestSSLServer server(&acceptCallback);
1642
1643   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1644   socket->open();
1645   uint8_t buf[3] = {0x16, 0x03, 0x01};
1646   socket->write(buf, sizeof(buf));
1647   socket->close();
1648
1649   handshakeCallback.waitForHandshake();
1650   EXPECT_NE(
1651       handshakeCallback.errorString_.find("Connection EOF"), std::string::npos);
1652   EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
1653 }
1654
1655 TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
1656   // Start listening on a local port
1657   WriteCallbackBase writeCallback;
1658   WriteErrorCallback readCallback(&writeCallback);
1659   HandshakeCallback handshakeCallback(&readCallback,
1660                                       HandshakeCallback::EXPECT_ERROR);
1661   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1662   TestSSLServer server(&acceptCallback);
1663
1664   auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
1665   socket->open();
1666   uint8_t buf[256] = {0x16, 0x03};
1667   memset(buf + 2, 'a', sizeof(buf) - 2);
1668   socket->write(buf, sizeof(buf));
1669   socket->close();
1670
1671   handshakeCallback.waitForHandshake();
1672   EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
1673             std::string::npos);
1674 #if defined(OPENSSL_IS_BORINGSSL)
1675   EXPECT_NE(
1676       handshakeCallback.errorString_.find("ENCRYPTED_LENGTH_TOO_LONG"),
1677       std::string::npos);
1678 #else
1679   EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
1680             std::string::npos);
1681 #endif
1682 }
1683
1684 TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
1685   using folly::ssl::OpenSSLUtils;
1686   EXPECT_EQ(
1687       OpenSSLUtils::getCipherName(0xc02c), "ECDHE-ECDSA-AES256-GCM-SHA384");
1688   // TLS_DHE_RSA_WITH_DES_CBC_SHA - We shouldn't be building with this
1689   EXPECT_EQ(OpenSSLUtils::getCipherName(0x0015), "");
1690   // This indicates TLS_EMPTY_RENEGOTIATION_INFO_SCSV, no name expected
1691   EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
1692 }
1693
1694 #if FOLLY_ALLOW_TFO
1695
1696 class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
1697  public:
1698   using UniquePtr = std::unique_ptr<MockAsyncTFOSSLSocket, Destructor>;
1699
1700   explicit MockAsyncTFOSSLSocket(
1701       std::shared_ptr<folly::SSLContext> sslCtx,
1702       EventBase* evb)
1703       : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
1704
1705   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
1706 };
1707
1708 /**
1709  * Test connecting to, writing to, reading from, and closing the
1710  * connection to the SSL server with TFO.
1711  */
1712 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFO) {
1713   // Start listening on a local port
1714   WriteCallbackBase writeCallback;
1715   ReadCallback readCallback(&writeCallback);
1716   HandshakeCallback handshakeCallback(&readCallback);
1717   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1718   TestSSLServer server(&acceptCallback, true);
1719
1720   // Set up SSL context.
1721   auto sslContext = std::make_shared<SSLContext>();
1722
1723   // connect
1724   auto socket =
1725       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1726   socket->enableTFO();
1727   socket->open();
1728
1729   // write()
1730   std::array<uint8_t, 128> buf;
1731   memset(buf.data(), 'a', buf.size());
1732   socket->write(buf.data(), buf.size());
1733
1734   // read()
1735   std::array<uint8_t, 128> readbuf;
1736   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1737   EXPECT_EQ(bytesRead, 128);
1738   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1739
1740   // close()
1741   socket->close();
1742 }
1743
1744 /**
1745  * Test connecting to, writing to, reading from, and closing the
1746  * connection to the SSL server with TFO.
1747  */
1748 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOWithTFOServerDisabled) {
1749   // Start listening on a local port
1750   WriteCallbackBase writeCallback;
1751   ReadCallback readCallback(&writeCallback);
1752   HandshakeCallback handshakeCallback(&readCallback);
1753   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1754   TestSSLServer server(&acceptCallback, false);
1755
1756   // Set up SSL context.
1757   auto sslContext = std::make_shared<SSLContext>();
1758
1759   // connect
1760   auto socket =
1761       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1762   socket->enableTFO();
1763   socket->open();
1764
1765   // write()
1766   std::array<uint8_t, 128> buf;
1767   memset(buf.data(), 'a', buf.size());
1768   socket->write(buf.data(), buf.size());
1769
1770   // read()
1771   std::array<uint8_t, 128> readbuf;
1772   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
1773   EXPECT_EQ(bytesRead, 128);
1774   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1775
1776   // close()
1777   socket->close();
1778 }
1779
1780 class ConnCallback : public AsyncSocket::ConnectCallback {
1781  public:
1782   void connectSuccess() noexcept override {
1783     state = State::SUCCESS;
1784   }
1785
1786   void connectErr(const AsyncSocketException& ex) noexcept override {
1787     state = State::ERROR;
1788     error = ex.what();
1789   }
1790
1791   enum class State { WAITING, SUCCESS, ERROR };
1792
1793   State state{State::WAITING};
1794   std::string error;
1795 };
1796
1797 template <class Cardinality>
1798 MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
1799     EventBase* evb,
1800     const SocketAddress& address,
1801     Cardinality cardinality) {
1802   // Set up SSL context.
1803   auto sslContext = std::make_shared<SSLContext>();
1804
1805   // connect
1806   auto socket = MockAsyncTFOSSLSocket::UniquePtr(
1807       new MockAsyncTFOSSLSocket(sslContext, evb));
1808   socket->enableTFO();
1809
1810   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
1811       .Times(cardinality)
1812       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
1813         sockaddr_storage addr;
1814         auto len = address.getAddress(&addr);
1815         return connect(fd, (const struct sockaddr*)&addr, len);
1816       }));
1817   return socket;
1818 }
1819
1820 TEST(AsyncSSLSocketTest, ConnectWriteReadCloseTFOFallback) {
1821   // Start listening on a local port
1822   WriteCallbackBase writeCallback;
1823   ReadCallback readCallback(&writeCallback);
1824   HandshakeCallback handshakeCallback(&readCallback);
1825   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
1826   TestSSLServer server(&acceptCallback, true);
1827
1828   EventBase evb;
1829
1830   auto socket = setupSocketWithFallback(&evb, server.getAddress(), 1);
1831   ConnCallback ccb;
1832   socket->connect(&ccb, server.getAddress(), 30);
1833
1834   evb.loop();
1835   EXPECT_EQ(ConnCallback::State::SUCCESS, ccb.state);
1836
1837   evb.runInEventBaseThread([&] { socket->detachEventBase(); });
1838   evb.loop();
1839
1840   BlockingSocket sock(std::move(socket));
1841   // write()
1842   std::array<uint8_t, 128> buf;
1843   memset(buf.data(), 'a', buf.size());
1844   sock.write(buf.data(), buf.size());
1845
1846   // read()
1847   std::array<uint8_t, 128> readbuf;
1848   uint32_t bytesRead = sock.readAll(readbuf.data(), readbuf.size());
1849   EXPECT_EQ(bytesRead, 128);
1850   EXPECT_EQ(memcmp(buf.data(), readbuf.data(), bytesRead), 0);
1851
1852   // close()
1853   sock.close();
1854 }
1855
1856 #if !defined(OPENSSL_IS_BORINGSSL)
1857 TEST(AsyncSSLSocketTest, ConnectTFOTimeout) {
1858   // Start listening on a local port
1859   ConnectTimeoutCallback acceptCallback;
1860   TestSSLServer server(&acceptCallback, true);
1861
1862   // Set up SSL context.
1863   auto sslContext = std::make_shared<SSLContext>();
1864
1865   // connect
1866   auto socket =
1867       std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
1868   socket->enableTFO();
1869   EXPECT_THROW(
1870       socket->open(std::chrono::milliseconds(20)), AsyncSocketException);
1871 }
1872 #endif
1873
1874 #if !defined(OPENSSL_IS_BORINGSSL)
1875 TEST(AsyncSSLSocketTest, ConnectTFOFallbackTimeout) {
1876   // Start listening on a local port
1877   ConnectTimeoutCallback acceptCallback;
1878   TestSSLServer server(&acceptCallback, true);
1879
1880   EventBase evb;
1881
1882   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1883   ConnCallback ccb;
1884   // Set a short timeout
1885   socket->connect(&ccb, server.getAddress(), 1);
1886
1887   evb.loop();
1888   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1889 }
1890 #endif
1891
1892 TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
1893   // Start listening on a local port
1894   EmptyReadCallback readCallback;
1895   HandshakeCallback handshakeCallback(
1896       &readCallback, HandshakeCallback::EXPECT_ERROR);
1897   HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
1898   TestSSLServer server(&acceptCallback, true);
1899
1900   EventBase evb;
1901
1902   auto socket = setupSocketWithFallback(&evb, server.getAddress(), AtMost(1));
1903   ConnCallback ccb;
1904   socket->connect(&ccb, server.getAddress(), 100);
1905
1906   evb.loop();
1907   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1908   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
1909 }
1910
1911 TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
1912   // Start listening on a local port
1913   EventBase evb;
1914
1915   // Hopefully nothing is listening on this address
1916   SocketAddress addr("127.0.0.1", 65535);
1917   auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
1918   ConnCallback ccb;
1919   socket->connect(&ccb, addr, 100);
1920
1921   evb.loop();
1922   EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
1923   EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
1924 }
1925
1926 TEST(AsyncSSLSocketTest, TestPreReceivedData) {
1927   EventBase clientEventBase;
1928   EventBase serverEventBase;
1929   auto clientCtx = std::make_shared<SSLContext>();
1930   auto dfServerCtx = std::make_shared<SSLContext>();
1931   std::array<int, 2> fds;
1932   getfds(fds.data());
1933   getctx(clientCtx, dfServerCtx);
1934
1935   AsyncSSLSocket::UniquePtr clientSockPtr(
1936       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
1937   AsyncSSLSocket::UniquePtr serverSockPtr(
1938       new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
1939   auto clientSock = clientSockPtr.get();
1940   auto serverSock = serverSockPtr.get();
1941   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
1942
1943   // Steal some data from the server.
1944   clientEventBase.loopOnce();
1945   std::array<uint8_t, 10> buf;
1946   recv(fds[1], buf.data(), buf.size(), 0);
1947
1948   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
1949   SSLHandshakeServer server(std::move(serverSockPtr), true, true);
1950   while (!client.handshakeSuccess_ && !client.handshakeError_) {
1951     serverEventBase.loopOnce();
1952     clientEventBase.loopOnce();
1953   }
1954
1955   EXPECT_TRUE(client.handshakeSuccess_);
1956   EXPECT_TRUE(server.handshakeSuccess_);
1957   EXPECT_EQ(
1958       serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
1959 }
1960
1961 TEST(AsyncSSLSocketTest, TestMoveFromAsyncSocket) {
1962   EventBase clientEventBase;
1963   EventBase serverEventBase;
1964   auto clientCtx = std::make_shared<SSLContext>();
1965   auto dfServerCtx = std::make_shared<SSLContext>();
1966   std::array<int, 2> fds;
1967   getfds(fds.data());
1968   getctx(clientCtx, dfServerCtx);
1969
1970   AsyncSSLSocket::UniquePtr clientSockPtr(
1971       new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
1972   AsyncSocket::UniquePtr serverSockPtr(
1973       new AsyncSocket(&serverEventBase, fds[1]));
1974   auto clientSock = clientSockPtr.get();
1975   auto serverSock = serverSockPtr.get();
1976   SSLHandshakeClient client(std::move(clientSockPtr), true, true);
1977
1978   // Steal some data from the server.
1979   clientEventBase.loopOnce();
1980   std::array<uint8_t, 10> buf;
1981   recv(fds[1], buf.data(), buf.size(), 0);
1982   serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
1983   AsyncSSLSocket::UniquePtr serverSSLSockPtr(
1984       new AsyncSSLSocket(dfServerCtx, std::move(serverSockPtr), true));
1985   auto serverSSLSock = serverSSLSockPtr.get();
1986   SSLHandshakeServer server(std::move(serverSSLSockPtr), true, true);
1987   while (!client.handshakeSuccess_ && !client.handshakeError_) {
1988     serverEventBase.loopOnce();
1989     clientEventBase.loopOnce();
1990   }
1991
1992   EXPECT_TRUE(client.handshakeSuccess_);
1993   EXPECT_TRUE(server.handshakeSuccess_);
1994   EXPECT_EQ(
1995       serverSSLSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
1996 }
1997
1998 /**
1999  * Test overriding the flags passed to "sendmsg()" system call,
2000  * and verifying that write requests fail properly.
2001  */
2002 TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
2003   // Start listening on a local port
2004   SendMsgFlagsCallback msgCallback;
2005   ExpectWriteErrorCallback writeCallback(&msgCallback);
2006   ReadCallback readCallback(&writeCallback);
2007   HandshakeCallback handshakeCallback(&readCallback);
2008   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2009   TestSSLServer server(&acceptCallback);
2010
2011   // Set up SSL context.
2012   auto sslContext = std::make_shared<SSLContext>();
2013   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2014
2015   // connect
2016   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
2017                                                  sslContext);
2018   socket->open();
2019
2020   // Setting flags to "-1" to trigger "Invalid argument" error
2021   // on attempt to use this flags in sendmsg() system call.
2022   msgCallback.resetFlags(-1);
2023
2024   // write()
2025   std::vector<uint8_t> buf(128, 'a');
2026   ASSERT_EQ(socket->write(buf.data(), buf.size()), buf.size());
2027
2028   // close()
2029   socket->close();
2030
2031   cerr << "SendMsgParamsCallback test completed" << endl;
2032 }
2033
2034 #ifdef MSG_ERRQUEUE
2035 /**
2036  * Test connecting to, writing to, reading from, and closing the
2037  * connection to the SSL server.
2038  */
2039 TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
2040   // This test requires Linux kernel v4.6 or later
2041   struct utsname s_uname;
2042   memset(&s_uname, 0, sizeof(s_uname));
2043   ASSERT_EQ(uname(&s_uname), 0);
2044   int major, minor;
2045   folly::StringPiece extra;
2046   if (folly::split<false>(
2047         '.', std::string(s_uname.release) + ".", major, minor, extra)) {
2048     if (major < 4 || (major == 4 && minor < 6)) {
2049       LOG(INFO) << "Kernel version: 4.6 and newer required for this test ("
2050                 << "kernel ver. " << s_uname.release << " detected).";
2051       return;
2052     }
2053   }
2054
2055   // Start listening on a local port
2056   SendMsgDataCallback msgCallback;
2057   WriteCheckTimestampCallback writeCallback(&msgCallback);
2058   ReadCallback readCallback(&writeCallback);
2059   HandshakeCallback handshakeCallback(&readCallback);
2060   SSLServerAcceptCallback acceptCallback(&handshakeCallback);
2061   TestSSLServer server(&acceptCallback);
2062
2063   // Set up SSL context.
2064   auto sslContext = std::make_shared<SSLContext>();
2065   sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
2066
2067   // connect
2068   auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
2069                                                  sslContext);
2070   socket->open();
2071
2072   // Adding MSG_EOR flag to the message flags - it'll trigger
2073   // timestamp generation for the last byte of the message.
2074   msgCallback.resetFlags(MSG_DONTWAIT|MSG_NOSIGNAL|MSG_EOR);
2075
2076   // Init ancillary data buffer to trigger timestamp notification
2077   union {
2078     uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
2079     struct cmsghdr cmsg;
2080   } u;
2081   u.cmsg.cmsg_level = SOL_SOCKET;
2082   u.cmsg.cmsg_type = SO_TIMESTAMPING;
2083   u.cmsg.cmsg_len = CMSG_LEN(sizeof(uint32_t));
2084   uint32_t flags =
2085       SOF_TIMESTAMPING_TX_SCHED |
2086       SOF_TIMESTAMPING_TX_SOFTWARE |
2087       SOF_TIMESTAMPING_TX_ACK;
2088   memcpy(CMSG_DATA(&u.cmsg), &flags, sizeof(uint32_t));
2089   std::vector<char> ctrl(CMSG_LEN(sizeof(uint32_t)));
2090   memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
2091   msgCallback.resetData(std::move(ctrl));
2092
2093   // write()
2094   std::vector<uint8_t> buf(128, 'a');
2095   socket->write(buf.data(), buf.size());
2096
2097   // read()
2098   std::vector<uint8_t> readbuf(buf.size());
2099   uint32_t bytesRead = socket->readAll(readbuf.data(), readbuf.size());
2100   EXPECT_EQ(bytesRead, buf.size());
2101   EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
2102
2103   writeCallback.checkForTimestampNotifications();
2104
2105   // close()
2106   socket->close();
2107
2108   cerr << "SendMsgDataCallback test completed" << endl;
2109 }
2110 #endif // MSG_ERRQUEUE
2111
2112 #endif
2113
2114 } // namespace
2115
2116 #ifdef SIGPIPE
2117 ///////////////////////////////////////////////////////////////////////////
2118 // init_unit_test_suite
2119 ///////////////////////////////////////////////////////////////////////////
2120 namespace {
2121 struct Initializer {
2122   Initializer() {
2123     signal(SIGPIPE, SIG_IGN);
2124   }
2125 };
2126 Initializer initializer;
2127 } // anonymous
2128 #endif